@@ -4,202 +4,24 @@ import (
4
4
"fmt"
5
5
"log"
6
6
7
- pg_query "github .com/pganalyze/pg_query_go/v2 "
7
+ "gitlab .com/postgres-ai/database-lab/v3/internal/schema/diff "
8
8
)
9
9
10
10
const idxExample = `
11
11
CREATE UNIQUE INDEX title_idx ON films (title);
12
-
13
- DROP INDEX title_idx;
14
-
15
- ALTER TABLE distributors
16
- ADD CONSTRAINT zipchk CHECK (char_length(zipcode) = 5);
17
-
18
- ALTER TABLE distributors
19
- ADD CONSTRAINT distfk FOREIGN KEY (address) REFERENCES addresses (address);
20
-
21
- ALTER TABLE pgbench_accounts
22
- ADD COLUMN test integer NOT NULL DEFAULT 0;
23
12
`
24
13
25
14
/*
26
15
Optimized queries:
27
16
28
17
CREATE UNIQUE INDEX CONCURRENTLY title_idx ON films USING btree (title);
29
-
30
- DROP INDEX CONCURRENTLY title_idx;
31
-
32
- ALTER TABLE distributors ADD CONSTRAINT zipchk CHECK (char_length(zipcode) = 5) NOT VALID;
33
- BEGIN; ALTER TABLE distributors VALIDATE CONSTRAINT zipchk; COMMIT;
34
-
35
- ALTER TABLE distributors ADD CONSTRAINT distfk FOREIGN KEY (address) REFERENCES addresses (address) NOT VALID;
36
- BEGIN; ALTER TABLE distributors VALIDATE CONSTRAINT distfk; COMMIT;
37
-
38
- ALTER TABLE pgbench_accounts ADD COLUMN test int;
39
- ALTER TABLE pgbench_accounts ALTER COLUMN test SET DEFAULT 0;
40
18
*/
41
19
42
20
func main () {
43
- scanTree , err := pg_query .ParseToJSON (idxExample )
44
- if err != nil {
45
- log .Fatal (err )
46
- }
47
-
48
- fmt .Printf ("JSON: %s\n " , scanTree )
49
-
50
- idxTree , err := pg_query .Parse (idxExample )
51
- if err != nil {
52
- log .Fatal (err )
53
- }
54
-
55
- fmt .Printf ("Original query:\n %v\n \n " , idxExample )
56
- fmt .Printf ("Parse Tree:\n %#v\n \n " , idxTree )
57
-
58
- stmts := idxTree .GetStmts ()
59
- nodes := processStmts (stmts )
60
- idxTree .Stmts = nodes
61
-
62
- fmt .Printf ("Parse Tree after processing:\n %#v\n \n " , idxTree .GetStmts ())
63
-
64
- resIdxStr , err := pg_query .Deparse (idxTree )
21
+ resIdxStr , err := diff .OptimizeQueries (idxExample )
65
22
if err != nil {
66
23
log .Fatal (err )
67
24
}
68
25
69
26
fmt .Printf ("Optimized queries:\n %v\n " , resIdxStr )
70
27
}
71
-
72
- func processStmts (stmts []* pg_query.RawStmt ) []* pg_query.RawStmt {
73
- rawStmts := []* pg_query.RawStmt {}
74
-
75
- for _ , stmt := range stmts {
76
- for _ , node := range detectNodeType (stmt .Stmt ) {
77
- rawStmt := & pg_query.RawStmt {
78
- Stmt : node ,
79
- }
80
-
81
- rawStmts = append (rawStmts , rawStmt )
82
- }
83
- }
84
-
85
- return rawStmts
86
- }
87
-
88
- func detectNodeType (node * pg_query.Node ) []* pg_query.Node {
89
- switch stmt := node .Node .(type ) {
90
- case * pg_query.Node_IndexStmt :
91
- IndexStmt (stmt )
92
-
93
- case * pg_query.Node_DropStmt :
94
- DropStmt (stmt )
95
-
96
- case * pg_query.Node_AlterTableStmt :
97
- fmt .Println ("Alter Type" )
98
- return AlterStmt (node )
99
-
100
- case * pg_query.Node_SelectStmt :
101
- fmt .Println ("Select Type" )
102
- }
103
-
104
- return []* pg_query.Node {node }
105
- }
106
-
107
- // IndexStmt processes index statement.
108
- func IndexStmt (stmt * pg_query.Node_IndexStmt ) {
109
- stmt .IndexStmt .Concurrent = true
110
- }
111
-
112
- // DropStmt processes drop statement.
113
- func DropStmt (stmt * pg_query.Node_DropStmt ) {
114
- switch stmt .DropStmt .RemoveType {
115
- case pg_query .ObjectType_OBJECT_INDEX :
116
- stmt .DropStmt .Concurrent = true
117
- default :
118
- }
119
- }
120
-
121
- // AlterStmt processes alter statement.
122
- func AlterStmt (node * pg_query.Node ) []* pg_query.Node {
123
- alterTableStmt := node .GetAlterTableStmt ()
124
- if alterTableStmt == nil {
125
- return []* pg_query.Node {node }
126
- }
127
-
128
- var alterStmts []* pg_query.Node
129
-
130
- initialCommands := alterTableStmt .GetCmds ()
131
-
132
- for _ , cmd := range initialCommands {
133
- switch v := cmd .Node .(type ) {
134
- case * pg_query.Node_AlterTableCmd :
135
- fmt .Printf ("%#v\n " , v )
136
- fmt .Printf ("%#v\n " , v .AlterTableCmd .Def .Node )
137
- fmt .Println (v .AlterTableCmd .Subtype .Enum ())
138
-
139
- switch v .AlterTableCmd .Subtype {
140
- case pg_query .AlterTableType_AT_AddColumn :
141
- def := v .AlterTableCmd .Def .GetColumnDef ()
142
-
143
- constraints := def .GetConstraints ()
144
- constraintsMap := make (map [pg_query.ConstrType ]int )
145
-
146
- for i , constr := range constraints {
147
- constraintsMap [constr .GetConstraint ().Contype ] = i
148
- }
149
-
150
- if index , ok := constraintsMap [pg_query .ConstrType_CONSTR_DEFAULT ]; ok {
151
- def .Constraints = make ([]* pg_query.Node , 0 )
152
-
153
- alterStmts = append (alterStmts , node )
154
-
155
- defaultDefinitionTemp := fmt .Sprintf (`alter table %s alter column %s set default %v;` ,
156
- alterTableStmt .GetRelation ().GetRelname (), def .Colname ,
157
- constraints [index ].GetConstraint ().GetRawExpr ().GetAConst ().GetVal ().GetInteger ().GetIval ())
158
-
159
- alterStmts = append (alterStmts , generateNodes (defaultDefinitionTemp )... )
160
-
161
- // TODO: Update rows
162
-
163
- // TODO: apply the rest constraints
164
- constraints = append (constraints [:index ], constraints [index + 1 :]... )
165
- fmt .Println (constraints )
166
- }
167
-
168
- case pg_query .AlterTableType_AT_AddConstraint :
169
- constraint := v .AlterTableCmd .Def .GetConstraint ()
170
- constraint .SkipValidation = true
171
-
172
- alterStmts = append (alterStmts , node )
173
-
174
- validationTemp := fmt .Sprintf (`begin; alter table %s validate constraint %s; commit;` ,
175
- alterTableStmt .GetRelation ().GetRelname (), constraint .GetConname ())
176
-
177
- alterStmts = append (alterStmts , generateNodes (validationTemp )... )
178
-
179
- default :
180
- alterStmts = append (alterStmts , node )
181
- }
182
-
183
- default :
184
- alterStmts = append (alterStmts , node )
185
-
186
- fmt .Printf ("%T\n " , v )
187
- }
188
- }
189
-
190
- return alterStmts
191
- }
192
-
193
- func generateNodes (nodeTemplate string ) []* pg_query.Node {
194
- defDefinition , err := pg_query .Parse (nodeTemplate )
195
- if err != nil {
196
- log .Fatal (err )
197
- }
198
-
199
- nodes := []* pg_query.Node {}
200
- for _ , rawStmt := range defDefinition .Stmts {
201
- nodes = append (nodes , rawStmt .Stmt )
202
- }
203
-
204
- return nodes
205
- }
0 commit comments