diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 6a258e49..31c2d9c0 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -264,13 +264,18 @@ func outCommand(s *Sqlcmd, args []string, line uint) error { if len(args) == 0 || args[0] == "" { return InvalidCommandError("OUT", line) } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err + } + switch { - case strings.EqualFold(args[0], "stdout"): + case strings.EqualFold(filePath, "stdout"): s.SetOutput(os.Stdout) - case strings.EqualFold(args[0], "stderr"): + case strings.EqualFold(filePath, "stderr"): s.SetOutput(os.Stderr) default: - o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return InvalidFileError(err, args[0]) } @@ -290,15 +295,19 @@ func outCommand(s *Sqlcmd, args []string, line uint) error { // errorCommand changes the error writer to use a file func errorCommand(s *Sqlcmd, args []string, line uint) error { if len(args) == 0 || args[0] == "" { - return InvalidCommandError("OUT", line) + return InvalidCommandError("ERROR", line) + } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err } switch { - case strings.EqualFold(args[0], "stderr"): + case strings.EqualFold(filePath, "stderr"): s.SetError(os.Stderr) - case strings.EqualFold(args[0], "stdout"): + case strings.EqualFold(filePath, "stdout"): s.SetError(os.Stdout) default: - o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return InvalidFileError(err, args[0]) } @@ -549,7 +558,7 @@ func xmlCommand(s *Sqlcmd, args []string, line uint) error { func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { var b *strings.Builder end := len(arg) - for i := 0; i < end; { + for i := 0; i < end && !s.Connect.DisableVariableSubstitution; { c, next := arg[i], grab(arg, i+1, end) switch { case c == '$' && next == '(': diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 94caeb2e..7a9e8f3f 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -246,6 +246,7 @@ func TestConnectCommand(t *testing.T) { func TestErrorCommand(t *testing.T) { s, buf := setupSqlCmdWithMemoryOutput(t) + defer s.SetError(nil) defer buf.Close() file, err := os.CreateTemp("", "sqlcmderr") assert.NoError(t, err, "os.CreateTemp") @@ -253,17 +254,20 @@ func TestErrorCommand(t *testing.T) { fileName := file.Name() _ = file.Close() err = errorCommand(s, []string{""}, 1) - assert.EqualError(t, err, InvalidCommandError("OUT", 1).Error(), "errorCommand with empty file name") + assert.EqualError(t, err, InvalidCommandError("ERROR", 1).Error(), "errorCommand with empty file name") err = errorCommand(s, []string{fileName}, 1) assert.NoError(t, err, "errorCommand") // Only some error kinds go to the error output err = runSqlCmd(t, s, []string{"print N'message'", "RAISERROR(N'Error', 16, 1)", "SELECT 1", ":SETVAR 1", "GO"}) assert.NoError(t, err, "runSqlCmd") - s.SetError(nil) errText, err := os.ReadFile(file.Name()) if assert.NoError(t, err, "ReadFile") { assert.Regexp(t, "Msg 50000, Level 16, State 1, Server .*, Line 2"+SqlcmdEol+"Error"+SqlcmdEol, string(errText), "Error file contents: "+string(errText)) } + s.vars.Set("myvar", "stdout") + err = errorCommand(s, []string{"$(myvar)"}, 1) + assert.NoError(t, err, "errorCommand with a variable") + assert.Equal(t, os.Stdout, s.err, "error set to stdout using a variable") } func TestOnErrorCommand(t *testing.T) { @@ -320,6 +324,11 @@ func TestResolveArgumentVariables(t *testing.T) { if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") { assert.Empty(t, actual, "fail on unresolved variable") } + s.Connect.DisableVariableSubstitution = true + input := "$(var1) notvar" + actual, err = resolveArgumentVariables(s, []rune(input), true) + assert.NoError(t, err) + assert.Equal(t, input, actual, "resolveArgumentVariables when DisableVariableSubstitution is false") } func TestExecCommand(t *testing.T) {