@@ -391,36 +391,26 @@ def test_get_param_types_none(self):
391
391
392
392
@unittest .skipIf (skip_condition , skip_message )
393
393
def test_ensure_where_clause (self ):
394
+ from google .cloud .spanner_dbapi .exceptions import ProgrammingError
394
395
from google .cloud .spanner_dbapi .parse_utils import ensure_where_clause
395
396
396
- cases = [
397
- (
398
- "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1" ,
399
- "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1" ,
400
- ),
401
- (
402
- "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5" ,
403
- "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1" ,
404
- ),
405
- (
406
- "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2" ,
407
- "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2" ,
408
- ),
409
- (
410
- "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)" ,
411
- "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)" ,
412
- ),
413
- (
414
- "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)" ,
415
- "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)" ,
416
- ),
417
- ("DELETE * FROM TABLE" , "DELETE * FROM TABLE WHERE 1=1" ),
418
- ]
397
+ cases = (
398
+ "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1" ,
399
+ "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2" ,
400
+ "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)" ,
401
+ )
402
+ err_cases = (
403
+ "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5" ,
404
+ "DELETE * FROM TABLE" ,
405
+ )
406
+ for sql in cases :
407
+ with self .subTest (sql = sql ):
408
+ ensure_where_clause (sql )
419
409
420
- for sql , want in cases :
410
+ for sql in err_cases :
421
411
with self .subTest (sql = sql ):
422
- got = ensure_where_clause ( sql )
423
- self . assertEqual ( got , want )
412
+ with self . assertRaises ( ProgrammingError ):
413
+ ensure_where_clause ( sql )
424
414
425
415
@unittest .skipIf (skip_condition , skip_message )
426
416
def test_escape_name (self ):
0 commit comments