Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 122 additions & 4 deletions cmd/litestream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,20 +290,34 @@ func (c *Config) Validate() error {
}

// Validate database configs
for _, db := range c.DBs {
for idx, db := range c.DBs {
// Validate that either path or directory is specified, but not both
if db.Path != "" && db.Directory != "" {
return fmt.Errorf("database config #%d: cannot specify both 'path' and 'directory'", idx+1)
}
if db.Path == "" && db.Directory == "" {
return fmt.Errorf("database config #%d: must specify either 'path' or 'directory'", idx+1)
}

// Use path or directory for identifying the config in error messages
dbIdentifier := db.Path
if dbIdentifier == "" {
dbIdentifier = db.Directory
}

// Validate sync intervals for replicas
if db.Replica != nil && db.Replica.SyncInterval != nil && *db.Replica.SyncInterval <= 0 {
return &ConfigValidationError{
Err: ErrInvalidSyncInterval,
Field: fmt.Sprintf("dbs[%s].replica.sync-interval", db.Path),
Field: fmt.Sprintf("dbs[%s].replica.sync-interval", dbIdentifier),
Value: *db.Replica.SyncInterval,
}
}
for i, replica := range db.Replicas {
if replica.SyncInterval != nil && *replica.SyncInterval <= 0 {
return &ConfigValidationError{
Err: ErrInvalidSyncInterval,
Field: fmt.Sprintf("dbs[%s].replicas[%d].sync-interval", db.Path, i),
Field: fmt.Sprintf("dbs[%s].replicas[%d].sync-interval", dbIdentifier, i),
Value: *replica.SyncInterval,
}
}
Expand Down Expand Up @@ -461,9 +475,12 @@ type CompactionLevelConfig struct {
Interval time.Duration `yaml:"interval"`
}

// DBConfig represents the configuration for a single database.
// DBConfig represents the configuration for a single database or directory of databases.
type DBConfig struct {
Path string `yaml:"path"`
Directory string `yaml:"directory"` // Directory to scan for databases
Pattern string `yaml:"pattern"` // File pattern to match (e.g., "*.db", "*.sqlite")
Recursive bool `yaml:"recursive"` // Scan subdirectories recursively
MetaPath *string `yaml:"meta-path"`
MonitorInterval *time.Duration `yaml:"monitor-interval"`
CheckpointInterval *time.Duration `yaml:"checkpoint-interval"`
Expand Down Expand Up @@ -533,6 +550,107 @@ func NewDBFromConfig(dbc *DBConfig) (*litestream.DB, error) {
return db, nil
}

// NewDBsFromDirectoryConfig scans a directory and creates DB instances for all SQLite databases found.
func NewDBsFromDirectoryConfig(dbc *DBConfig) ([]*litestream.DB, error) {
if dbc.Directory == "" {
return nil, fmt.Errorf("directory path is required for directory replication")
}

dirPath, err := expand(dbc.Directory)
if err != nil {
return nil, err
}

// Default pattern if not specified
pattern := dbc.Pattern
if pattern == "" {
pattern = "*.db"
}

// Find all SQLite databases in the directory
dbPaths, err := FindSQLiteDatabases(dirPath, pattern, dbc.Recursive)
if err != nil {
return nil, fmt.Errorf("failed to scan directory %s: %w", dirPath, err)
}

if len(dbPaths) == 0 {
return nil, fmt.Errorf("no SQLite databases found in directory %s with pattern %s", dirPath, pattern)
}

// Create DB instances for each found database
var dbs []*litestream.DB
for _, dbPath := range dbPaths {
// Create a copy of the config for each database
dbConfigCopy := *dbc
dbConfigCopy.Path = dbPath
dbConfigCopy.Directory = "" // Clear directory field for individual DB

db, err := NewDBFromConfig(&dbConfigCopy)
if err != nil {
return nil, fmt.Errorf("failed to create DB for %s: %w", dbPath, err)
}
dbs = append(dbs, db)
}

return dbs, nil
}

// FindSQLiteDatabases recursively finds all SQLite database files in a directory.
// Exported for testing.
func FindSQLiteDatabases(dir string, pattern string, recursive bool) ([]string, error) {
var dbPaths []string

err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

// Skip directories unless recursive
if info.IsDir() {
if !recursive && path != dir {
return filepath.SkipDir
}
return nil
}

// Check if file matches pattern
matched, err := filepath.Match(pattern, filepath.Base(path))
if err != nil {
return err
}
if !matched {
return nil
}

// Check if it's a SQLite database
if IsSQLiteDatabase(path) {
dbPaths = append(dbPaths, path)
}

return nil
})

return dbPaths, err
}

// IsSQLiteDatabase checks if a file is a SQLite database by reading its header.
// Exported for testing.
func IsSQLiteDatabase(path string) bool {
file, err := os.Open(path)
if err != nil {
return false
}
defer file.Close()

// SQLite files start with "SQLite format 3\x00"
header := make([]byte, 16)
if _, err := file.Read(header); err != nil {
return false
}

return string(header) == "SQLite format 3\x00"
}

// ReplicaConfig represents the configuration for a single replica in a database.
type ReplicaConfig struct {
Type string `yaml:"type"` // "file", "s3"
Expand Down
206 changes: 206 additions & 0 deletions cmd/litestream/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,209 @@ func TestConfig_DefaultValues(t *testing.T) {
t.Errorf("expected default snapshot retention of 24h, got %v", *config.Snapshot.Retention)
}
}

func TestFindSQLiteDatabases(t *testing.T) {
// Create a temporary directory using t.TempDir() - automatically cleaned up
tmpDir := t.TempDir()

// Create test files
testFiles := []struct {
path string
isSQLite bool
shouldFind bool
}{
{"test1.db", true, true},
{"test2.sqlite", true, true},
{"test3.db", false, false}, // Not a SQLite file
{"test.txt", false, false},
{"subdir/test4.db", true, true},
{"subdir/test5.sqlite", true, true},
{"subdir/deep/test6.db", true, true},
}

// Create test files
for _, tf := range testFiles {
fullPath := filepath.Join(tmpDir, tf.path)
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
t.Fatal(err)
}

file, err := os.Create(fullPath)
if err != nil {
t.Fatal(err)
}

if tf.isSQLite {
// Write SQLite header
if _, err := file.Write([]byte("SQLite format 3\x00")); err != nil {
t.Fatal(err)
}
} else {
// Write non-SQLite content
if _, err := file.Write([]byte("not a sqlite file")); err != nil {
t.Fatal(err)
}
}
if err := file.Close(); err != nil {
t.Fatal(err)
}
}

t.Run("non-recursive *.db pattern", func(t *testing.T) {
dbs, err := main.FindSQLiteDatabases(tmpDir, "*.db", false)
if err != nil {
t.Fatal(err)
}

// Should only find test1.db in root directory
if len(dbs) != 1 {
t.Errorf("expected 1 database, got %d", len(dbs))
}
})

t.Run("recursive *.db pattern", func(t *testing.T) {
dbs, err := main.FindSQLiteDatabases(tmpDir, "*.db", true)
if err != nil {
t.Fatal(err)
}

// Should find test1.db, test4.db, and test6.db
if len(dbs) != 3 {
t.Errorf("expected 3 databases, got %d", len(dbs))
}
})

t.Run("recursive *.sqlite pattern", func(t *testing.T) {
dbs, err := main.FindSQLiteDatabases(tmpDir, "*.sqlite", true)
if err != nil {
t.Fatal(err)
}

// Should find test2.sqlite and test5.sqlite
if len(dbs) != 2 {
t.Errorf("expected 2 databases, got %d", len(dbs))
}
})

t.Run("recursive * pattern", func(t *testing.T) {
dbs, err := main.FindSQLiteDatabases(tmpDir, "*", true)
if err != nil {
t.Fatal(err)
}

// Should find all 5 SQLite databases
if len(dbs) != 5 {
t.Errorf("expected 5 databases, got %d", len(dbs))
}
})
}

func TestIsSQLiteDatabase(t *testing.T) {
// Create temporary test files using t.TempDir() - automatically cleaned up
tmpDir := t.TempDir()

t.Run("valid SQLite file", func(t *testing.T) {
path := filepath.Join(tmpDir, "valid.db")
file, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
if _, err := file.Write([]byte("SQLite format 3\x00")); err != nil {
t.Fatal(err)
}
if err := file.Close(); err != nil {
t.Fatal(err)
}

if !main.IsSQLiteDatabase(path) {
t.Error("expected file to be identified as SQLite database")
}
})

t.Run("invalid SQLite file", func(t *testing.T) {
path := filepath.Join(tmpDir, "invalid.db")
file, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
if _, err := file.Write([]byte("not a sqlite file")); err != nil {
t.Fatal(err)
}
if err := file.Close(); err != nil {
t.Fatal(err)
}

if main.IsSQLiteDatabase(path) {
t.Error("expected file to NOT be identified as SQLite database")
}
})

t.Run("non-existent file", func(t *testing.T) {
path := filepath.Join(tmpDir, "doesnotexist.db")
if main.IsSQLiteDatabase(path) {
t.Error("expected non-existent file to NOT be identified as SQLite database")
}
})
}

func TestDBConfigValidation(t *testing.T) {
t.Run("both path and directory specified", func(t *testing.T) {
config := main.Config{
DBs: []*main.DBConfig{
{
Path: "/path/to/db.sqlite",
Directory: "/path/to/dir",
},
},
}

err := config.Validate()
if err == nil {
t.Error("expected validation error when both path and directory are specified")
}
})

t.Run("neither path nor directory specified", func(t *testing.T) {
config := main.Config{
DBs: []*main.DBConfig{
{},
},
}

err := config.Validate()
if err == nil {
t.Error("expected validation error when neither path nor directory are specified")
}
})

t.Run("valid path configuration", func(t *testing.T) {
config := main.DefaultConfig()
config.DBs = []*main.DBConfig{
{
Path: "/path/to/db.sqlite",
},
}

err := config.Validate()
if err != nil {
t.Errorf("unexpected validation error for valid path config: %v", err)
}
})

t.Run("valid directory configuration", func(t *testing.T) {
config := main.DefaultConfig()
config.DBs = []*main.DBConfig{
{
Directory: "/path/to/dir",
Pattern: "*.db",
Recursive: true,
},
}

err := config.Validate()
if err != nil {
t.Errorf("unexpected validation error for valid directory config: %v", err)
}
})
}
21 changes: 16 additions & 5 deletions cmd/litestream/replicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,24 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) {
slog.Error("no databases specified in configuration")
}

dbs := make([]*litestream.DB, 0, len(c.Config.DBs))
var dbs []*litestream.DB
for _, dbConfig := range c.Config.DBs {
db, err := NewDBFromConfig(dbConfig)
if err != nil {
return err
// Handle directory configuration
if dbConfig.Directory != "" {
dirDbs, err := NewDBsFromDirectoryConfig(dbConfig)
if err != nil {
return err
}
dbs = append(dbs, dirDbs...)
slog.Info("found databases in directory", "directory", dbConfig.Directory, "count", len(dirDbs))
} else {
// Handle single database configuration
db, err := NewDBFromConfig(dbConfig)
if err != nil {
return err
}
dbs = append(dbs, db)
}
dbs = append(dbs, db)
}

levels := c.Config.CompactionLevels()
Expand Down
Loading