From 12346fa9294908f85e48b598ba3c95c391c07299 Mon Sep 17 00:00:00 2001 From: Ting-Lan Wang Date: Fri, 19 Sep 2025 09:48:21 -0400 Subject: [PATCH] support case insensitive field matching --- oracle/oracle.go | 12 ++++++++++++ oracle/query.go | 20 ++++++++++++++++++-- tests/config_test.go | 42 ++++++++++++++++++++++++++++++++++++++---- 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/oracle/oracle.go b/oracle/oracle.go index ee7e0a2..b4b7efa 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -109,6 +109,18 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) { callback.Update().Replace("gorm:update", Update) callback.Query().Before("gorm:query").Register("oracle:before_query", BeforeQuery) + if d.SkipQuoteIdentifiers { + // When identifiers are not quoted, columns are returned by Oracle in uppercase. + // Fields in the models may be lower case for compatibility with other databases. + // Match them up with the fields using the column mapping. + oracleCaseHandler := "oracle:case_handler" + if callback.Query().Get(oracleCaseHandler) == nil { + if err := callback.Query().Before("gorm:query").Register(oracleCaseHandler, MismatchedCaseHandler); err != nil { + return err + } + } + } + maps.Copy(db.ClauseBuilders, OracleClauseBuilders()) if d.Conn == nil { diff --git a/oracle/query.go b/oracle/query.go index 5a45c26..ca63b0a 100644 --- a/oracle/query.go +++ b/oracle/query.go @@ -39,9 +39,10 @@ package oracle import ( - "gorm.io/gorm" "regexp" "strings" + + "gorm.io/gorm" ) // Identifies the table name alias provided as @@ -63,5 +64,20 @@ func BeforeQuery(db *gorm.DB) { } } } - return +} + +// MismatchedCaseHandler handles Oracle Case Insensitivity. +// When identifiers are not quoted, columns are returned by Oracle in uppercase. +// Fields in the models may be lower case for compatibility with other databases. +// Match them up with the fields using the column mapping. +func MismatchedCaseHandler(gormDB *gorm.DB) { + if gormDB.Statement == nil || gormDB.Statement.Schema == nil { + return + } + if len(gormDB.Statement.Schema.Fields) > 0 && gormDB.Statement.ColumnMapping == nil { + gormDB.Statement.ColumnMapping = map[string]string{} + } + for _, field := range gormDB.Statement.Schema.Fields { + gormDB.Statement.ColumnMapping[strings.ToUpper(field.DBName)] = field.Name + } } diff --git a/tests/config_test.go b/tests/config_test.go index a9dd435..7fff870 100644 --- a/tests/config_test.go +++ b/tests/config_test.go @@ -78,23 +78,57 @@ func TestSkipQuoteIdentifiers(t *testing.T) { t.Errorf("Failed to get column: name") } + student := Student{ID: 1, Name: "John"} + if err := db.Model(&Student{}).Create(&student).Error; err != nil { + t.Errorf("Failed to insert student, got %v", err) + } + + var result Student + if err := db.First(&result).Error; err != nil { + t.Errorf("Failed to query first student, got %v", err) + } + + if result.ID != student.ID { + t.Errorf("id should be %v, but got %v", student.ID, result.ID) + } + + if result.Name != student.Name { + t.Errorf("name should be %v, but got %v", student.Name, result.Name) + } +} + +func TestSkipQuoteIdentifiersSQL(t *testing.T) { + db, err := openTestDBWithOptions( + &oracle.Config{SkipQuoteIdentifiers: true}, + &gorm.Config{Logger: newLogger}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } dryrunDB := db.Session(&gorm.Session{DryRun: true}) - result := dryrunDB.Model(&Student{}).Create(&Student{ID: 1, Name: "John"}) + insertedStudent := Student{ID: 1, Name: "John"} + result := dryrunDB.Model(&Student{}).Create(&insertedStudent) + if !regexp.MustCompile(`^INSERT INTO STUDENTS \(name,id\) VALUES \(:1,:2\)$`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid insert SQL, got %v", result.Statement.SQL.String()) } - result = dryrunDB.First(&Student{}) + // Test First + var firstStudent Student + result = dryrunDB.First(&firstStudent) + if !regexp.MustCompile(`^SELECT \* FROM STUDENTS ORDER BY STUDENTS\.id FETCH NEXT 1 ROW ONLY$`).MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) } - result = dryrunDB.Find(&Student{ID: 1, Name: "John"}) - if !regexp.MustCompile(`^SELECT \* FROM STUDENTS WHERE STUDENTS\.id = :1$`).MatchString(result.Statement.SQL.String()) { + // Test Find + var foundStudent Student + result = dryrunDB.Find(foundStudent, "id = ?", insertedStudent.ID) + if !regexp.MustCompile(`^SELECT \* FROM STUDENTS WHERE id = :1$`).MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) } + // Test Save result = dryrunDB.Save(&Student{ID: 2, Name: "Mary"}) if !regexp.MustCompile(`^UPDATE STUDENTS SET name=:1 WHERE id = :2$`).MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String())