diff --git a/pkg/sqlcmd/batch.go b/pkg/sqlcmd/batch.go index 7b8082e5..17432319 100644 --- a/pkg/sqlcmd/batch.go +++ b/pkg/sqlcmd/batch.go @@ -109,7 +109,7 @@ parse: i, ok = readMultilineComment(b.raw, i, b.rawlen) b.comment = !ok // start of a string - case c == '\'' || c == '"': + case c == '\'' || c == '"' || c == '[': b.quote = c // inline sql comment, skip to end of line case c == '-' && next == '-': @@ -145,25 +145,24 @@ parse: } } if err == nil { - i = min(i, b.rawlen) - empty := isEmptyLine(b.raw, 0, i) - appendLine := true - if !b.comment && command != nil && empty { - appendLine = false - } - if appendLine { - // any variables on the line need to be added to the global map - inc := 0 - if b.Length > 0 { - inc = len(lineend) - } - if b.linevarmap != nil { - for v := range b.linevarmap { - b.varmap[v+b.Length+inc] = b.linevarmap[v] + if command == nil { + i = min(i, b.rawlen) + empty := i == 0 + appendLine := !empty || b.comment || b.quote != 0 + if appendLine { + // any variables on the line need to be added to the global map + inc := 0 + if b.Length > 0 { + inc = len(lineend) + } + if b.linevarmap != nil { + for v := range b.linevarmap { + b.varmap[v+b.Length+inc] = b.linevarmap[v] + } } + // log.Printf(">> appending: `%s`", string(r[st:i])) + b.append(b.raw[:i], lineend) } - // log.Printf(">> appending: `%s`", string(r[st:i])) - b.append(b.raw[:i], lineend) b.batchline++ } b.raw = b.raw[i:] @@ -242,11 +241,13 @@ func (b *Batch) readString(r []rune, i, end int, quote rune, line uint) (int, bo } else { return i, false, syntaxError(line) } - case quote == '\'' && c == '\'' && next == '\'': + case quote == '\'' && c == '\'' && next == '\'', + quote == '[' && c == ']' && next == ']': i++ continue case quote == '\'' && c == '\'' && prev != '\'', - quote == '"' && c == '"': + quote == '"' && c == '"', + quote == '[' && c == ']': return i, true, nil } prev = c diff --git a/pkg/sqlcmd/batch_test.go b/pkg/sqlcmd/batch_test.go index a323c1fb..29d290da 100644 --- a/pkg/sqlcmd/batch_test.go +++ b/pkg/sqlcmd/batch_test.go @@ -33,6 +33,10 @@ func TestBatchNext(t *testing.T) { {"select 1\n:exit()", []string{"select 1"}, []string{"EXIT"}, "-"}, {"select 1\n:exit (select 10)", []string{"select 1"}, []string{"EXIT"}, "-"}, {"select 1\n:exit", []string{"select 1"}, []string{"EXIT"}, "-"}, + {"select [a'b] = 'c'", []string{"select [a'b] = 'c'"}, nil, "-"}, + {"select [bracket", []string{"select [bracket"}, nil, "["}, + {"select [bracket]]a]", []string{"select [bracket]]a]"}, nil, "-"}, + {"exit_1", []string{"exit_1"}, nil, "-"}, } for _, test := range tests { b := NewBatch(sp(test.s, "\n"), newCommands()) diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index b921ff32..3422f151 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -37,7 +37,7 @@ func newCommands() Commands { // Commands is the set of Command implementations return map[string]*Command{ "EXIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT(?:[ \t]*(\(?.*\)?$)|$)`), + regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`), action: exitCommand, name: "EXIT", }, @@ -186,15 +186,17 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error { } } query = strings.TrimSpace(params[1 : len(params)-1]) - s.batch.Reset([]rune(query)) - _, _, err := s.batch.Next() - if err != nil { - return err - } - query = s.batch.String() - if s.batch.String() != "" { - query = s.getRunnableQuery(query) - s.Exitcode, _ = s.runQuery(query) + if len(query) > 0 { + s.batch.Reset([]rune(query)) + _, _, err := s.batch.Next() + if err != nil { + return err + } + query = s.batch.String() + if s.batch.String() != "" { + query = s.getRunnableQuery(query) + s.Exitcode, _ = s.runQuery(query) + } } return ErrExitRequested } diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index d90a9749..2540bd1b 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -41,9 +41,9 @@ func TestCommandParsing(t *testing.T) { {` :Error c:\folder\file`, "ERROR", []string{`c:\folder\file`}}, {`:Setvar A1 "some value" `, "SETVAR", []string{`A1 "some value" `}}, {` :Listvar`, "LISTVAR", []string{""}}, - {`:EXIT (select 100 as count)`, "EXIT", []string{"(select 100 as count)"}}, - {`:EXIT ( )`, "EXIT", []string{"( )"}}, - {`EXIT `, "EXIT", []string{""}}, + {`:EXIT (select 100 as count)`, "EXIT", []string{" (select 100 as count)"}}, + {`:EXIT ( )`, "EXIT", []string{" ( )"}}, + {`EXIT `, "EXIT", []string{" "}}, {`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}}, {`:r c:\$(var)\file.sql`, "READFILE", []string{`c:\$(var)\file.sql`}}, {`:!! notepad`, "EXEC", []string{" notepad"}}, diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index 81b66633..c90ed782 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -67,10 +67,11 @@ func TestConnectionStringFromSqlCmd(t *testing.T) { } } -/* The following tests require a working SQL instance and rely on SqlCmd environment variables +/* + The following tests require a working SQL instance and rely on SqlCmd environment variables + to manage the initial connection string. The default connection when no environment variables are set will be to localhost using Windows auth. - */ func TestSqlCmdConnectDb(t *testing.T) { v := InitializeVariables(true) @@ -185,6 +186,34 @@ func TestIncludeFileWithVariables(t *testing.T) { } } +func TestIncludeFileMultilineString(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + dataPath := "testdata" + string(os.PathSeparator) + err := s.IncludeFile(dataPath+"blanks.sql", true) + if assert.NoError(t, err, "IncludeFile blanks.sql true") { + assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile blanks.sql true") + assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile blanks.sql true") + s.SetOutput(nil) + o := buf.buf.String() + assert.Equal(t, "line 1"+SqlcmdEol+SqlcmdEol+SqlcmdEol+SqlcmdEol+"line2"+SqlcmdEol+SqlcmdEol, o) + } +} + +func TestIncludeFileQuotedIdentifiers(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + dataPath := "testdata" + string(os.PathSeparator) + err := s.IncludeFile(dataPath+"quotedidentifiers.sql", true) + if assert.NoError(t, err, "IncludeFile quotedidentifiers.sql true") { + assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile quotedidentifiers.sql true") + assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile quotedidentifiers.sql true") + s.SetOutput(nil) + o := buf.buf.String() + assert.Equal(t, `ab 1 a"b`+SqlcmdEol+SqlcmdEol, o) + } +} + func TestGetRunnableQuery(t *testing.T) { v := InitializeVariables(false) v.Set("var1", "v1") diff --git a/pkg/sqlcmd/testdata/blanks.sql b/pkg/sqlcmd/testdata/blanks.sql new file mode 100644 index 00000000..0e400ae5 --- /dev/null +++ b/pkg/sqlcmd/testdata/blanks.sql @@ -0,0 +1,7 @@ +set nocount on +:setvar l line2 +select 'line 1 + + + +$(l)' \ No newline at end of file diff --git a/pkg/sqlcmd/testdata/quotedidentifiers.sql b/pkg/sqlcmd/testdata/quotedidentifiers.sql new file mode 100644 index 00000000..25b8c1cc --- /dev/null +++ b/pkg/sqlcmd/testdata/quotedidentifiers.sql @@ -0,0 +1,3 @@ +set nocount on +set quoted_identifier on +select [a]]b] = 'ab', "a'b" = 1, [a"b] = 'a"b' diff --git a/pkg/sqlcmd/testdata/variablesnogo.sql b/pkg/sqlcmd/testdata/variablesnogo.sql index bea1d010..10aded38 100644 --- a/pkg/sqlcmd/testdata/variablesnogo.sql +++ b/pkg/sqlcmd/testdata/variablesnogo.sql @@ -1,5 +1,12 @@ set nocount on :setvar hundred 100 --- comment + +-- verify fix for https://github.com/microsoft/go-sqlcmd/issues/197 + +-- Correctly handle the first line of a batch having a variable after an empty line + +GO + select $(hundred) +GO \ No newline at end of file