Skip to content

Commit 7ada359

Browse files
corylanouclaude
andauthored
Fix: Use proper context in acquireReadLock during DB close (#701)
Co-authored-by: Claude <[email protected]>
1 parent 347dabb commit 7ada359

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

db.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ func (db *DB) Close(ctx context.Context) (err error) {
326326

327327
// init initializes the connection to the database.
328328
// Skipped if already initialized or if the database file does not exist.
329-
func (db *DB) init() (err error) {
329+
func (db *DB) init(ctx context.Context) (err error) {
330330
// Exit if already initialized.
331331
if db.db != nil {
332332
return nil
@@ -382,7 +382,7 @@ func (db *DB) init() (err error) {
382382
}
383383

384384
// Disable autocheckpoint for litestream's connection.
385-
if _, err := db.db.ExecContext(db.ctx, `PRAGMA wal_autocheckpoint = 0;`); err != nil {
385+
if _, err := db.db.ExecContext(ctx, `PRAGMA wal_autocheckpoint = 0;`); err != nil {
386386
return fmt.Errorf("disable autocheckpoint: %w", err)
387387
}
388388

@@ -400,7 +400,7 @@ func (db *DB) init() (err error) {
400400

401401
// Start a long-running read transaction to prevent other transactions
402402
// from checkpointing.
403-
if err := db.acquireReadLock(); err != nil {
403+
if err := db.acquireReadLock(ctx); err != nil {
404404
return fmt.Errorf("acquire read lock: %w", err)
405405
}
406406

@@ -479,7 +479,7 @@ func (db *DB) verifyHeadersMatch() error {
479479
*/
480480

481481
// acquireReadLock begins a read transaction on the database to prevent checkpointing.
482-
func (db *DB) acquireReadLock() error {
482+
func (db *DB) acquireReadLock(ctx context.Context) error {
483483
if db.rtx != nil {
484484
return nil
485485
}
@@ -491,7 +491,7 @@ func (db *DB) acquireReadLock() error {
491491
}
492492

493493
// Execute read query to obtain read lock.
494-
if _, err := tx.ExecContext(db.ctx, `SELECT COUNT(1) FROM _litestream_seq;`); err != nil {
494+
if _, err := tx.ExecContext(ctx, `SELECT COUNT(1) FROM _litestream_seq;`); err != nil {
495495
_ = tx.Rollback()
496496
return err
497497
}
@@ -520,7 +520,7 @@ func (db *DB) Sync(ctx context.Context) (err error) {
520520
defer db.mu.Unlock()
521521

522522
// Initialize database, if necessary. Exit if no DB exists.
523-
if err := db.init(); err != nil {
523+
if err := db.init(ctx); err != nil {
524524
return err
525525
} else if db.db == nil {
526526
db.Logger.Debug("sync: no database found")
@@ -1017,7 +1017,7 @@ func (db *DB) checkpoint(ctx context.Context, mode string) error {
10171017

10181018
// Execute checkpoint and immediately issue a write to the WAL to ensure
10191019
// a new page is written.
1020-
if err := db.execCheckpoint(mode); err != nil {
1020+
if err := db.execCheckpoint(ctx, mode); err != nil {
10211021
return err
10221022
} else if _, err = db.db.Exec(`INSERT INTO _litestream_seq (id, seq) VALUES (1, 1) ON CONFLICT (id) DO UPDATE SET seq = seq + 1`); err != nil {
10231023
return err
@@ -1057,7 +1057,7 @@ func (db *DB) checkpoint(ctx context.Context, mode string) error {
10571057
return nil
10581058
}
10591059

1060-
func (db *DB) execCheckpoint(mode string) (err error) {
1060+
func (db *DB) execCheckpoint(ctx context.Context, mode string) (err error) {
10611061
// Ignore if there is no underlying database.
10621062
if db.db == nil {
10631063
return nil
@@ -1079,7 +1079,7 @@ func (db *DB) execCheckpoint(mode string) (err error) {
10791079
if err := db.releaseReadLock(); err != nil {
10801080
return fmt.Errorf("release read lock: %w", err)
10811081
}
1082-
defer func() { _ = db.acquireReadLock() }()
1082+
defer func() { _ = db.acquireReadLock(ctx) }()
10831083

10841084
// A non-forced checkpoint is issued as "PASSIVE". This will only checkpoint
10851085
// if there are not pending transactions. A forced checkpoint ("RESTART")
@@ -1096,8 +1096,8 @@ func (db *DB) execCheckpoint(mode string) (err error) {
10961096
db.Logger.Debug("checkpoint", "mode", mode, "result", fmt.Sprintf("%d,%d,%d", row[0], row[1], row[2]))
10971097

10981098
// Reacquire the read lock immediately after the checkpoint.
1099-
if err := db.acquireReadLock(); err != nil {
1100-
return fmt.Errorf("release read lock: %w", err)
1099+
if err := db.acquireReadLock(ctx); err != nil {
1100+
return fmt.Errorf("reacquire read lock: %w", err)
11011101
}
11021102

11031103
return nil
@@ -1412,7 +1412,7 @@ func (db *DB) CRC64(ctx context.Context) (uint64, ltx.Pos, error) {
14121412
db.mu.Lock()
14131413
defer db.mu.Unlock()
14141414

1415-
if err := db.init(); err != nil {
1415+
if err := db.init(ctx); err != nil {
14161416
return 0, ltx.Pos{}, err
14171417
} else if db.db == nil {
14181418
return 0, ltx.Pos{}, os.ErrNotExist

0 commit comments

Comments
 (0)