-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSchema.hs
180 lines (153 loc) · 5.78 KB
/
Schema.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
{-# LANGUAGE OverloadedStrings #-}
module Schema (
checkAndReportErrors,
parseSchema,
typeCheckSchema,
extractChecks,
extractUniques,
extractCatalogUpdates,
extractName,
removeChecks,
) where
import Control.Monad
import Data.Data
import Data.Generics.Uniplate.Data
import Data.List
import Data.Maybe
import Data.Text (Text, pack)
import Database.HsSqlPpp.Annotation
import Database.HsSqlPpp.Catalog
import Database.HsSqlPpp.Dialect
import Database.HsSqlPpp.Parse
import Database.HsSqlPpp.Syntax
import Database.HsSqlPpp.TypeCheck
import Text.Printf
import System.Exit
import qualified Data.Text.Lazy as T
import Logging
-- TODO: move elsewhere
-- ^ Report all type errors and return True if any were found.
-- Does so by extracting every field of type "Annotation"
checkAndReportErrors :: Data a => a -> IO Bool
checkAndReportErrors x = do
let errorAnns = filter (not . null . anErrs) $ universeBi x
forM_ errorAnns $ \a -> do
case anSrc a of
Nothing -> err "Type error"
Just (fp, r, c) -> err $ printf "Type error at %s:%d:%d" fp r c
mapM_ (err . (" " ++) . show) $ anErrs a
return $ not $ null errorAnns
-------------------------------
-- Parsing and type checking --
-------------------------------
parseSchema :: Dialect -> FilePath -> T.Text -> IO [Statement]
parseSchema dialect fp src =
case parseStatements parseFlags fp Nothing src of
Left e -> fatal $ show e
Right stmts -> return stmts
where
parseFlags = defaultParseFlags { pfDialect = dialect }
typeCheckSchema :: Dialect -> FilePath -> Catalog -> [Statement] -> IO (Catalog, [Statement])
typeCheckSchema dialect fp catalog stmts = do
verifySchema fp stmts
(catalog, stmts) <- return $ typeCheckStatements typeCheckFlags catalog stmts
anyErrors <- checkAndReportErrors stmts
when anyErrors exitFailure
return (catalog, stmts)
where
typeCheckFlags = defaultTypeCheckFlags {
tcfAddQualifiers = True,
-- tcfAddFullTablerefAliases = True,
tcfAddSelectItemAliases = True,
tcfExpandStars = True,
tcfDialect = dialect
}
-- Verify that given statements all create tables.
verifySchema :: FilePath -> [Statement] -> IO ()
verifySchema fp = mapM_ go
where
go CreateTable{} = return ()
go CreateFunction{} = return ()
go stmt = case anSrc (getAnnotation stmt) of
Nothing -> fatal (printf "Expecting only CREATE statements in schema. Error in %s." fp)
Just (fp, r, c) -> fatal (printf "Expecting only CREATE statements in schema. Error at %s:%d:%d." fp r c)
--Just (_, r, c) -> fatal (printf "Expecting only CREATE statements in schema. Error at %s:%d:%d." fp r c)
------------------------------
-- Extract info from schema --
------------------------------
nameToText :: Name -> Text
nameToText (AntiName _) = ice "AntiName."
nameToText (Name _ []) = ice "Empty name."
nameToText (Name _ ns) = pack $ go ns
where
go [nc] = ncStr nc
go (nc : ncs) = ncStr nc ++ "." ++ go ncs
ncStrT :: NameComponent -> Text
ncStrT = pack . ncStr
typeNameToCatNameExtra :: TypeName -> CatNameExtra
typeNameToCatNameExtra = mkCatNameExtra . typeNameToCatName
typeNameToCatName :: TypeName -> CatName
typeNameToCatName (SimpleTypeName _ n) = nameToText n
typeNameToCatName _ = ice "Unsupported TypeName."
removeChecks :: Statement -> Statement
removeChecks = transformBi (filter isntRowCheck) . transformBi (filter isntCheck)
where
isntRowCheck RowCheckConstraint{} = False
isntRowCheck _ = True
isntCheck CheckConstraint{} = False
isntCheck _ = True
extractChecks :: Statement -> [ScalarExpr]
extractChecks stmt =
[e | RowCheckConstraint _ _ e <- universeBi stmt] ++
[e | CheckConstraint _ _ e <- universeBi stmt]
extractName :: Statement -> Name
extractName (CreateTable _ n _ _ _ _) = n
extractName (CreateFunction _ n _ _ _ _ _ _) = n
extractName _ = ice "Expecting a CREATE TABLE or FUNCTION statement."
extractUniques :: Statement -> [[NameComponent]]
extractUniques stmt
| CreateTable _ _ as cs _ _ <- stmt =
normalize $ mapMaybe goAttr as ++ mapMaybe goConstr cs
| CreateFunction{} <- stmt = []
| otherwise = ice "Expecting a CREATE TABLE or FUNCTION statement."
where
normalize = nub . map sort -- TODO: O(n^2)
isUniqueRow RowPrimaryKeyConstraint{} = True
isUniqueRow RowUniqueConstraint{} = True
isUniqueRow _ = False
goAttr (AttributeDef _ n _ _ rcs)
| any isUniqueRow rcs = Just [n]
| otherwise = Nothing
goConstr (UniqueConstraint _ _ ns) = Just ns
goConstr (PrimaryKeyConstraint _ _ ns) = Just ns
goConstr _ = Nothing
-- TODO: it's also possible that table constraints make some more columns NOT NULL
-- This is becayse primary key constraint implies NOT NULL.
-- ^ Extract catalog updates from CREATE TABLE statements.
extractCatalogUpdates :: Statement -> [CatalogUpdate]
extractCatalogUpdates stmt
| CreateTable _ name as _cs _ _ <- stmt = doTable name as
| CreateFunction _ name ps retTy _ _ _ _ <- stmt = doFun name ps retTy
| otherwise = ice "Expecting a CREATE TABLE or FUNCTION statement."
where
isNotNull NotNullConstraint{} = True
isNotNull RowPrimaryKeyConstraint{} = True
isNotNull _ = False
isNull NullConstraint{} = True
isNull _ = False
paramDefTy (ParamDef _ _ t) = t
paramDefTy (ParamDefTp _ t) = t
doTable name as = [CatCreateTable ("public", nameToText name)
[ (pack (ncStr name), catNameExtra) |
AttributeDef _ name ty _ rcs <- as,
let notNull = any isNotNull rcs && not (any isNull rcs),
let catNameExtra = (typeNameToCatNameExtra ty) {
catNullable = not notNull
}
]
]
doFun name ps retTy = [CatCreateFunction name' ps' False retTy']
where
name' = nameToText name
retTy' = typeNameToCatName retTy
ps' = map (typeNameToCatName.paramDefTy) ps