Skip to content

Commit d59d502

Browse files
author
Ilya Gurov
authored
feat(dbapi): add aborted transactions retry support (#168)
Fixes #34. See googleapis/python-spanner-django#544.
1 parent e801a2e commit d59d502

File tree

9 files changed

+1109
-24
lines changed

9 files changed

+1109
-24
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""API to calculate checksums of SQL statements results."""
16+
17+
import hashlib
18+
import pickle
19+
20+
from google.cloud.spanner_dbapi.exceptions import RetryAborted
21+
22+
23+
class ResultsChecksum:
24+
"""Cumulative checksum.
25+
26+
Used to calculate a total checksum of all the results
27+
returned by operations executed within transaction.
28+
Includes methods for checksums comparison.
29+
These checksums are used while retrying an aborted
30+
transaction to check if the results of a retried transaction
31+
are equal to the results of the original transaction.
32+
"""
33+
34+
def __init__(self):
35+
self.checksum = hashlib.sha256()
36+
self.count = 0 # counter of consumed results
37+
38+
def __len__(self):
39+
"""Return the number of consumed results.
40+
41+
:rtype: :class:`int`
42+
:returns: The number of results.
43+
"""
44+
return self.count
45+
46+
def __eq__(self, other):
47+
"""Check if checksums are equal.
48+
49+
:type other: :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
50+
:param other: Another checksum to compare with this one.
51+
"""
52+
return self.checksum.digest() == other.checksum.digest()
53+
54+
def consume_result(self, result):
55+
"""Add the given result into the checksum.
56+
57+
:type result: Union[int, list]
58+
:param result: Streamed row or row count from an UPDATE operation.
59+
"""
60+
self.checksum.update(pickle.dumps(result))
61+
self.count += 1
62+
63+
64+
def _compare_checksums(original, retried):
65+
"""Compare the given checksums.
66+
67+
Raise an error if the given checksums are not equal.
68+
69+
:type original: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum`
70+
:param original: results checksum of the original transaction.
71+
72+
:type retried: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum`
73+
:param retried: results checksum of the retried transaction.
74+
75+
:raises: :exc:`google.cloud.spanner_dbapi.exceptions.RetryAborted` in case if checksums are not equal.
76+
"""
77+
if retried != original:
78+
raise RetryAborted(
79+
"The transaction was aborted and could not be retried due to a concurrent modification."
80+
)

google/cloud/spanner_dbapi/connection.py

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,24 @@
1414

1515
"""DB-API Connection for the Google Cloud Spanner."""
1616

17+
import time
1718
import warnings
1819

20+
from google.api_core.exceptions import Aborted
1921
from google.api_core.gapic_v1.client_info import ClientInfo
2022
from google.cloud import spanner_v1 as spanner
23+
from google.cloud.spanner_v1.session import _get_retry_delay
2124

25+
from google.cloud.spanner_dbapi.checksum import _compare_checksums
26+
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
2227
from google.cloud.spanner_dbapi.cursor import Cursor
2328
from google.cloud.spanner_dbapi.exceptions import InterfaceError
2429
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
2530
from google.cloud.spanner_dbapi.version import PY_VERSION
2631

2732

2833
AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
34+
MAX_INTERNAL_RETRIES = 50
2935

3036

3137
class Connection:
@@ -48,9 +54,16 @@ def __init__(self, instance, database):
4854

4955
self._transaction = None
5056
self._session = None
57+
# SQL statements, which were executed
58+
# within the current transaction
59+
self._statements = []
5160

5261
self.is_closed = False
5362
self._autocommit = False
63+
# indicator to know if the session pool used by
64+
# this connection should be cleared on the
65+
# connection close
66+
self._own_pool = True
5467

5568
@property
5669
def autocommit(self):
@@ -114,6 +127,58 @@ def _release_session(self):
114127
self.database._pool.put(self._session)
115128
self._session = None
116129

130+
def retry_transaction(self):
131+
"""Retry the aborted transaction.
132+
133+
All the statements executed in the original transaction
134+
will be re-executed in new one. Results checksums of the
135+
original statements and the retried ones will be compared.
136+
137+
:raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
138+
If results checksum of the retried statement is
139+
not equal to the checksum of the original one.
140+
"""
141+
attempt = 0
142+
while True:
143+
self._transaction = None
144+
attempt += 1
145+
if attempt > MAX_INTERNAL_RETRIES:
146+
raise
147+
148+
try:
149+
self._rerun_previous_statements()
150+
break
151+
except Aborted as exc:
152+
delay = _get_retry_delay(exc.errors[0], attempt)
153+
if delay:
154+
time.sleep(delay)
155+
156+
def _rerun_previous_statements(self):
157+
"""
158+
Helper to run all the remembered statements
159+
from the last transaction.
160+
"""
161+
for statement in self._statements:
162+
res_iter, retried_checksum = self.run_statement(statement, retried=True)
163+
# executing all the completed statements
164+
if statement != self._statements[-1]:
165+
for res in res_iter:
166+
retried_checksum.consume_result(res)
167+
168+
_compare_checksums(statement.checksum, retried_checksum)
169+
# executing the failed statement
170+
else:
171+
# streaming up to the failed result or
172+
# to the end of the streaming iterator
173+
while len(retried_checksum) < len(statement.checksum):
174+
try:
175+
res = next(iter(res_iter))
176+
retried_checksum.consume_result(res)
177+
except StopIteration:
178+
break
179+
180+
_compare_checksums(statement.checksum, retried_checksum)
181+
117182
def transaction_checkout(self):
118183
"""Get a Cloud Spanner transaction.
119184
@@ -158,6 +223,9 @@ def close(self):
158223
):
159224
self._transaction.rollback()
160225

226+
if self._own_pool:
227+
self.database._pool.clear()
228+
161229
self.is_closed = True
162230

163231
def commit(self):
@@ -168,8 +236,13 @@ def commit(self):
168236
if self._autocommit:
169237
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
170238
elif self._transaction:
171-
self._transaction.commit()
172-
self._release_session()
239+
try:
240+
self._transaction.commit()
241+
self._release_session()
242+
self._statements = []
243+
except Aborted:
244+
self.retry_transaction()
245+
self.commit()
173246

174247
def rollback(self):
175248
"""Rolls back any pending transaction.
@@ -182,6 +255,7 @@ def rollback(self):
182255
elif self._transaction:
183256
self._transaction.rollback()
184257
self._release_session()
258+
self._statements = []
185259

186260
def cursor(self):
187261
"""Factory to create a DB-API Cursor."""
@@ -198,6 +272,32 @@ def run_prior_DDL_statements(self):
198272

199273
return self.database.update_ddl(ddl_statements).result()
200274

275+
def run_statement(self, statement, retried=False):
276+
"""Run single SQL statement in begun transaction.
277+
278+
This method is never used in autocommit mode. In
279+
!autocommit mode however it remembers every executed
280+
SQL statement with its parameters.
281+
282+
:type statement: :class:`dict`
283+
:param statement: SQL statement to execute.
284+
285+
:rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`,
286+
:class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
287+
:returns: Streamed result set of the statement and a
288+
checksum of this statement results.
289+
"""
290+
transaction = self.transaction_checkout()
291+
if not retried:
292+
self._statements.append(statement)
293+
294+
return (
295+
transaction.execute_sql(
296+
statement.sql, statement.params, param_types=statement.param_types,
297+
),
298+
ResultsChecksum() if retried else statement.checksum,
299+
)
300+
201301
def __enter__(self):
202302
return self
203303

@@ -207,7 +307,12 @@ def __exit__(self, etype, value, traceback):
207307

208308

209309
def connect(
210-
instance_id, database_id, project=None, credentials=None, pool=None, user_agent=None
310+
instance_id,
311+
database_id,
312+
project=None,
313+
credentials=None,
314+
pool=None,
315+
user_agent=None,
211316
):
212317
"""Creates a connection to a Google Cloud Spanner database.
213318
@@ -261,4 +366,8 @@ def connect(
261366
if not database.exists():
262367
raise ValueError("database '%s' does not exist." % database_id)
263368

264-
return Connection(instance, database)
369+
conn = Connection(instance, database)
370+
if pool is not None:
371+
conn._own_pool = False
372+
373+
return conn

google/cloud/spanner_dbapi/cursor.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Database cursor for Google Cloud Spanner DB-API."""
1616

17+
from google.api_core.exceptions import Aborted
1718
from google.api_core.exceptions import AlreadyExists
1819
from google.api_core.exceptions import FailedPrecondition
1920
from google.api_core.exceptions import InternalServerError
@@ -22,7 +23,7 @@
2223
from collections import namedtuple
2324

2425
from google.cloud import spanner_v1 as spanner
25-
26+
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
2627
from google.cloud.spanner_dbapi.exceptions import IntegrityError
2728
from google.cloud.spanner_dbapi.exceptions import InterfaceError
2829
from google.cloud.spanner_dbapi.exceptions import OperationalError
@@ -34,11 +35,13 @@
3435

3536
from google.cloud.spanner_dbapi import parse_utils
3637
from google.cloud.spanner_dbapi.parse_utils import get_param_types
38+
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
3739
from google.cloud.spanner_dbapi.utils import PeekIterator
3840

3941
_UNSET_COUNT = -1
4042

4143
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
44+
Statement = namedtuple("Statement", "sql, params, param_types, checksum")
4245

4346

4447
class Cursor(object):
@@ -54,6 +57,8 @@ def __init__(self, connection):
5457
self._row_count = _UNSET_COUNT
5558
self.connection = connection
5659
self._is_closed = False
60+
# the currently running SQL statement results checksum
61+
self._checksum = None
5762

5863
# the number of rows to fetch at a time with fetchmany()
5964
self.arraysize = 1
@@ -166,12 +171,13 @@ def execute(self, sql, args=None):
166171
self.connection.run_prior_DDL_statements()
167172

168173
if not self.connection.autocommit:
169-
transaction = self.connection.transaction_checkout()
170-
171-
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, args)
174+
sql, params = sql_pyformat_args_to_spanner(sql, args)
172175

173-
self._result_set = transaction.execute_sql(
174-
sql, params, param_types=get_param_types(params)
176+
statement = Statement(
177+
sql, params, get_param_types(params), ResultsChecksum(),
178+
)
179+
(self._result_set, self._checksum,) = self.connection.run_statement(
180+
statement
175181
)
176182
self._itr = PeekIterator(self._result_set)
177183
return
@@ -213,9 +219,31 @@ def fetchone(self):
213219
self._raise_if_closed()
214220

215221
try:
216-
return next(self)
222+
res = next(self)
223+
self._checksum.consume_result(res)
224+
return res
217225
except StopIteration:
218-
return None
226+
return
227+
except Aborted:
228+
self.connection.retry_transaction()
229+
return self.fetchone()
230+
231+
def fetchall(self):
232+
"""Fetch all (remaining) rows of a query result, returning them as
233+
a sequence of sequences.
234+
"""
235+
self._raise_if_closed()
236+
237+
res = []
238+
try:
239+
for row in self:
240+
self._checksum.consume_result(row)
241+
res.append(row)
242+
except Aborted:
243+
self._connection.retry_transaction()
244+
return self.fetchall()
245+
246+
return res
219247

220248
def fetchmany(self, size=None):
221249
"""Fetch the next set of rows of a query result, returning a sequence
@@ -236,20 +264,17 @@ def fetchmany(self, size=None):
236264
items = []
237265
for i in range(size):
238266
try:
239-
items.append(tuple(self.__next__()))
267+
res = next(self)
268+
self._checksum.consume_result(res)
269+
items.append(res)
240270
except StopIteration:
241271
break
272+
except Aborted:
273+
self._connection.retry_transaction()
274+
return self.fetchmany(size)
242275

243276
return items
244277

245-
def fetchall(self):
246-
"""Fetch all (remaining) rows of a query result, returning them as
247-
a sequence of sequences.
248-
"""
249-
self._raise_if_closed()
250-
251-
return list(self.__iter__())
252-
253278
def nextset(self):
254279
"""A no-op, raising an error if the cursor or connection is closed."""
255280
self._raise_if_closed()

0 commit comments

Comments
 (0)