Skip to content

Commit cdaf25b

Browse files
fix: ensure transactions rollback on failure (#767)
1 parent 9840d43 commit cdaf25b

File tree

6 files changed

+406
-493
lines changed

6 files changed

+406
-493
lines changed

google/cloud/firestore_v1/async_transaction.py

Lines changed: 31 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ async def _rollback(self) -> None:
110110
111111
Raises:
112112
ValueError: If no transaction is in progress.
113+
google.api_core.exceptions.GoogleAPICallError: If the rollback fails.
113114
"""
114115
if not self.in_progress:
115116
raise ValueError(_CANT_ROLLBACK)
@@ -124,6 +125,7 @@ async def _rollback(self) -> None:
124125
metadata=self._client._rpc_metadata,
125126
)
126127
finally:
128+
# clean up, even if rollback fails
127129
self._clean_up()
128130

129131
async def _commit(self) -> list:
@@ -223,10 +225,6 @@ async def _pre_commit(
223225
) -> Coroutine:
224226
"""Begin transaction and call the wrapped coroutine.
225227
226-
If the coroutine raises an exception, the transaction will be rolled
227-
back. If not, the transaction will be "ready" for ``Commit`` (i.e.
228-
it will have staged writes).
229-
230228
Args:
231229
transaction
232230
(:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
@@ -250,41 +248,7 @@ async def _pre_commit(
250248
self.current_id = transaction._id
251249
if self.retry_id is None:
252250
self.retry_id = self.current_id
253-
try:
254-
return await self.to_wrap(transaction, *args, **kwargs)
255-
except: # noqa
256-
# NOTE: If ``rollback`` fails this will lose the information
257-
# from the original failure.
258-
await transaction._rollback()
259-
raise
260-
261-
async def _maybe_commit(self, transaction: AsyncTransaction) -> bool:
262-
"""Try to commit the transaction.
263-
264-
If the transaction is read-write and the ``Commit`` fails with the
265-
``ABORTED`` status code, it will be retried. Any other failure will
266-
not be caught.
267-
268-
Args:
269-
transaction
270-
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
271-
The transaction to be ``Commit``-ed.
272-
273-
Returns:
274-
bool: Indicating if the commit succeeded.
275-
"""
276-
try:
277-
await transaction._commit()
278-
return True
279-
except exceptions.GoogleAPICallError as exc:
280-
if transaction._read_only:
281-
raise
282-
283-
if isinstance(exc, exceptions.Aborted):
284-
# If a read-write transaction returns ABORTED, retry.
285-
return False
286-
else:
287-
raise
251+
return await self.to_wrap(transaction, *args, **kwargs)
288252

289253
async def __call__(self, transaction, *args, **kwargs):
290254
"""Execute the wrapped callable within a transaction.
@@ -306,22 +270,35 @@ async def __call__(self, transaction, *args, **kwargs):
306270
``max_attempts``.
307271
"""
308272
self._reset()
273+
retryable_exceptions = (
274+
(exceptions.Aborted) if not transaction._read_only else ()
275+
)
276+
last_exc = None
309277

310-
for attempt in range(transaction._max_attempts):
311-
result = await self._pre_commit(transaction, *args, **kwargs)
312-
succeeded = await self._maybe_commit(transaction)
313-
if succeeded:
314-
return result
315-
316-
# Subsequent requests will use the failed transaction ID as part of
317-
# the ``BeginTransactionRequest`` when restarting this transaction
318-
# (via ``options.retry_transaction``). This preserves the "spot in
319-
# line" of the transaction, so exponential backoff is not required
320-
# in this case.
321-
322-
await transaction._rollback()
323-
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
324-
raise ValueError(msg)
278+
try:
279+
for attempt in range(transaction._max_attempts):
280+
result = await self._pre_commit(transaction, *args, **kwargs)
281+
try:
282+
await transaction._commit()
283+
return result
284+
except retryable_exceptions as exc:
285+
last_exc = exc
286+
# Retry attempts that result in retryable exceptions
287+
# Subsequent requests will use the failed transaction ID as part of
288+
# the ``BeginTransactionRequest`` when restarting this transaction
289+
# (via ``options.retry_transaction``). This preserves the "spot in
290+
# line" of the transaction, so exponential backoff is not required
291+
# in this case.
292+
# retries exhausted
293+
# wrap the last exception in a ValueError before raising
294+
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
295+
raise ValueError(msg) from last_exc
296+
297+
except BaseException:
298+
# rollback the transaction on any error
299+
# errors raised during _rollback will be chained to the original error through __context__
300+
await transaction._rollback()
301+
raise
325302

326303

327304
def async_transactional(

google/cloud/firestore_v1/base_transaction.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,5 @@ def _reset(self) -> None:
185185
def _pre_commit(self, transaction, *args, **kwargs) -> NoReturn:
186186
raise NotImplementedError
187187

188-
def _maybe_commit(self, transaction) -> NoReturn:
189-
raise NotImplementedError
190-
191188
def __call__(self, transaction, *args, **kwargs):
192189
raise NotImplementedError

google/cloud/firestore_v1/transaction.py

Lines changed: 31 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
# Types needed only for Type Hints
4545
from google.cloud.firestore_v1.base_document import DocumentSnapshot
4646
from google.cloud.firestore_v1.types import CommitResponse
47-
from typing import Any, Callable, Generator, Optional
47+
from typing import Any, Callable, Generator
4848

4949

5050
class Transaction(batch.WriteBatch, BaseTransaction):
@@ -108,6 +108,7 @@ def _rollback(self) -> None:
108108
109109
Raises:
110110
ValueError: If no transaction is in progress.
111+
google.api_core.exceptions.GoogleAPICallError: If the rollback fails.
111112
"""
112113
if not self.in_progress:
113114
raise ValueError(_CANT_ROLLBACK)
@@ -122,6 +123,7 @@ def _rollback(self) -> None:
122123
metadata=self._client._rpc_metadata,
123124
)
124125
finally:
126+
# clean up, even if rollback fails
125127
self._clean_up()
126128

127129
def _commit(self) -> list:
@@ -214,10 +216,6 @@ def __init__(self, to_wrap) -> None:
214216
def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any:
215217
"""Begin transaction and call the wrapped callable.
216218
217-
If the callable raises an exception, the transaction will be rolled
218-
back. If not, the transaction will be "ready" for ``Commit`` (i.e.
219-
it will have staged writes).
220-
221219
Args:
222220
transaction
223221
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
@@ -241,41 +239,7 @@ def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any:
241239
self.current_id = transaction._id
242240
if self.retry_id is None:
243241
self.retry_id = self.current_id
244-
try:
245-
return self.to_wrap(transaction, *args, **kwargs)
246-
except: # noqa
247-
# NOTE: If ``rollback`` fails this will lose the information
248-
# from the original failure.
249-
transaction._rollback()
250-
raise
251-
252-
def _maybe_commit(self, transaction: Transaction) -> Optional[bool]:
253-
"""Try to commit the transaction.
254-
255-
If the transaction is read-write and the ``Commit`` fails with the
256-
``ABORTED`` status code, it will be retried. Any other failure will
257-
not be caught.
258-
259-
Args:
260-
transaction
261-
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
262-
The transaction to be ``Commit``-ed.
263-
264-
Returns:
265-
bool: Indicating if the commit succeeded.
266-
"""
267-
try:
268-
transaction._commit()
269-
return True
270-
except exceptions.GoogleAPICallError as exc:
271-
if transaction._read_only:
272-
raise
273-
274-
if isinstance(exc, exceptions.Aborted):
275-
# If a read-write transaction returns ABORTED, retry.
276-
return False
277-
else:
278-
raise
242+
return self.to_wrap(transaction, *args, **kwargs)
279243

280244
def __call__(self, transaction: Transaction, *args, **kwargs):
281245
"""Execute the wrapped callable within a transaction.
@@ -297,22 +261,34 @@ def __call__(self, transaction: Transaction, *args, **kwargs):
297261
``max_attempts``.
298262
"""
299263
self._reset()
264+
retryable_exceptions = (
265+
(exceptions.Aborted) if not transaction._read_only else ()
266+
)
267+
last_exc = None
300268

301-
for attempt in range(transaction._max_attempts):
302-
result = self._pre_commit(transaction, *args, **kwargs)
303-
succeeded = self._maybe_commit(transaction)
304-
if succeeded:
305-
return result
306-
307-
# Subsequent requests will use the failed transaction ID as part of
308-
# the ``BeginTransactionRequest`` when restarting this transaction
309-
# (via ``options.retry_transaction``). This preserves the "spot in
310-
# line" of the transaction, so exponential backoff is not required
311-
# in this case.
312-
313-
transaction._rollback()
314-
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
315-
raise ValueError(msg)
269+
try:
270+
for attempt in range(transaction._max_attempts):
271+
result = self._pre_commit(transaction, *args, **kwargs)
272+
try:
273+
transaction._commit()
274+
return result
275+
except retryable_exceptions as exc:
276+
last_exc = exc
277+
# Retry attempts that result in retryable exceptions
278+
# Subsequent requests will use the failed transaction ID as part of
279+
# the ``BeginTransactionRequest`` when restarting this transaction
280+
# (via ``options.retry_transaction``). This preserves the "spot in
281+
# line" of the transaction, so exponential backoff is not required
282+
# in this case.
283+
# retries exhausted
284+
# wrap the last exception in a ValueError before raising
285+
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
286+
raise ValueError(msg) from last_exc
287+
except BaseException: # noqa: B901
288+
# rollback the transaction on any error
289+
# errors raised during _rollback will be chained to the original error through __context__
290+
transaction._rollback()
291+
raise
316292

317293

318294
def transactional(to_wrap: Callable) -> _Transactional:

google/cloud/firestore_v1/watch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,9 @@ def _on_snapshot_target_change_remove(self, target_change):
401401

402402
error_message = "Error %s: %s" % (code, message)
403403

404-
raise RuntimeError(error_message)
404+
raise RuntimeError(error_message) from exceptions.from_grpc_status(
405+
code, message
406+
)
405407

406408
def _on_snapshot_target_change_reset(self, target_change):
407409
# Whatever changes have happened so far no longer matter.

0 commit comments

Comments
 (0)