This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 280
/
Copy pathtest_sql.py
108 lines (92 loc) · 3.78 KB
/
test_sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import unittest
import attrs
from tests.common import TEST_MYSQL_CONN_STRING
from data_diff.databases import connect
from data_diff.databases.base import Compiler
from data_diff.queries.api import Count, Explain, Select, table, In, BinOp, Code
class TestSQL(unittest.TestCase):
def setUp(self):
self.mysql = connect(TEST_MYSQL_CONN_STRING)
self.compiler = Compiler(self.mysql)
def test_compile_string(self):
self.assertEqual("SELECT 1", self.compiler.compile(Code("SELECT 1")))
def test_compile_int(self):
self.assertEqual("1", self.compiler.compile(1))
def test_compile_table_name(self):
compiler = attrs.evolve(self.compiler, root=False)
self.assertEqual("`marine_mammals`.`walrus`", compiler.compile(table("marine_mammals", "walrus")))
def test_compile_select(self):
expected_sql = "SELECT name FROM `marine_mammals`.`walrus`"
self.assertEqual(
expected_sql,
self.compiler.compile(
Select(
table("marine_mammals", "walrus"),
[Code("name")],
)
),
)
# def test_enum(self):
# expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp"
# self.assertEqual(
# expected_sql,
# self.compiler.compile(
# Enum(
# ("walrus",),
# "id",
# )
# ),
# )
# def test_checksum(self):
# expected_sql = "SELECT name, sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) FROM `marine_mammals`.`walrus`"
# self.assertEqual(
# expected_sql,
# self.compiler.compile(
# Select(
# ["name", Checksum(["id", "timestamp"])],
# TableName(("marine_mammals", "walrus")),
# )
# ),
# )
def test_compare(self):
expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id <= 1000) AND (id > 1)"
self.assertEqual(
expected_sql,
self.compiler.compile(
Select(
table("marine_mammals", "walrus"),
[Code("name")],
[BinOp("<=", [Code("id"), Code("1000")]), BinOp(">", [Code("id"), Code("1")])],
)
),
)
def test_in(self):
expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))"
self.assertEqual(
expected_sql,
self.compiler.compile(
Select(table("marine_mammals", "walrus"), [Code("name")], [In(Code("id"), [1, 2, 3])])
),
)
def test_count(self):
expected_sql = "SELECT count(*) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))"
self.assertEqual(
expected_sql,
self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count()], [In(Code("id"), [1, 2, 3])])),
)
def test_count_with_column(self):
expected_sql = "SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))"
self.assertEqual(
expected_sql,
self.compiler.compile(
Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])])
),
)
def test_explain(self):
expected_sql = "EXPLAIN FORMAT=TREE SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))"
self.assertEqual(
expected_sql,
self.compiler.compile(
Explain(Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])]))
),
)