diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index b0af2ab6..6a258e49 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -182,28 +182,34 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error { if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { return InvalidCommandError("EXIT", line) } - // First we run the current batch - query := s.batch.String() - if query != "" { - query = s.getRunnableQuery(query) - if exitCode, err := s.runQuery(query); err != nil { - s.Exitcode = exitCode - return ErrExitRequested - } - } - query = strings.TrimSpace(params[1 : len(params)-1]) - if len(query) > 0 { - s.batch.Reset([]rune(query)) + // First we save the current batch + query1 := s.batch.String() + if len(query1) > 0 { + query1 = s.getRunnableQuery(query1) + } + // Now parse the params of EXIT as a batch without commands + cmd := s.batch.cmd + s.batch.cmd = nil + defer func() { + s.batch.cmd = cmd + }() + query2 := strings.TrimSpace(params[1 : len(params)-1]) + if len(query2) > 0 { + s.batch.Reset([]rune(query2)) _, _, err := s.batch.Next() if err != nil { return err } - query = s.batch.String() - if s.batch.String() != "" { - query = s.getRunnableQuery(query) - s.Exitcode, _ = s.runQuery(query) + query2 = s.batch.String() + if len(query2) > 0 { + query2 = s.getRunnableQuery(query2) } } + + if len(query1) > 0 || len(query2) > 0 { + query := query1 + SqlcmdEol + query2 + s.Exitcode, _ = s.runQuery(query) + } return ErrExitRequested } diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index eff3a509..1119df93 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -375,3 +375,23 @@ func TestEchoInput(t *testing.T) { assert.Equal(t, "set nocount on"+SqlcmdEol+"select 100"+SqlcmdEol+"100"+SqlcmdEol+SqlcmdEol, buf.buf.String(), "Incorrect output with echo true") } } + +func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + c := []string{"set nocount on", "declare @v integer = 2", "select 1", "exit(select @v)"} + err := runSqlCmd(t, s, c) + if assert.NoError(t, err, "exit should not error") { + output := buf.buf.String() + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"2"+SqlcmdEol+SqlcmdEol, output, "Incorrect output") + assert.Equal(t, 2, s.Exitcode, "exit should set Exitcode") + } + s, buf1 := setupSqlCmdWithMemoryOutput(t) + defer buf1.Close() + c = []string{"set nocount on", "select 1", "exit(select @v)"} + err = runSqlCmd(t, s, c) + if assert.NoError(t, err, "exit should not error") { + assert.Equal(t, -101, s.Exitcode, "exit should not set Exitcode on script error") + } + +}