14
14
15
15
"""DB-API Connection for the Google Cloud Spanner."""
16
16
17
+ import time
17
18
import warnings
18
19
20
+ from google .api_core .exceptions import Aborted
19
21
from google .api_core .gapic_v1 .client_info import ClientInfo
20
22
from google .cloud import spanner_v1 as spanner
23
+ from google .cloud .spanner_v1 .session import _get_retry_delay
21
24
25
+ from google .cloud .spanner_dbapi .checksum import _compare_checksums
26
+ from google .cloud .spanner_dbapi .checksum import ResultsChecksum
22
27
from google .cloud .spanner_dbapi .cursor import Cursor
23
28
from google .cloud .spanner_dbapi .exceptions import InterfaceError
24
29
from google .cloud .spanner_dbapi .version import DEFAULT_USER_AGENT
25
30
from google .cloud .spanner_dbapi .version import PY_VERSION
26
31
27
32
28
33
AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
34
+ MAX_INTERNAL_RETRIES = 50
29
35
30
36
31
37
class Connection :
@@ -48,9 +54,16 @@ def __init__(self, instance, database):
48
54
49
55
self ._transaction = None
50
56
self ._session = None
57
+ # SQL statements, which were executed
58
+ # within the current transaction
59
+ self ._statements = []
51
60
52
61
self .is_closed = False
53
62
self ._autocommit = False
63
+ # indicator to know if the session pool used by
64
+ # this connection should be cleared on the
65
+ # connection close
66
+ self ._own_pool = True
54
67
55
68
@property
56
69
def autocommit (self ):
@@ -114,6 +127,58 @@ def _release_session(self):
114
127
self .database ._pool .put (self ._session )
115
128
self ._session = None
116
129
130
+ def retry_transaction (self ):
131
+ """Retry the aborted transaction.
132
+
133
+ All the statements executed in the original transaction
134
+ will be re-executed in new one. Results checksums of the
135
+ original statements and the retried ones will be compared.
136
+
137
+ :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
138
+ If results checksum of the retried statement is
139
+ not equal to the checksum of the original one.
140
+ """
141
+ attempt = 0
142
+ while True :
143
+ self ._transaction = None
144
+ attempt += 1
145
+ if attempt > MAX_INTERNAL_RETRIES :
146
+ raise
147
+
148
+ try :
149
+ self ._rerun_previous_statements ()
150
+ break
151
+ except Aborted as exc :
152
+ delay = _get_retry_delay (exc .errors [0 ], attempt )
153
+ if delay :
154
+ time .sleep (delay )
155
+
156
+ def _rerun_previous_statements (self ):
157
+ """
158
+ Helper to run all the remembered statements
159
+ from the last transaction.
160
+ """
161
+ for statement in self ._statements :
162
+ res_iter , retried_checksum = self .run_statement (statement , retried = True )
163
+ # executing all the completed statements
164
+ if statement != self ._statements [- 1 ]:
165
+ for res in res_iter :
166
+ retried_checksum .consume_result (res )
167
+
168
+ _compare_checksums (statement .checksum , retried_checksum )
169
+ # executing the failed statement
170
+ else :
171
+ # streaming up to the failed result or
172
+ # to the end of the streaming iterator
173
+ while len (retried_checksum ) < len (statement .checksum ):
174
+ try :
175
+ res = next (iter (res_iter ))
176
+ retried_checksum .consume_result (res )
177
+ except StopIteration :
178
+ break
179
+
180
+ _compare_checksums (statement .checksum , retried_checksum )
181
+
117
182
def transaction_checkout (self ):
118
183
"""Get a Cloud Spanner transaction.
119
184
@@ -158,6 +223,9 @@ def close(self):
158
223
):
159
224
self ._transaction .rollback ()
160
225
226
+ if self ._own_pool :
227
+ self .database ._pool .clear ()
228
+
161
229
self .is_closed = True
162
230
163
231
def commit (self ):
@@ -168,8 +236,13 @@ def commit(self):
168
236
if self ._autocommit :
169
237
warnings .warn (AUTOCOMMIT_MODE_WARNING , UserWarning , stacklevel = 2 )
170
238
elif self ._transaction :
171
- self ._transaction .commit ()
172
- self ._release_session ()
239
+ try :
240
+ self ._transaction .commit ()
241
+ self ._release_session ()
242
+ self ._statements = []
243
+ except Aborted :
244
+ self .retry_transaction ()
245
+ self .commit ()
173
246
174
247
def rollback (self ):
175
248
"""Rolls back any pending transaction.
@@ -182,6 +255,7 @@ def rollback(self):
182
255
elif self ._transaction :
183
256
self ._transaction .rollback ()
184
257
self ._release_session ()
258
+ self ._statements = []
185
259
186
260
def cursor (self ):
187
261
"""Factory to create a DB-API Cursor."""
@@ -198,6 +272,32 @@ def run_prior_DDL_statements(self):
198
272
199
273
return self .database .update_ddl (ddl_statements ).result ()
200
274
275
+ def run_statement (self , statement , retried = False ):
276
+ """Run single SQL statement in begun transaction.
277
+
278
+ This method is never used in autocommit mode. In
279
+ !autocommit mode however it remembers every executed
280
+ SQL statement with its parameters.
281
+
282
+ :type statement: :class:`dict`
283
+ :param statement: SQL statement to execute.
284
+
285
+ :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`,
286
+ :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
287
+ :returns: Streamed result set of the statement and a
288
+ checksum of this statement results.
289
+ """
290
+ transaction = self .transaction_checkout ()
291
+ if not retried :
292
+ self ._statements .append (statement )
293
+
294
+ return (
295
+ transaction .execute_sql (
296
+ statement .sql , statement .params , param_types = statement .param_types ,
297
+ ),
298
+ ResultsChecksum () if retried else statement .checksum ,
299
+ )
300
+
201
301
def __enter__ (self ):
202
302
return self
203
303
@@ -207,7 +307,12 @@ def __exit__(self, etype, value, traceback):
207
307
208
308
209
309
def connect (
210
- instance_id , database_id , project = None , credentials = None , pool = None , user_agent = None
310
+ instance_id ,
311
+ database_id ,
312
+ project = None ,
313
+ credentials = None ,
314
+ pool = None ,
315
+ user_agent = None ,
211
316
):
212
317
"""Creates a connection to a Google Cloud Spanner database.
213
318
@@ -261,4 +366,8 @@ def connect(
261
366
if not database .exists ():
262
367
raise ValueError ("database '%s' does not exist." % database_id )
263
368
264
- return Connection (instance , database )
369
+ conn = Connection (instance , database )
370
+ if pool is not None :
371
+ conn ._own_pool = False
372
+
373
+ return conn
0 commit comments