diff --git a/packages/core/src/transaction.ts b/packages/core/src/transaction.ts index 5c11262a1..1226173b2 100644 --- a/packages/core/src/transaction.ts +++ b/packages/core/src/transaction.ts @@ -220,6 +220,19 @@ class Transaction { return this._state === _states.ACTIVE } + /** + * Closes the transaction + * + * This method will roll back the transaction if it is not already committed or rolled back. + * + * @returns {Promise} An empty promise if closed successfully or error if any error happened during + */ + async close(): Promise { + if (this.isOpen()) { + await this.rollback() + } + } + _onErrorCallback(err: any): Promise { // error will be "acknowledged" by sending a RESET message // database will then forget about this transaction and cleanup all corresponding resources diff --git a/packages/core/test/transaction.test.ts b/packages/core/test/transaction.test.ts index d4a8546dc..1c3956668 100644 --- a/packages/core/test/transaction.test.ts +++ b/packages/core/test/transaction.test.ts @@ -60,6 +60,59 @@ describe('Transaction', () => { }) + describe('.close()', () => { + describe('when transaction is open', () => { + it('should roll back the transaction', async () => { + const connection = newFakeConnection() + const tx = newTransaction({ connection }) + + await tx.run('RETURN 1') + await tx.close() + + expect(connection.rollbackInvoked).toEqual(1) + }) + + it('should surface errors during the rollback', async () => { + const expectedError = new Error('rollback error') + const connection = newFakeConnection().withRollbackError(expectedError) + const tx = newTransaction({ connection }) + + await tx.run('RETURN 1') + + try { + await tx.close() + fail('should have thrown') + } catch (error) { + expect(error).toEqual(expectedError) + } + }) + }) + + describe('when transaction is closed', () => { + const commit = async (tx: Transaction) => tx.commit() + const rollback = async (tx: Transaction) => tx.rollback() + const error = async (tx: Transaction, conn: FakeConnection) => { + conn.withRollbackError(new Error('rollback error')) + return tx.rollback().catch(() => { }) + } + + it.each([ + ['commmited', commit], + ['rolled back', rollback], + ['with error', error] + ])('should not roll back the connection', async (_, operation) => { + const connection = newFakeConnection() + const tx = newTransaction({ connection }) + + await operation(tx, connection) + const rollbackInvokedAfterOperation = connection.rollbackInvoked + + await tx.close() + + expect(connection.rollbackInvoked).toEqual(rollbackInvokedAfterOperation) + }) + }) + }) }) function newTransaction({ @@ -69,9 +122,9 @@ function newTransaction({ lowRecordWatermark = 300 }: { connection: FakeConnection - fetchSize: number - highRecordWatermark: number, - lowRecordWatermark: number + fetchSize?: number + highRecordWatermark?: number, + lowRecordWatermark?: number }): Transaction { const connectionProvider = new ConnectionProvider() connectionProvider.acquireConnection = () => Promise.resolve(connection) diff --git a/packages/core/test/utils/connection.fake.ts b/packages/core/test/utils/connection.fake.ts index 2417e3483..aaae1e024 100644 --- a/packages/core/test/utils/connection.fake.ts +++ b/packages/core/test/utils/connection.fake.ts @@ -44,6 +44,8 @@ export default class FakeConnection extends Connection { public protocolErrorsHandled: number public seenProtocolErrors: string[] public seenRequestRoutingInformation: any[] + public rollbackInvoked: number + public _rollbackError: Error | null constructor() { super() @@ -64,6 +66,8 @@ export default class FakeConnection extends Connection { this.protocolErrorsHandled = 0 this.seenProtocolErrors = [] this.seenRequestRoutingInformation = [] + this.rollbackInvoked = 0 + this._rollbackError = null } get id(): string { @@ -105,6 +109,13 @@ export default class FakeConnection extends Connection { beginTransaction: () => { return Promise.resolve() }, + rollbackTransaction: () => { + this.rollbackInvoked ++ + if (this._rollbackError !== null) { + return mockResultStreamObserverWithError('ROLLBACK', {}, this._rollbackError) + } + return mockResultStreamObserver('ROLLBACK', {}) + }, requestRoutingInformation: (params: any | undefined) => { this.seenRequestRoutingInformation.push(params) if (this._requestRoutingInformationMock) { @@ -161,12 +172,27 @@ export default class FakeConnection extends Connection { return this } + withRollbackError(error: Error) { + this._rollbackError = error + return this + } + closed() { this._open = false return this } } +function mockResultStreamObserverWithError (query: string, parameters: any | undefined, error: Error) { + const observer = mockResultStreamObserver(query, parameters) + observer.subscribe = (observer: ResultObserver) => { + if (observer && observer.onError) { + observer.onError(error) + } + } + return observer +} + function mockResultStreamObserver(query: string, parameters: any | undefined): ResultStreamObserver { return { onError: (error: any) => { }, diff --git a/packages/neo4j-driver/src/transaction-rx.js b/packages/neo4j-driver/src/transaction-rx.js index 5fa713c5c..11acac062 100644 --- a/packages/neo4j-driver/src/transaction-rx.js +++ b/packages/neo4j-driver/src/transaction-rx.js @@ -90,4 +90,22 @@ export default class RxTransaction { .catch(err => observer.error(err)) }) } + + /** + * Closes the transaction + * + * This method will roll back the transaction if it is not already committed or rolled back. + * + * @returns {Observable} - An empty observable + */ + close () { + return new Observable(observer => { + this._txc + .close() + .then(() => { + observer.complete() + }) + .catch(err => observer.error(err)) + }) + } } diff --git a/packages/neo4j-driver/test/rx/transaction.test.js b/packages/neo4j-driver/test/rx/transaction.test.js index 783a0620b..14ca8fa0b 100644 --- a/packages/neo4j-driver/test/rx/transaction.test.js +++ b/packages/neo4j-driver/test/rx/transaction.test.js @@ -29,6 +29,7 @@ import { } from 'rxjs/operators' import neo4j from '../../src' import RxSession from '../../src/session-rx' +import RxTransaction from '../../src/transaction-rx' import sharedNeo4j from '../internal/shared-neo4j' import { newError } from 'neo4j-driver-core' @@ -148,6 +149,35 @@ describe('#integration-rx transaction', () => { expect(await countNodes(42)).toBe(0) }) + it('should run query and close', async () => { + if (protocolVersion < 4.0) { + return + } + + const result = await session + .beginTransaction() + .pipe( + flatMap(txc => + txc + .run('CREATE (n:Node {id: 42}) RETURN n') + .records() + .pipe( + map(r => r.get('n').properties.id), + concat(txc.close()) + ) + ), + materialize(), + toArray() + ) + .toPromise() + expect(result).toEqual([ + Notification.createNext(neo4j.int(42)), + Notification.createComplete() + ]) + + expect(await countNodes(42)).toBe(0) + }) + it('should run multiple queries and commit', async () => { await verifyCanRunMultipleQueries(true) }) @@ -720,3 +750,37 @@ describe('#integration-rx transaction', () => { .toPromise() } }) + +describe('#unit', () => { + describe('.close()', () => { + it('should delegate to the original Transaction', async () => { + const txc = { + close: jasmine.createSpy('close').and.returnValue(Promise.resolve()) + } + + const transaction = new RxTransaction(txc) + + await transaction.close().toPromise() + + expect(txc.close).toHaveBeenCalled() + }) + + it('should fail if to the original Transaction.close call fails', async () => { + const expectedError = new Error('expected') + const txc = { + close: jasmine + .createSpy('close') + .and.returnValue(Promise.reject(expectedError)) + } + + const transaction = new RxTransaction(txc) + + try { + await transaction.close().toPromise() + fail('should have thrown') + } catch (error) { + expect(error).toBe(expectedError) + } + }) + }) +}) diff --git a/packages/neo4j-driver/test/types/transaction-rx.test.ts b/packages/neo4j-driver/test/types/transaction-rx.test.ts index 6f9275d8b..b10139ceb 100644 --- a/packages/neo4j-driver/test/types/transaction-rx.test.ts +++ b/packages/neo4j-driver/test/types/transaction-rx.test.ts @@ -68,3 +68,7 @@ tx.commit() tx.rollback() .pipe(concat(of('rolled back'))) .subscribe(stringObserver) + +tx.close() + .pipe(concat(of('closed'))) + .subscribe(stringObserver) diff --git a/packages/neo4j-driver/types/transaction-rx.d.ts b/packages/neo4j-driver/types/transaction-rx.d.ts index ddf69708c..64d6494a7 100644 --- a/packages/neo4j-driver/types/transaction-rx.d.ts +++ b/packages/neo4j-driver/types/transaction-rx.d.ts @@ -26,6 +26,8 @@ declare interface RxTransaction { commit(): Observable rollback(): Observable + + close(): Observable } export default RxTransaction diff --git a/packages/testkit-backend/src/request-handlers.js b/packages/testkit-backend/src/request-handlers.js index 4686d67c5..eaf8d2b44 100644 --- a/packages/testkit-backend/src/request-handlers.js +++ b/packages/testkit-backend/src/request-handlers.js @@ -289,6 +289,14 @@ export function TransactionRollback (context, data, wire) { .catch(e => wire.writeError(e)) } +export function TransactionClose (context, data, wire) { + const { txId: id } = data + const { tx } = context.getTx(id) + return tx.close() + .then(() => wire.writeResponse('Transaction', { id })) + .catch(e => wire.writeError(e)) +} + export function SessionLastBookmarks (context, data, wire) { const { sessionId } = data const session = context.getSession(sessionId) @@ -338,6 +346,7 @@ export function GetFeatures (_context, _params, wire) { 'Feature:Bolt:4.4', 'Feature:API:Result.List', 'Temporary:ConnectionAcquisitionTimeout', + 'Temporary:TransactionClose', ...SUPPORTED_TLS ] })