diff --git a/pkg/sqlcmd/batch_test.go b/pkg/sqlcmd/batch_test.go index c55d45b9..948c2a93 100644 --- a/pkg/sqlcmd/batch_test.go +++ b/pkg/sqlcmd/batch_test.go @@ -175,6 +175,7 @@ func TestReadStringVarmap(t *testing.T) { } tests := []mapTest{ {`'var $(var1) var2 $(var2)'`, map[int]string{5: "var1", 18: "var2"}}, + {`'var $(va_1) var2 $(va-2)'`, map[int]string{5: "va_1", 18: "va-2"}}, } for _, test := range tests { b := NewBatch(nil, newCommands()) diff --git a/pkg/sqlcmd/parse.go b/pkg/sqlcmd/parse.go index d0dd9cc9..0fc6a47c 100644 --- a/pkg/sqlcmd/parse.go +++ b/pkg/sqlcmd/parse.go @@ -4,6 +4,7 @@ package sqlcmd import ( + "strings" "unicode" ) @@ -63,7 +64,7 @@ func readVariableReference(r []rune, i int, end int) (int, bool) { if r[i] == ')' { return i, true } - if (r[i] >= 'a' && r[i] <= 'z') || (r[i] >= 'A' && r[i] <= 'Z') || (r[i] >= '0' && r[i] <= '9') { + if (r[i] >= 'a' && r[i] <= 'z') || (r[i] >= 'A' && r[i] <= 'Z') || (r[i] >= '0' && r[i] <= '9') || strings.ContainsRune(validVariableRunes, r[i]) { continue } break diff --git a/pkg/sqlcmd/variables.go b/pkg/sqlcmd/variables.go index 3c2886a8..1d7909b7 100644 --- a/pkg/sqlcmd/variables.go +++ b/pkg/sqlcmd/variables.go @@ -252,14 +252,14 @@ func (variables *Variables) Setvar(name, value string) error { return nil } +const validVariableRunes = "_-" + // ValidIdentifier determines if a given string can be used as a variable name -// The native sqlcmd allowed some punctuation characters as part of a variable name -// but this version will not. func ValidIdentifier(name string) error { first := true for _, c := range name { - if !unicode.IsLetter(c) && (first || !unicode.IsDigit(c)) { + if !unicode.IsLetter(c) && (first || (!unicode.IsDigit(c) && !strings.ContainsRune(validVariableRunes, c))) { return fmt.Errorf("Invalid variable identifier %s", name) } first = false diff --git a/pkg/sqlcmd/variables_test.go b/pkg/sqlcmd/variables_test.go index 0d3e5208..6c1f75fb 100644 --- a/pkg/sqlcmd/variables_test.go +++ b/pkg/sqlcmd/variables_test.go @@ -103,6 +103,7 @@ func TestValidIdentifier(t *testing.T) { {"1A", false}, {"A1", true}, {"A+", false}, + {"A-_b", true}, } for _, tst := range tests { err := ValidIdentifier(tst.raw)