Skip to content
This repository was archived by the owner on Apr 15, 2025. It is now read-only.

Commit 875bb3d

Browse files
committed
feat(client): add transaction isolation level
1 parent 3f16eba commit 875bb3d

File tree

5 files changed

+349
-30
lines changed

5 files changed

+349
-30
lines changed

databases/sync_tests/test_transactions.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,124 @@ def test_transaction_already_closed(client: Prisma) -> None:
201201
transaction.user.delete_many()
202202

203203
assert exc.match('Transaction already closed')
204+
205+
206+
@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
207+
def test_read_uncommited_isolation_level(client: Prisma) -> None:
208+
"""A transaction isolation level is set to `READ_UNCOMMITED`"""
209+
client2 = Prisma()
210+
client2.connect()
211+
212+
user = client.user.create(data={'name': 'Robert'})
213+
214+
with client.tx(isolation_level=prisma.TransactionIsolationLevel.READ_UNCOMMITED) as tx1:
215+
tx1_user = tx1.user.find_first_or_raise(where={'id': user.id})
216+
tx1_count = tx1.user.count()
217+
218+
with client2.tx() as tx2:
219+
tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
220+
tx2.user.create(data={'name': 'Bobby'})
221+
222+
dirty_user = tx1.user.find_first_or_raise(where={'id': user.id})
223+
224+
non_repeatable_user = tx1.user.find_first_or_raise(where={'id': user.id})
225+
phantom_count = tx1.user.count()
226+
227+
# Have dirty read
228+
assert tx1_user.name != dirty_user.name
229+
# Have non-repeatable read
230+
assert tx1_user.name != non_repeatable_user.name
231+
# Have phantom read
232+
assert tx1_count != phantom_count
233+
234+
235+
@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
236+
def test_read_commited_isolation_level(client: Prisma) -> None:
237+
"""A transaction isolation level is set to `READ_COMMITED`"""
238+
client2 = Prisma()
239+
client2.connect()
240+
241+
user = client.user.create(data={'name': 'Robert'})
242+
243+
with client.tx(isolation_level=prisma.TransactionIsolationLevel.READ_COMMITED) as tx1:
244+
tx1_user = tx1.user.find_first_or_raise(where={'id': user.id})
245+
tx1_count = tx1.user.count()
246+
247+
with client2.tx() as tx2:
248+
tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
249+
tx2.user.create(data={'name': 'Bobby'})
250+
251+
dirty_user = tx1.user.find_first_or_raise(where={'id': user.id})
252+
253+
non_repeatable_user = tx1.user.find_first_or_raise(where={'id': user.id})
254+
phantom_count = tx1.user.count()
255+
256+
# No dirty read
257+
assert tx1_user.name == dirty_user.name
258+
# Have non-repeatable read
259+
assert tx1_user.name != non_repeatable_user.name
260+
# Have phantom read
261+
assert tx1_count != phantom_count
262+
263+
264+
@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
265+
def test_repeatable_read_isolation_level(client: Prisma) -> None:
266+
"""A transaction isolation level is set to `REPEATABLE_READ`"""
267+
client2 = Prisma()
268+
client2.connect()
269+
270+
user = client.user.create(data={'name': 'Robert'})
271+
272+
with client.tx(isolation_level=prisma.TransactionIsolationLevel.REPEATABLE_READ) as tx1:
273+
tx1_user = tx1.user.find_first_or_raise(where={'id': user.id})
274+
tx1_count = tx1.user.count()
275+
276+
with client2.tx() as tx2:
277+
tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
278+
tx2.user.create(data={'name': 'Bobby'})
279+
280+
dirty_user = tx1.user.find_first_or_raise(where={'id': user.id})
281+
282+
non_repeatable_user = tx1.user.find_first_or_raise(where={'id': user.id})
283+
phantom_count = tx1.user.count()
284+
285+
# No dirty read
286+
assert tx1_user.name == dirty_user.name
287+
# No non-repeatable read
288+
assert tx1_user.name == non_repeatable_user.name
289+
# Have phantom read
290+
assert tx1_count != phantom_count
291+
292+
293+
@pytest.mark.skipif(True, reason='Available for SQL Server only')
294+
def test_snapshot_isolation_level() -> None:
295+
"""A transaction isolation level is set to `SNAPSHOT`"""
296+
raise NotImplementedError
297+
298+
299+
def test_serializable_isolation_level(client: Prisma) -> None:
300+
"""A transaction isolation level is set to `SERIALIZABLE`"""
301+
client2 = Prisma()
302+
client2.connect()
303+
304+
user = client.user.create(data={'name': 'Robert'})
305+
306+
with client.tx(isolation_level=prisma.TransactionIsolationLevel.SERIALIZABLE) as tx1:
307+
tx1_user = tx1.user.find_first_or_raise(where={'id': user.id})
308+
tx1_count = tx1.user.count()
309+
310+
with client2.tx() as tx2:
311+
tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
312+
tx2.user.create(data={'name': 'Bobby'})
313+
314+
dirty_user = tx1.user.find_first_or_raise(where={'id': user.id})
315+
316+
non_repeatable_user = tx1.user.find_first_or_raise(where={'id': user.id})
317+
phantom_count = tx1.user.count()
318+
319+
# No dirty read
320+
assert tx1_user.name == dirty_user.name
321+
# No non-repeatable read
322+
assert tx1_user.name == non_repeatable_user.name
323+
# No phantom read
324+
assert tx1_count == phantom_count

databases/tests/test_transactions.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,129 @@ async def test_transaction_already_closed(client: Prisma) -> None:
212212
await transaction.user.delete_many()
213213

214214
assert exc.match('Transaction already closed')
215+
216+
217+
@pytest.mark.asyncio
218+
@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
219+
async def test_read_uncommited_isolation_level(client: Prisma) -> None:
220+
"""A transaction isolation level is set to `READ_UNCOMMITED`"""
221+
client2 = Prisma()
222+
await client2.connect()
223+
224+
user = await client.user.create(data={'name': 'Robert'})
225+
226+
async with client.tx(isolation_level=prisma.TransactionIsolationLevel.READ_UNCOMMITED) as tx1:
227+
tx1_user = await tx1.user.find_first_or_raise(where={'id': user.id})
228+
tx1_count = await tx1.user.count()
229+
230+
async with client2.tx() as tx2:
231+
await tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
232+
await tx2.user.create(data={'name': 'Bobby'})
233+
234+
dirty_user = await tx1.user.find_first_or_raise(where={'id': user.id})
235+
236+
non_repeatable_user = await tx1.user.find_first_or_raise(where={'id': user.id})
237+
phantom_count = await tx1.user.count()
238+
239+
# Have dirty read
240+
assert tx1_user.name != dirty_user.name
241+
# Have non-repeatable read
242+
assert tx1_user.name != non_repeatable_user.name
243+
# Have phantom read
244+
assert tx1_count != phantom_count
245+
246+
247+
@pytest.mark.asyncio
248+
@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
249+
async def test_read_commited_isolation_level(client: Prisma) -> None:
250+
"""A transaction isolation level is set to `READ_COMMITED`"""
251+
client2 = Prisma()
252+
await client2.connect()
253+
254+
user = await client.user.create(data={'name': 'Robert'})
255+
256+
async with client.tx(isolation_level=prisma.TransactionIsolationLevel.READ_COMMITED) as tx1:
257+
tx1_user = await tx1.user.find_first_or_raise(where={'id': user.id})
258+
tx1_count = await tx1.user.count()
259+
260+
async with client2.tx() as tx2:
261+
await tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
262+
await tx2.user.create(data={'name': 'Bobby'})
263+
264+
dirty_user = await tx1.user.find_first_or_raise(where={'id': user.id})
265+
266+
non_repeatable_user = await tx1.user.find_first_or_raise(where={'id': user.id})
267+
phantom_count = await tx1.user.count()
268+
269+
# No dirty read
270+
assert tx1_user.name == dirty_user.name
271+
# Have non-repeatable read
272+
assert tx1_user.name != non_repeatable_user.name
273+
# Have phantom read
274+
assert tx1_count != phantom_count
275+
276+
277+
@pytest.mark.asyncio
278+
@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
279+
async def test_repeatable_read_isolation_level(client: Prisma) -> None:
280+
"""A transaction isolation level is set to `REPEATABLE_READ`"""
281+
client2 = Prisma()
282+
await client2.connect()
283+
284+
user = await client.user.create(data={'name': 'Robert'})
285+
286+
async with client.tx(isolation_level=prisma.TransactionIsolationLevel.REPEATABLE_READ) as tx1:
287+
tx1_user = await tx1.user.find_first_or_raise(where={'id': user.id})
288+
tx1_count = await tx1.user.count()
289+
290+
async with client2.tx() as tx2:
291+
await tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
292+
await tx2.user.create(data={'name': 'Bobby'})
293+
294+
dirty_user = await tx1.user.find_first_or_raise(where={'id': user.id})
295+
296+
non_repeatable_user = await tx1.user.find_first_or_raise(where={'id': user.id})
297+
phantom_count = await tx1.user.count()
298+
299+
# No dirty read
300+
assert tx1_user.name == dirty_user.name
301+
# No non-repeatable read
302+
assert tx1_user.name == non_repeatable_user.name
303+
# Have phantom read
304+
assert tx1_count != phantom_count
305+
306+
307+
@pytest.mark.asyncio
308+
@pytest.mark.skipif(True, reason='Available for SQL Server only')
309+
async def test_snapshot_isolation_level() -> None:
310+
"""A transaction isolation level is set to `SNAPSHOT`"""
311+
raise NotImplementedError
312+
313+
314+
@pytest.mark.asyncio
315+
async def test_serializable_isolation_level(client: Prisma) -> None:
316+
"""A transaction isolation level is set to `SERIALIZABLE`"""
317+
client2 = Prisma()
318+
await client2.connect()
319+
320+
user = await client.user.create(data={'name': 'Robert'})
321+
322+
async with client.tx(isolation_level=prisma.TransactionIsolationLevel.SERIALIZABLE) as tx1:
323+
tx1_user = await tx1.user.find_first_or_raise(where={'id': user.id})
324+
tx1_count = await tx1.user.count()
325+
326+
async with client2.tx() as tx2:
327+
await tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
328+
await tx2.user.create(data={'name': 'Bobby'})
329+
330+
dirty_user = await tx1.user.find_first_or_raise(where={'id': user.id})
331+
332+
non_repeatable_user = await tx1.user.find_first_or_raise(where={'id': user.id})
333+
phantom_count = await tx1.user.count()
334+
335+
# No dirty read
336+
assert tx1_user.name == dirty_user.name
337+
# No non-repeatable read
338+
assert tx1_user.name == non_repeatable_user.name
339+
# No phantom read
340+
assert tx1_count == phantom_count

src/prisma/generator/templates/client.py.jinja

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
import logging
77
from datetime import timedelta
8+
from enum import Enum
89
from pathlib import Path
910
from types import TracebackType
1011

@@ -39,6 +40,7 @@ __all__ = (
3940
'Batch',
4041
'Prisma',
4142
'Client',
43+
'TransactionIsolationLevel',
4244
'load_env',
4345
'register',
4446
'get_client',
@@ -323,6 +325,7 @@ class Prisma:
323325
def tx(
324326
self,
325327
*,
328+
isolation_level: Optional['TransactionIsolationLevel'] = None,
326329
max_wait: Union[int, timedelta] = DEFAULT_TX_MAX_WAIT,
327330
timeout: Union[int, timedelta] = DEFAULT_TX_TIMEOUT,
328331
) -> 'TransactionManager':
@@ -332,6 +335,9 @@ class Prisma:
332335
actions within a transaction, queries will be isolated to the Prisma instance and
333336
will not be commited to the database until the context manager exits.
334337

338+
By default, Prisma sets the isolation level to the value currently configured in the database. You can modify this
339+
default with the `isolation_level` argument (see [supported isolation levels](https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels)).
340+
335341
By default, Prisma will wait a maximum of 2 seconds to acquire a transaction from the database. You can modify this
336342
default with the `max_wait` argument which accepts a value in milliseconds or `datetime.timedelta`.
337343

@@ -348,7 +354,7 @@ class Prisma:
348354

349355
In the above example, if the first database call succeeds but the second does not then neither of the records will be created.
350356
"""
351-
return TransactionManager(client=self, max_wait=max_wait, timeout=timeout)
357+
return TransactionManager(client=self, isolation_level=isolation_level, max_wait=max_wait, timeout=timeout)
352358

353359
def is_transaction(self) -> bool:
354360
"""Returns True if the client is wrapped within a transaction"""
@@ -475,16 +481,34 @@ class Prisma:
475481
}
476482

477483

484+
# See here: https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels
485+
class TransactionIsolationLevel(str, Enum):
486+
READ_UNCOMMITED = "ReadUncommitted"
487+
READ_COMMITED = "ReadCommitted"
488+
REPEATABLE_READ = "RepeatableRead"
489+
SNAPSHOT = "Snapshot"
490+
SERIALIZABLE = "Serializable"
491+
492+
478493
class TransactionManager:
479494
"""Context manager for wrapping a Prisma instance within a transaction.
480495

481496
This should never be created manually, instead it should be used
482497
through the Prisma.tx() method.
483498
"""
484499

485-
def __init__(self, *, client: Prisma, max_wait: Union[int, timedelta], timeout: Union[int, timedelta]) -> None:
500+
def __init__(
501+
self,
502+
*,
503+
client: Prisma,
504+
isolation_level: Optional['TransactionIsolationLevel'],
505+
max_wait: Union[int, timedelta],
506+
timeout: Union[int, timedelta],
507+
) -> None:
486508
self.__client = client
487509

510+
self._isolation_level = isolation_level
511+
488512
if isinstance(max_wait, int):
489513
message = (
490514
'Passing an int as `max_wait` argument is deprecated '
@@ -520,14 +544,14 @@ class TransactionManager:
520544
stacklevel=3 if _from_context else 2
521545
)
522546

523-
tx_id = {{ maybe_await }}self.__client._engine.start_transaction(
524-
content=dumps(
525-
{
526-
'timeout': int(self._timeout.total_seconds() * 1000),
527-
'max_wait': int(self._max_wait.total_seconds() * 1000),
528-
}
529-
),
530-
)
547+
content_dict: dict[str, Any] = {
548+
'timeout': int(self._timeout.total_seconds() * 1000),
549+
'max_wait': int(self._max_wait.total_seconds() * 1000),
550+
}
551+
if self._isolation_level:
552+
content_dict['isolation_level'] = self._isolation_level.value
553+
554+
tx_id = {{ maybe_await }}self.__client._engine.start_transaction(content=dumps(content_dict))
531555
self._tx_id = tx_id
532556
client = self.__client._copy()
533557
client._tx_id = tx_id

0 commit comments

Comments
 (0)