-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSelectQuery.hs
178 lines (160 loc) · 6.25 KB
/
SelectQuery.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
{-# LANGUAGE LambdaCase #-}
module SelectQuery (
parseSelectQuery,
typeCheckSelectQuery,
combinePrimaryKeys,
extractWhereExpr,
extractJoinTables,
isSupportedAggregOp)
where
import Control.Monad
import Database.HsSqlPpp.Annotation
import Database.HsSqlPpp.Catalog
import Database.HsSqlPpp.Dialect
import Database.HsSqlPpp.Parse
import Database.HsSqlPpp.Syntax
import Database.HsSqlPpp.TypeCheck
import Data.Char
import Data.Generics.Uniplate.Data
import Data.IORef
import Data.List
import Data.Maybe
import System.Exit
import Text.Printf
import qualified Data.Text.Lazy as T
import Schema -- TODO: workaround
import Logging
--parseSelectQuery :: Dialect -> FilePath -> T.Text -> IO QueryExpr
--parseSelectQuery dialect fp src =
-- case parseQueryExpr parseFlags fp Nothing src of
-- Left err -> fatal (show err)
-- Right query@Select{} -> return query
-- Right _ -> fatal "Unsupported query type. Expecting basic SELECT query."
-- where
-- parseFlags = defaultParseFlags { pfDialect = dialect }
parseSelectQuery :: Dialect -> FilePath -> T.Text -> IO (QueryExpr, [QueryExpr])
parseSelectQuery dialect fp src =
case parseQueryExpr parseFlags fp Nothing src of
Left err -> fatal (show err)
Right query -> do
qs <- extractQueries query
return (query, qs)
where
parseFlags = defaultParseFlags { pfDialect = dialect }
extractQueries query@Select{} = return [query]
extractQueries (CombineQueryExpr _ cqt q1 q2) | cqt `elem` [Intersect,Union,Except] = do
qs1 <- extractQueries q1
qs2 <- extractQueries q2
return (qs1 ++ qs2)
extractQueries _ = fatal "Unsupported query type. Expecting basic SELECT query."
combinePrimaryKeys :: QueryExpr -> [[Bool]] -> [Bool]
combinePrimaryKeys q pks = fst $ f q pks
where
f Select{} (pk:pks) = (pk,pks)
f (CombineQueryExpr _ cqt q1 q2) pks =
let
(pk1,pks') = f q1 pks
(pk2,pks'') = f q2 pks'
pk =
case cqt of
Intersect -> zipWith (||) pk1 pk2
Union -> map (const False) pk1
Except -> pk1
in
(pk,pks'')
type Reason = String
type Loc = Maybe SourcePosition -- using location of hssqlppp
prettyLoc :: String -> Loc -> String
prettyLoc msg Nothing = msg
prettyLoc msg (Just (fp, r, c)) = printf "%s. Error at %s:%d:%d" msg fp r c
unsupportedClauses :: Bool -> QueryExpr -> [Reason]
unsupportedClauses False query =
-- ["ALL with non-aggregating expressions and without GROUP BY"
-- | selDistinct query == All && not (isSelectListOnlyAggregExprs (selSelectList query)) && null (selGroupBy query) ] ++
["LIMIT" | isJust $ selLimit query] ++
["OFFSET" | isJust $ selOffset query] ++
["HAVING" | isJust $ selHaving query]
unsupportedClauses True query =
["LIMIT" | isJust $ selLimit query] ++
["OFFSET" | isJust $ selOffset query] ++
["HAVING" | isJust $ selHaving query]
unsupportedFrom :: Bool -> QueryExpr -> [(Loc, Reason)]
unsupportedFrom local query =
[(anSrc a, "Subquery") | not local, SubTref a _ <- trefs] ++
[(anSrc a, "Function") | FunTref a _ <- trefs] ++
[(anSrc a, "???") | OdbcTableRef a _ <- trefs] ++
[(anSrc a, showJoin j) | JoinTref a _ _ j _ _ _ <- trefs, j `notElem` [Inner, Cross]] ++
[(anSrc a, "Join USING") | JoinTref _ _ _ _ _ _ (Just (JoinUsing a _)) <- trefs] ++ -- TODO: handle these properly
[(anSrc a, "Full alias") | FullAlias a _ _ _ <- trefs] -- TODO: also handle this properly
where
trefs = universeBi query
showJoin j = headToUpper (map toLower (show j)) ++ " join"
headToUpper [] = []
headToUpper (x : xs) = toUpper x : xs
-- ^ Get locations of unsupported expressions.
-- ^ Both from WHERE clause and joins.
-- TODO: Dont descend under already unsupported expressions?
unsupportedWhere :: Bool -> QueryExpr -> [Loc]
unsupportedWhere local query =
map (anSrc.getAnnotation) $
filter (not.isSupportedWhereExpr) $
if local
then universeBi (selWhere query)
else universeBi (selWhere query) ++ universeBi (selTref query)
isSupportedWhereExpr :: ScalarExpr -> Bool
isSupportedWhereExpr = \case
NumberLit{} -> True
StringLit{} -> True
-- NullLit{} -> True
BooleanLit{} -> True
Identifier{} -> True
Parens{} -> True
PrefixOp _ n _ -> nameToStr n `elem` ops
BinaryOp _ n _ _ -> nameToStr n `elem` ops
SpecialOp _ n _ -> nameToStr n == "between"
_ -> False
where
ops = ["=", "<", ">", "<>", "!=", "<=", ">=", "and", "or", "+", "-", "*", "/", "%", "not"]
isSupportedAggregOp :: Name -> Bool
isSupportedAggregOp op = nameToStr op `elem` ["count", "sum", "avg", "min", "max"]
nameToStr :: Name -> String
nameToStr (Name _ ns) = intercalate "." (map ncStr ns)
nameToStr AntiName{} = ice "Unexpected AntiName."
extractWhereExpr :: QueryExpr -> [ScalarExpr]
extractWhereExpr query =
maybeToList (selWhere query) ++
[e | JoinTref a _ _ _ _ _ (Just (JoinOn _ e)) <- universeBi query]
extractJoinTables :: QueryExpr -> [TableRef]
extractJoinTables = concatMap go . selTref
where
go (JoinTref _ l _ _ _ r _) = go l ++ go r
go t = [t]
typeCheckSelectQuery :: Dialect -> Bool -> Bool -> FilePath -> Catalog -> QueryExpr -> IO QueryExpr
typeCheckSelectQuery dialect local checkUnsupporteds fp catalog query = do
query <- return $ typeCheckQueryExpr typeCheckFlags catalog query
queryErrs <- checkAndReportErrors query
when queryErrs exitFailure -- dont bail?
-- Because type checker may rewrite queries to a different form
-- we perform feature check late.
bailRef <- newIORef False
when checkUnsupporteds $ do
forM_ (unsupportedClauses local query) $ \str -> do
bailRef `writeIORef` True
err $ str ++ " clause is not supported"
forM_ (unsupportedFrom local query) $ \ (loc, str) -> do
bailRef `writeIORef` True
err $ prettyLoc (printf "%s not supported in FROM clause." str) loc
forM_ (unsupportedWhere local query) $ \loc -> do
bailRef `writeIORef` True
err $ prettyLoc "Unsupported expression in WHERE clause or join." loc
bail <- readIORef bailRef
when bail exitFailure
return query
where
typeCheckFlags = defaultTypeCheckFlags {
tcfAddQualifiers = True,
-- tcfAddFullTablerefAliases = True,
tcfAddSelectItemAliases = True,
tcfExpandStars = True,
tcfDialect = dialect
}