Skip to content

Commit caef5a3

Browse files
committed
feat(engine): create a diff package for query processing, add tests
1 parent 4249aac commit caef5a3

File tree

3 files changed

+209
-180
lines changed

3 files changed

+209
-180
lines changed

engine/cmd/schema-diff/main.go

Lines changed: 2 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -4,202 +4,24 @@ import (
44
"fmt"
55
"log"
66

7-
pg_query "github.com/pganalyze/pg_query_go/v2"
7+
"gitlab.com/postgres-ai/database-lab/v3/internal/schema/diff"
88
)
99

1010
const idxExample = `
1111
CREATE UNIQUE INDEX title_idx ON films (title);
12-
13-
DROP INDEX title_idx;
14-
15-
ALTER TABLE distributors
16-
ADD CONSTRAINT zipchk CHECK (char_length(zipcode) = 5);
17-
18-
ALTER TABLE distributors
19-
ADD CONSTRAINT distfk FOREIGN KEY (address) REFERENCES addresses (address);
20-
21-
ALTER TABLE pgbench_accounts
22-
ADD COLUMN test integer NOT NULL DEFAULT 0;
2312
`
2413

2514
/*
2615
Optimized queries:
2716
2817
CREATE UNIQUE INDEX CONCURRENTLY title_idx ON films USING btree (title);
29-
30-
DROP INDEX CONCURRENTLY title_idx;
31-
32-
ALTER TABLE distributors ADD CONSTRAINT zipchk CHECK (char_length(zipcode) = 5) NOT VALID;
33-
BEGIN; ALTER TABLE distributors VALIDATE CONSTRAINT zipchk; COMMIT;
34-
35-
ALTER TABLE distributors ADD CONSTRAINT distfk FOREIGN KEY (address) REFERENCES addresses (address) NOT VALID;
36-
BEGIN; ALTER TABLE distributors VALIDATE CONSTRAINT distfk; COMMIT;
37-
38-
ALTER TABLE pgbench_accounts ADD COLUMN test int;
39-
ALTER TABLE pgbench_accounts ALTER COLUMN test SET DEFAULT 0;
4018
*/
4119

4220
func main() {
43-
scanTree, err := pg_query.ParseToJSON(idxExample)
44-
if err != nil {
45-
log.Fatal(err)
46-
}
47-
48-
fmt.Printf("JSON: %s\n", scanTree)
49-
50-
idxTree, err := pg_query.Parse(idxExample)
51-
if err != nil {
52-
log.Fatal(err)
53-
}
54-
55-
fmt.Printf("Original query:\n%v\n\n", idxExample)
56-
fmt.Printf("Parse Tree:\n%#v\n\n", idxTree)
57-
58-
stmts := idxTree.GetStmts()
59-
nodes := processStmts(stmts)
60-
idxTree.Stmts = nodes
61-
62-
fmt.Printf("Parse Tree after processing:\n%#v\n\n", idxTree.GetStmts())
63-
64-
resIdxStr, err := pg_query.Deparse(idxTree)
21+
resIdxStr, err := diff.OptimizeQueries(idxExample)
6522
if err != nil {
6623
log.Fatal(err)
6724
}
6825

6926
fmt.Printf("Optimized queries:\n%v\n", resIdxStr)
7027
}
71-
72-
func processStmts(stmts []*pg_query.RawStmt) []*pg_query.RawStmt {
73-
rawStmts := []*pg_query.RawStmt{}
74-
75-
for _, stmt := range stmts {
76-
for _, node := range detectNodeType(stmt.Stmt) {
77-
rawStmt := &pg_query.RawStmt{
78-
Stmt: node,
79-
}
80-
81-
rawStmts = append(rawStmts, rawStmt)
82-
}
83-
}
84-
85-
return rawStmts
86-
}
87-
88-
func detectNodeType(node *pg_query.Node) []*pg_query.Node {
89-
switch stmt := node.Node.(type) {
90-
case *pg_query.Node_IndexStmt:
91-
IndexStmt(stmt)
92-
93-
case *pg_query.Node_DropStmt:
94-
DropStmt(stmt)
95-
96-
case *pg_query.Node_AlterTableStmt:
97-
fmt.Println("Alter Type")
98-
return AlterStmt(node)
99-
100-
case *pg_query.Node_SelectStmt:
101-
fmt.Println("Select Type")
102-
}
103-
104-
return []*pg_query.Node{node}
105-
}
106-
107-
// IndexStmt processes index statement.
108-
func IndexStmt(stmt *pg_query.Node_IndexStmt) {
109-
stmt.IndexStmt.Concurrent = true
110-
}
111-
112-
// DropStmt processes drop statement.
113-
func DropStmt(stmt *pg_query.Node_DropStmt) {
114-
switch stmt.DropStmt.RemoveType {
115-
case pg_query.ObjectType_OBJECT_INDEX:
116-
stmt.DropStmt.Concurrent = true
117-
default:
118-
}
119-
}
120-
121-
// AlterStmt processes alter statement.
122-
func AlterStmt(node *pg_query.Node) []*pg_query.Node {
123-
alterTableStmt := node.GetAlterTableStmt()
124-
if alterTableStmt == nil {
125-
return []*pg_query.Node{node}
126-
}
127-
128-
var alterStmts []*pg_query.Node
129-
130-
initialCommands := alterTableStmt.GetCmds()
131-
132-
for _, cmd := range initialCommands {
133-
switch v := cmd.Node.(type) {
134-
case *pg_query.Node_AlterTableCmd:
135-
fmt.Printf("%#v\n", v)
136-
fmt.Printf("%#v\n", v.AlterTableCmd.Def.Node)
137-
fmt.Println(v.AlterTableCmd.Subtype.Enum())
138-
139-
switch v.AlterTableCmd.Subtype {
140-
case pg_query.AlterTableType_AT_AddColumn:
141-
def := v.AlterTableCmd.Def.GetColumnDef()
142-
143-
constraints := def.GetConstraints()
144-
constraintsMap := make(map[pg_query.ConstrType]int)
145-
146-
for i, constr := range constraints {
147-
constraintsMap[constr.GetConstraint().Contype] = i
148-
}
149-
150-
if index, ok := constraintsMap[pg_query.ConstrType_CONSTR_DEFAULT]; ok {
151-
def.Constraints = make([]*pg_query.Node, 0)
152-
153-
alterStmts = append(alterStmts, node)
154-
155-
defaultDefinitionTemp := fmt.Sprintf(`alter table %s alter column %s set default %v;`,
156-
alterTableStmt.GetRelation().GetRelname(), def.Colname,
157-
constraints[index].GetConstraint().GetRawExpr().GetAConst().GetVal().GetInteger().GetIval())
158-
159-
alterStmts = append(alterStmts, generateNodes(defaultDefinitionTemp)...)
160-
161-
// TODO: Update rows
162-
163-
// TODO: apply the rest constraints
164-
constraints = append(constraints[:index], constraints[index+1:]...)
165-
fmt.Println(constraints)
166-
}
167-
168-
case pg_query.AlterTableType_AT_AddConstraint:
169-
constraint := v.AlterTableCmd.Def.GetConstraint()
170-
constraint.SkipValidation = true
171-
172-
alterStmts = append(alterStmts, node)
173-
174-
validationTemp := fmt.Sprintf(`begin; alter table %s validate constraint %s; commit;`,
175-
alterTableStmt.GetRelation().GetRelname(), constraint.GetConname())
176-
177-
alterStmts = append(alterStmts, generateNodes(validationTemp)...)
178-
179-
default:
180-
alterStmts = append(alterStmts, node)
181-
}
182-
183-
default:
184-
alterStmts = append(alterStmts, node)
185-
186-
fmt.Printf("%T\n", v)
187-
}
188-
}
189-
190-
return alterStmts
191-
}
192-
193-
func generateNodes(nodeTemplate string) []*pg_query.Node {
194-
defDefinition, err := pg_query.Parse(nodeTemplate)
195-
if err != nil {
196-
log.Fatal(err)
197-
}
198-
199-
nodes := []*pg_query.Node{}
200-
for _, rawStmt := range defDefinition.Stmts {
201-
nodes = append(nodes, rawStmt.Stmt)
202-
}
203-
204-
return nodes
205-
}
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Package diff parses SQL queries and processes statements for optimization.
2+
package diff
3+
4+
import (
5+
"fmt"
6+
"log"
7+
8+
pg_query "github.com/pganalyze/pg_query_go/v2"
9+
)
10+
11+
// OptimizeQueries rewrites incoming queries into queries with zero downtime risk.
12+
func OptimizeQueries(queries string) (string, error) {
13+
idxTree, err := pg_query.Parse(queries)
14+
if err != nil {
15+
return "", fmt.Errorf("failed to parse queries %w", err)
16+
}
17+
18+
log.Printf("Original query:\n%v\n\n", queries)
19+
log.Printf("Parse Tree:\n%#v\n\n", idxTree)
20+
21+
stmts := idxTree.GetStmts()
22+
nodes := processStmts(stmts)
23+
idxTree.Stmts = nodes
24+
25+
return pg_query.Deparse(idxTree)
26+
}
27+
28+
func processStmts(stmts []*pg_query.RawStmt) []*pg_query.RawStmt {
29+
rawStmts := []*pg_query.RawStmt{}
30+
31+
for _, stmt := range stmts {
32+
for _, node := range detectNodeType(stmt.Stmt) {
33+
rawStmt := &pg_query.RawStmt{
34+
Stmt: node,
35+
}
36+
37+
rawStmts = append(rawStmts, rawStmt)
38+
}
39+
}
40+
41+
return rawStmts
42+
}
43+
44+
func detectNodeType(node *pg_query.Node) []*pg_query.Node {
45+
switch stmt := node.Node.(type) {
46+
case *pg_query.Node_IndexStmt:
47+
IndexStmt(stmt)
48+
49+
case *pg_query.Node_DropStmt:
50+
DropStmt(stmt)
51+
52+
case *pg_query.Node_AlterTableStmt:
53+
fmt.Println("Alter Type")
54+
return AlterStmt(node)
55+
56+
case *pg_query.Node_SelectStmt:
57+
fmt.Println("Select Type")
58+
}
59+
60+
return []*pg_query.Node{node}
61+
}
62+
63+
// IndexStmt processes index statement.
64+
func IndexStmt(stmt *pg_query.Node_IndexStmt) {
65+
stmt.IndexStmt.Concurrent = true
66+
}
67+
68+
// DropStmt processes drop statement.
69+
func DropStmt(stmt *pg_query.Node_DropStmt) {
70+
switch stmt.DropStmt.RemoveType {
71+
case pg_query.ObjectType_OBJECT_INDEX:
72+
stmt.DropStmt.Concurrent = true
73+
default:
74+
}
75+
}
76+
77+
// AlterStmt processes alter statement.
78+
func AlterStmt(node *pg_query.Node) []*pg_query.Node {
79+
alterTableStmt := node.GetAlterTableStmt()
80+
if alterTableStmt == nil {
81+
return []*pg_query.Node{node}
82+
}
83+
84+
var alterStmts []*pg_query.Node
85+
86+
initialCommands := alterTableStmt.GetCmds()
87+
88+
for _, cmd := range initialCommands {
89+
switch v := cmd.Node.(type) {
90+
case *pg_query.Node_AlterTableCmd:
91+
fmt.Printf("%#v\n", v)
92+
fmt.Printf("%#v\n", v.AlterTableCmd.Def.Node)
93+
fmt.Println(v.AlterTableCmd.Subtype.Enum())
94+
95+
switch v.AlterTableCmd.Subtype {
96+
case pg_query.AlterTableType_AT_AddColumn:
97+
def := v.AlterTableCmd.Def.GetColumnDef()
98+
99+
constraints := def.GetConstraints()
100+
constraintsMap := make(map[pg_query.ConstrType]int)
101+
102+
for i, constr := range constraints {
103+
constraintsMap[constr.GetConstraint().Contype] = i
104+
}
105+
106+
if index, ok := constraintsMap[pg_query.ConstrType_CONSTR_DEFAULT]; ok {
107+
def.Constraints = make([]*pg_query.Node, 0)
108+
109+
alterStmts = append(alterStmts, node)
110+
111+
defaultDefinitionTemp := fmt.Sprintf(`alter table %s alter column %s set default %v;`,
112+
alterTableStmt.GetRelation().GetRelname(), def.Colname,
113+
constraints[index].GetConstraint().GetRawExpr().GetAConst().GetVal().GetInteger().GetIval())
114+
115+
alterStmts = append(alterStmts, generateNodes(defaultDefinitionTemp)...)
116+
117+
// TODO: Update rows
118+
119+
// TODO: apply the rest constraints
120+
constraints = append(constraints[:index], constraints[index+1:]...)
121+
fmt.Println(constraints)
122+
}
123+
124+
case pg_query.AlterTableType_AT_AddConstraint:
125+
constraint := v.AlterTableCmd.Def.GetConstraint()
126+
constraint.SkipValidation = true
127+
128+
alterStmts = append(alterStmts, node)
129+
130+
validationTemp := fmt.Sprintf(`begin; alter table %s validate constraint %s; commit;`,
131+
alterTableStmt.GetRelation().GetRelname(), constraint.GetConname())
132+
133+
alterStmts = append(alterStmts, generateNodes(validationTemp)...)
134+
135+
default:
136+
alterStmts = append(alterStmts, node)
137+
}
138+
139+
default:
140+
alterStmts = append(alterStmts, node)
141+
142+
fmt.Printf("%T\n", v)
143+
}
144+
}
145+
146+
return alterStmts
147+
}
148+
149+
func generateNodes(nodeTemplate string) []*pg_query.Node {
150+
defDefinition, err := pg_query.Parse(nodeTemplate)
151+
if err != nil {
152+
log.Println(err)
153+
return nil
154+
}
155+
156+
nodes := []*pg_query.Node{}
157+
for _, rawStmt := range defDefinition.Stmts {
158+
nodes = append(nodes, rawStmt.Stmt)
159+
}
160+
161+
return nodes
162+
}

0 commit comments

Comments
 (0)