diff --git a/cmd/sqlcmd/main.go b/cmd/sqlcmd/main.go index c795abd5..c094f19d 100644 --- a/cmd/sqlcmd/main.go +++ b/cmd/sqlcmd/main.go @@ -201,22 +201,28 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq connect.ErrorSeverityLevel = args.ErrorSeverityLevel } +func isConsoleInitializationRequired(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments) bool { + iactive := args.InputFile == nil && args.Query == "" + return iactive || connect.RequiresPassword() +} + func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { wd, err := os.Getwd() if err != nil { return 1, err } - iactive := args.InputFile == nil && args.Query == "" + var connectConfig sqlcmd.ConnectSettings + setConnect(&connectConfig, args, vars) var line sqlcmd.Console = nil - if iactive { + if isConsoleInitializationRequired(&connectConfig, args) { line = console.NewConsole("") defer line.Close() } s := sqlcmd.New(line, wd, vars) s.UnicodeOutputFile = args.UnicodeOutputFile - setConnect(&s.Connect, args, vars) + if args.BatchTerminator != "GO" { err = s.Cmd.SetBatchTerminator(args.BatchTerminator) if err != nil { @@ -227,7 +233,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { return 1, err } - setConnect(&s.Connect, args, vars) + s.Connect = &connectConfig s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false) if args.OutputFile != "" { err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile}) @@ -257,10 +263,12 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { s.Query = args.Query } // connect using no overrides - err = s.ConnectDb(nil, !iactive) + err = s.ConnectDb(nil, line == nil) if err != nil { return 1, err } + + iactive := args.InputFile == nil && args.Query == "" if iactive || s.Query != "" { err = s.Run(once, false) } else { diff --git a/cmd/sqlcmd/main_test.go b/cmd/sqlcmd/main_test.go index ea13dbee..c0a55c9e 100644 --- a/cmd/sqlcmd/main_test.go +++ b/cmd/sqlcmd/main_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/alecthomas/kong" + "github.com/microsoft/go-mssqldb/azuread" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -327,6 +328,54 @@ func TestMissingInputFile(t *testing.T) { assert.Equal(t, 1, exitCode, "exitCode") } +func TestConditionsForPasswordPrompt(t *testing.T) { + + type test struct { + authenticationMethod string + inputFile []string + username string + pwd string + expectedResult bool + } + tests := []test{ + // Positive Testcases + {sqlcmd.SqlPassword, []string{""}, "someuser", "", true}, + {sqlcmd.NotSpecified, []string{"testdata/someFile.sql"}, "someuser", "", true}, + {azuread.ActiveDirectoryPassword, []string{""}, "someuser", "", true}, + {azuread.ActiveDirectoryPassword, []string{"testdata/someFile.sql"}, "someuser", "", true}, + {azuread.ActiveDirectoryServicePrincipal, []string{""}, "someuser", "", true}, + {azuread.ActiveDirectoryServicePrincipal, []string{"testdata/someFile.sql"}, "someuser", "", true}, + {azuread.ActiveDirectoryApplication, []string{""}, "someuser", "", true}, + {azuread.ActiveDirectoryApplication, []string{"testdata/someFile.sql"}, "someuser", "", true}, + + //Negative Testcases + {sqlcmd.NotSpecified, []string{""}, "", "", false}, + {sqlcmd.NotSpecified, []string{"testdata/someFile.sql"}, "", "", false}, + {azuread.ActiveDirectoryDefault, []string{""}, "someuser", "", false}, + {azuread.ActiveDirectoryDefault, []string{"testdata/someFile.sql"}, "someuser", "", false}, + {azuread.ActiveDirectoryInteractive, []string{""}, "someuser", "", false}, + {azuread.ActiveDirectoryInteractive, []string{"testdata/someFile.sql"}, "someuser", "", false}, + {azuread.ActiveDirectoryManagedIdentity, []string{""}, "someuser", "", false}, + {azuread.ActiveDirectoryManagedIdentity, []string{"testdata/someFile.sql"}, "someuser", "", false}, + } + + for _, testcase := range tests { + t.Log(testcase.authenticationMethod, testcase.inputFile, testcase.username, testcase.pwd, testcase.expectedResult) + args := newArguments() + args.DisableCmdAndWarn = true + args.InputFile = testcase.inputFile + args.UserName = testcase.username + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + setVars(vars, &args) + var connectConfig sqlcmd.ConnectSettings + setConnect(&connectConfig, &args, vars) + connectConfig.AuthenticationMethod = testcase.authenticationMethod + connectConfig.Password = testcase.pwd + assert.Equal(t, testcase.expectedResult, isConsoleInitializationRequired(&connectConfig, &args), "Unexpected test result encountered for console initialization") + assert.Equal(t, testcase.expectedResult, connectConfig.RequiresPassword() && connectConfig.Password == "", "Unexpected test result encountered for password prompt conditions") + } +} + // Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set func canTestAzureAuth() bool { server := os.Getenv(sqlcmd.SQLCMDSERVER) diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 4af4164a..f9eec1bf 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -353,7 +353,7 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error { return InvalidCommandError("CONNECT", line) } - connect := s.Connect + connect := *s.Connect connect.UserName, _ = resolveArgumentVariables(s, []rune(arguments.Username), false) connect.Password, _ = resolveArgumentVariables(s, []rune(arguments.Password), false) connect.ServerName, _ = resolveArgumentVariables(s, []rune(arguments.Server), false) diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index a19d626b..9299cac7 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -171,7 +171,7 @@ func TestConnectCommand(t *testing.T) { err := connectCommand(s, []string{"someserver -U someuser"}, 1) assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure") assert.True(t, prompted, "connectCommand with user name and no password should prompt for password") - assert.NotEqual(t, "someserver", s.Connect.ServerName, "On error, sqlCmd.Connect does not copy inputs") + assert.NotEqual(t, "someserver", s.Connect.ServerName, "On connection failure, sqlCmd.Connect does not copy inputs") err = connectCommand(s, []string{}, 2) assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error") diff --git a/pkg/sqlcmd/connect.go b/pkg/sqlcmd/connect.go index aa489ccb..01dac1ec 100644 --- a/pkg/sqlcmd/connect.go +++ b/pkg/sqlcmd/connect.go @@ -64,7 +64,7 @@ func (connect ConnectSettings) sqlAuthentication() bool { (!connect.UseTrustedConnection && connect.authenticationMethod() == NotSpecified && connect.UserName != "") } -func (connect ConnectSettings) requiresPassword() bool { +func (connect ConnectSettings) RequiresPassword() bool { requiresPassword := connect.sqlAuthentication() if !requiresPassword { switch connect.authenticationMethod() { diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index a3d34898..d0ebb491 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -62,7 +62,7 @@ type Sqlcmd struct { batch *Batch // Exitcode is returned to the operating system when the process exits Exitcode int - Connect ConnectSettings + Connect *ConnectSettings vars *Variables Format Formatter Query string @@ -79,6 +79,7 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd { workingDirectory: workingDirectory, vars: vars, Cmd: newCommands(), + Connect: &ConnectSettings{}, } s.batch = NewBatch(s.scanNext, s.Cmd) mssql.SetContextLogger(s) @@ -213,12 +214,12 @@ func (s *Sqlcmd) SetError(e io.WriteCloser) { func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { newConnection := connect != nil if connect == nil { - connect = &s.Connect + connect = s.Connect } var connector driver.Connector useAad := !connect.sqlAuthentication() && !connect.integratedAuthentication() - if connect.requiresPassword() && !nopw && connect.Password == "" { + if connect.RequiresPassword() && !nopw && connect.Password == "" { var err error if connect.Password, err = s.promptPassword(); err != nil { return err @@ -259,7 +260,7 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { s.vars.Set(SQLCMDUSER, u.Username) } if newConnection { - s.Connect = *connect + s.Connect = connect } if s.batch != nil { s.batch.batchline = 1 diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index 2799b87b..f7909334 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -367,10 +367,10 @@ func TestPromptForPasswordPositive(t *testing.T) { v := InitializeVariables(true) s := New(console, "", v) // attempt without password prompt - err := s.ConnectDb(&c, true) + err := s.ConnectDb(c, true) assert.False(t, prompted, "ConnectDb with nopw=true should not prompt for password") assert.Error(t, err, "ConnectDb with nopw==true and no password provided") - err = s.ConnectDb(&c, false) + err = s.ConnectDb(c, false) assert.True(t, prompted, "ConnectDb with !nopw should prompt for password") assert.NoError(t, err, "ConnectDb with !nopw and valid password returned from prompt") if s.Connect.Password != password { @@ -506,7 +506,7 @@ func canTestAzureAuth() bool { return strings.Contains(server, ".database.windows.net") && userName == "" } -func newConnect(t testing.TB) ConnectSettings { +func newConnect(t testing.TB) *ConnectSettings { t.Helper() connect := ConnectSettings{ UserName: os.Getenv(SQLCMDUSER), @@ -518,5 +518,5 @@ func newConnect(t testing.TB) ConnectSettings { t.Log("Using ActiveDirectoryDefault") connect.AuthenticationMethod = azuread.ActiveDirectoryDefault } - return connect + return &connect }