diff --git a/README.md b/README.md index ab5eb239..8c9565fc 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,11 @@ We will be implementing command line switches and behaviors over time. Several s - Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types. - All commands must fit on one line, even `EXIT`. Interactive mode will not check for open parentheses or quotes for commands and prompt for successive lines. The ODBC sqlcmd allows the query run by `EXIT(query)` to span multiple lines. +### Miscellaneous enhancements + +- `:Connect` now has an optional `-G` parameter to select one of the authentication methods for Azure SQL Database - `SqlAuthentication`, `ActiveDirectoryDefault`, `ActiveDirectoryIntegrated`, `ActiveDirectoryServicePrincipal`, `ActiveDirectoryManagedIdentity`, `ActiveDirectoryPassword`. If `-G` is not provided, either Integrated security or SQL Authentication will be used, dependent on the presence of a `-U` user name parameter. +- The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces. + ### Azure Active Directory Authentication This version of sqlcmd supports a broader range of AAD authentication models, based on the [azidentity package](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity). The implementation relies on an AAD Connector in the [driver](https://github.com/denisenkom/go-mssqldb). diff --git a/cmd/sqlcmd/main.go b/cmd/sqlcmd/main.go index c19351d6..bb09e023 100644 --- a/cmd/sqlcmd/main.go +++ b/cmd/sqlcmd/main.go @@ -99,7 +99,7 @@ func main() { setVars(vars, &args) // so far sqlcmd prints all the errors itself so ignore it - exitCode, _ := run(vars) + exitCode, _ := run(vars, &args) os.Exit(exitCode) } @@ -156,43 +156,57 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) { } -func setConnect(s *sqlcmd.Sqlcmd, args *SQLCmdArguments) { +func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sqlcmd.Variables) { if !args.DisableCmdAndWarn { - s.Connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD) - } - s.Connect.UseTrustedConnection = args.UseTrustedConnection - s.Connect.TrustServerCertificate = args.TrustServerCertificate - s.Connect.AuthenticationMethod = args.authenticationMethod(s.Connect.Password != "") - s.Connect.DisableEnvironmentVariables = args.DisableCmdAndWarn - s.Connect.DisableVariableSubstitution = args.DisableVariableSubstitution - s.Connect.ApplicationIntent = args.ApplicationIntent - s.Connect.LoginTimeoutSeconds = args.LoginTimeout - s.Connect.Encrypt = args.EncryptConnection - s.Connect.PacketSize = args.PacketSize - s.Connect.WorkstationName = args.WorkstationName - s.Connect.LogLevel = args.DriverLoggingLevel - s.Connect.ExitOnError = args.ExitOnError - s.Connect.ErrorSeverityLevel = args.ErrorSeverityLevel + connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD) + } + connect.ServerName = args.Server + if connect.ServerName == "" { + connect.ServerName, _ = vars.Get(sqlcmd.SQLCMDSERVER) + } + connect.Database = args.DatabaseName + if connect.Database == "" { + connect.Database, _ = vars.Get(sqlcmd.SQLCMDDBNAME) + } + connect.UserName = args.UserName + if connect.UserName == "" { + connect.UserName, _ = vars.Get(sqlcmd.SQLCMDUSER) + } + connect.UseTrustedConnection = args.UseTrustedConnection + connect.TrustServerCertificate = args.TrustServerCertificate + connect.AuthenticationMethod = args.authenticationMethod(connect.Password != "") + connect.DisableEnvironmentVariables = args.DisableCmdAndWarn + connect.DisableVariableSubstitution = args.DisableVariableSubstitution + connect.ApplicationIntent = args.ApplicationIntent + connect.LoginTimeoutSeconds = args.LoginTimeout + connect.Encrypt = args.EncryptConnection + connect.PacketSize = args.PacketSize + connect.WorkstationName = args.WorkstationName + connect.LogLevel = args.DriverLoggingLevel + connect.ExitOnError = args.ExitOnError + connect.ErrorSeverityLevel = args.ErrorSeverityLevel } -func run(vars *sqlcmd.Variables) (int, error) { +func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { wd, err := os.Getwd() if err != nil { return 1, err } iactive := args.InputFile == nil && args.Query == "" + var console sqlcmd.Console = nil var line *readline.Instance if iactive { line, err = readline.New(">") if err != nil { return 1, err } + console = line defer line.Close() } - s := sqlcmd.New(line, wd, vars) - + s := sqlcmd.New(console, wd, vars) + setConnect(&s.Connect, args, vars) if args.BatchTerminator != "GO" { err = s.Cmd.SetBatchTerminator(args.BatchTerminator) if err != nil { @@ -203,7 +217,7 @@ func run(vars *sqlcmd.Variables) (int, error) { return 1, err } - setConnect(s, &args) + setConnect(&s.Connect, args, vars) s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false) if args.OutputFile != "" { err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile}) @@ -218,7 +232,8 @@ func run(vars *sqlcmd.Variables) (int, error) { once = true s.Query = args.Query } - err = s.ConnectDb("", "", "", !iactive) + // connect using no overrides + err = s.ConnectDb(nil, !iactive) if err != nil { return 1, err } diff --git a/cmd/sqlcmd/main_test.go b/cmd/sqlcmd/main_test.go index 6334c57c..0129b1d8 100644 --- a/cmd/sqlcmd/main_test.go +++ b/cmd/sqlcmd/main_test.go @@ -122,7 +122,7 @@ func TestRunInputFiles(t *testing.T) { vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") setVars(vars, &args) - exitCode, err := run(vars) + exitCode, err := run(vars, &args) assert.NoError(t, err, "run") assert.Equal(t, 0, exitCode, "exitCode") bytes, err := os.ReadFile(o.Name()) @@ -148,7 +148,7 @@ func TestQueryAndExit(t *testing.T) { vars.Set("VAR1", "100") setVars(vars, &args) - exitCode, err := run(vars) + exitCode, err := run(vars, &args) assert.NoError(t, err, "run") assert.Equal(t, 0, exitCode, "exitCode") bytes, err := os.ReadFile(o.Name()) @@ -173,8 +173,7 @@ func TestAzureAuth(t *testing.T) { vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") setVars(vars, &args) - - exitCode, err := run(vars) + exitCode, err := run(vars, &args) assert.NoError(t, err, "run") assert.Equal(t, 0, exitCode, "exitCode") bytes, err := os.ReadFile(o.Name()) diff --git a/go.mod b/go.mod index 6fbcd109..d26d8dd1 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,6 @@ module github.com/microsoft/go-sqlcmd go 1.16 require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 - github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 github.com/alecthomas/kong v0.2.18-0.20210621093454-54558f65e86f github.com/chzyer/logex v1.1.10 // indirect github.com/chzyer/test v0.0.0-20210722231415-061457976a23 // indirect @@ -15,3 +13,4 @@ require ( github.com/stretchr/testify v1.7.0 ) +replace github.com/denisenkom/go-mssqldb => github.com/shueybubbles/go-mssqldb v0.10.1-0.20220303143659-8896461e4ec7 diff --git a/go.sum b/go.sum index 520d944e..38288101 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,6 @@ github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMn github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.12.0 h1:VtrkII767ttSPNRfFekePK3sctr+joXgO58stqQbtUA= -github.com/denisenkom/go-mssqldb v0.12.0/go.mod h1:iiK0YP1ZeepvmBQk/QpLEhhTNJgfzrpArPY/aFvc9yU= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/gohxs/readline v0.0.0-20171011095936-a780388e6e7c h1:yE35fKFwcelIte3q5q1/cPiY7pI7vvf5/j/0ddxNCKs= github.com/gohxs/readline v0.0.0-20171011095936-a780388e6e7c/go.mod h1:9S/fKAutQ6wVHqm1jnp9D9sc5hu689s9AaTWFS92LaU= @@ -31,6 +29,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/shueybubbles/go-mssqldb v0.10.1-0.20220303143659-8896461e4ec7 h1:4CIaYagSRCGr0/Gh6cfF5cQx3RVE3qrQukZn8iMO6Y8= +github.com/shueybubbles/go-mssqldb v0.10.1-0.20220303143659-8896461e4ec7/go.mod h1:iiK0YP1ZeepvmBQk/QpLEhhTNJgfzrpArPY/aFvc9yU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/pkg/sqlcmd/azure_auth.go b/pkg/sqlcmd/azure_auth.go index 7b76bde7..a610dd16 100644 --- a/pkg/sqlcmd/azure_auth.go +++ b/pkg/sqlcmd/azure_auth.go @@ -5,6 +5,7 @@ package sqlcmd import ( "database/sql/driver" + "fmt" "net/url" "os" @@ -24,28 +25,27 @@ func getSqlClientId() string { return sqlClientId } -func (s *Sqlcmd) GetTokenBasedConnection(connstr string, user string, password string) (driver.Connector, error) { +func GetTokenBasedConnection(connstr string, authenticationMethod string) (driver.Connector, error) { connectionUrl, err := url.Parse(connstr) if err != nil { return nil, err } - if user != "" { - connectionUrl.User = url.UserPassword(user, password) - } - query := connectionUrl.Query() - query.Set("fedauth", s.Connect.authenticationMethod()) + query.Set("fedauth", authenticationMethod) query.Set("applicationclientid", getSqlClientId()) - - switch s.Connect.AuthenticationMethod { - case azuread.ActiveDirectoryServicePrincipal: - case azuread.ActiveDirectoryApplication: + switch authenticationMethod { + case azuread.ActiveDirectoryServicePrincipal, azuread.ActiveDirectoryApplication: query.Set("clientcertpath", os.Getenv("AZURE_CLIENT_CERTIFICATE_PATH")) case azuread.ActiveDirectoryInteractive: + loginTimeout := query.Get("connection timeout") + loginTimeoutSeconds := 0 + if loginTimeout != "" { + _, _ = fmt.Sscanf(loginTimeout, "%d", &loginTimeoutSeconds) + } // AAD interactive needs minutes at minimum - if s.Connect.LoginTimeoutSeconds < 120 { + if loginTimeoutSeconds > 0 && loginTimeoutSeconds < 120 { query.Set("connection timeout", "120") } } diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 91aed59a..146c07cb 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -10,6 +10,8 @@ import ( "sort" "strings" "syscall" + + "github.com/alecthomas/kong" ) // Command defines a sqlcmd action which can be intermixed with the SQL batch @@ -80,6 +82,11 @@ func newCommands() Commands { action: listCommand, name: "LIST", }, + "CONNECT": { + regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`), + action: connectCommand, + name: "CONNECT", + }, } } @@ -285,3 +292,39 @@ func listCommand(s *Sqlcmd, args []string, line uint) error { return nil } + +type connectData struct { + Server string `arg:""` + Database string `short:"D"` + Username string `short:"U"` + Password string `short:"P"` + LoginTimeout int `short:"l"` + AuthenticationMethod string `short:"G"` +} + +func connectCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || strings.TrimSpace(args[0]) == "" { + return InvalidCommandError("CONNECT", line) + } + arguments := &connectData{} + parser, err := kong.New(arguments) + if err != nil { + return InvalidCommandError("CONNECT", line) + } + if _, err = parser.Parse(strings.Split(args[0], " ")); err != nil { + return InvalidCommandError("CONNECT", line) + } + + connect := s.Connect + connect.UserName = arguments.Username + connect.Password = arguments.Password + connect.ServerName = arguments.Server + if arguments.LoginTimeout > 0 { + connect.LoginTimeoutSeconds = arguments.LoginTimeout + } + connect.AuthenticationMethod = arguments.AuthenticationMethod + // If no user name is provided we switch to integrated auth + _ = s.ConnectDb(&connect, s.lineIo == nil) + // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option + return nil +} diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index d71ef646..11bedd75 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -5,6 +5,8 @@ package sqlcmd import ( "bytes" + "fmt" + "os" "strings" "testing" @@ -41,6 +43,7 @@ func TestCommandParsing(t *testing.T) { {`:EXIT (select 100 as count)`, "EXIT", []string{"(select 100 as count)"}}, {`:EXIT ( )`, "EXIT", []string{"( )"}}, {`EXIT `, "EXIT", []string{""}}, + {`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}}, } for _, test := range commands { @@ -151,3 +154,39 @@ func TestListCommand(t *testing.T) { o := buf.buf.String() assert.Equal(t, o, "select 1"+SqlcmdEol, ":list output not equal to batch") } + +func TestConnectCommand(t *testing.T) { + s, _ := setupSqlCmdWithMemoryOutput(t) + prompted := false + s.lineIo = &testConsole{ + OnPasswordPrompt: func(prompt string) ([]byte, error) { + prompted = true + return []byte{}, nil + }, + } + err := connectCommand(s, []string{"someserver -U someuser"}, 1) + assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure") + assert.True(t, prompted, "connectCommand with user name and no password should prompt for password") + assert.NotEqual(t, "someserver", s.Connect.ServerName, "On error, sqlCmd.Connect does not copy inputs") + + err = connectCommand(s, []string{}, 2) + assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error") + c := newConnect(t) + + authenticationMethod := "" + if c.Password == "" { + c.UserName = os.Getenv("AZURE_CLIENT_ID") + "@" + os.Getenv("AZURE_TENANT_ID") + c.Password = os.Getenv("AZURE_CLIENT_SECRET") + authenticationMethod = "-G ActiveDirectoryServicePrincipal" + if c.Password == "" { + t.Log("Not trying :Connect with valid password due to no password being available") + return + } + err = connectCommand(s, []string{fmt.Sprintf("%s -U %s -P %s %s", c.ServerName, c.UserName, c.Password, authenticationMethod)}, 3) + assert.NoError(t, err, "connectCommand with valid parameters should not return an error") + // not using assert to avoid printing passwords in the log + if s.Connect.UserName != c.UserName || c.Password != s.Connect.Password { + t.Fatal("After connect, sqlCmd.Connect is not updated") + } + } +} diff --git a/pkg/sqlcmd/connect.go b/pkg/sqlcmd/connect.go new file mode 100644 index 00000000..f5637836 --- /dev/null +++ b/pkg/sqlcmd/connect.go @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "fmt" + "net/url" + + "github.com/denisenkom/go-mssqldb/azuread" +) + +// ConnectSettings specifies the settings for connections +type ConnectSettings struct { + // ServerName is the full name including instance and port + ServerName string + // UseTrustedConnection indicates integrated auth is used when no user name is provided + UseTrustedConnection bool + // TrustServerCertificate sets the TrustServerCertificate setting on the connection string + TrustServerCertificate bool + // AuthenticationMethod defines the authentication method for connecting to Azure SQL Database + AuthenticationMethod string + // DisableEnvironmentVariables determines if sqlcmd resolves scripting variables from the process environment + DisableEnvironmentVariables bool + // DisableVariableSubstitution determines if scripting variables should be evaluated + DisableVariableSubstitution bool + // UserName is the username for the SQL connection + UserName string + // Password is the password used with SQL authentication or AAD authentications that require a password + Password string + // Encrypt is the choice of encryption + Encrypt string + // PacketSize is the size of the packet for TDS communication + PacketSize int + // LoginTimeoutSeconds specifies the timeout for establishing a connection + LoginTimeoutSeconds int + // WorkstationName is the string to use to identify the host in server DMVs + WorkstationName string + // ApplicationIntent can only be empty or "ReadOnly" + ApplicationIntent string + // LogLevel is the mssql driver log level + LogLevel int + // ExitOnError specifies whether to exit the app on an error + ExitOnError bool + // ErrorSeverityLevel sets the minimum SQL severity level to treat as an error + ErrorSeverityLevel uint8 + // Database is the name of the database for the connection + Database string +} + +func (c ConnectSettings) authenticationMethod() string { + if c.AuthenticationMethod == "" { + return NotSpecified + } + return c.AuthenticationMethod +} + +func (connect ConnectSettings) integratedAuthentication() bool { + return connect.UseTrustedConnection || (connect.UserName == "" && connect.authenticationMethod() == NotSpecified) +} + +func (connect ConnectSettings) sqlAuthentication() bool { + return connect.authenticationMethod() == SqlPassword || + (!connect.UseTrustedConnection && connect.authenticationMethod() == NotSpecified && connect.UserName != "") +} + +func (connect ConnectSettings) requiresPassword() bool { + requiresPassword := connect.sqlAuthentication() + if !requiresPassword { + switch connect.authenticationMethod() { + case azuread.ActiveDirectoryApplication, azuread.ActiveDirectoryPassword, azuread.ActiveDirectoryServicePrincipal: + requiresPassword = true + } + } + return requiresPassword +} + +// ConnectionString returns the go-mssql connection string to use for queries +func (connect ConnectSettings) ConnectionString() (connectionString string, err error) { + serverName, instance, port, err := splitServer(connect.ServerName) + if serverName == "" { + serverName = "." + } + if err != nil { + return "", err + } + query := url.Values{} + connectionURL := &url.URL{ + Scheme: "sqlserver", + Path: instance, + } + + if connect.sqlAuthentication() || connect.authenticationMethod() == azuread.ActiveDirectoryPassword || connect.authenticationMethod() == azuread.ActiveDirectoryServicePrincipal || connect.authenticationMethod() == azuread.ActiveDirectoryApplication { + connectionURL.User = url.UserPassword(connect.UserName, connect.Password) + } + if (connect.authenticationMethod() == azuread.ActiveDirectoryMSI || connect.authenticationMethod() == azuread.ActiveDirectoryManagedIdentity) && connect.UserName != "" { + connectionURL.User = url.UserPassword(connect.UserName, connect.Password) + } + if port > 0 { + connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port) + } else { + connectionURL.Host = serverName + } + if connect.Database != "" { + query.Add("database", connect.Database) + } + + if connect.TrustServerCertificate { + query.Add("trustservercertificate", "true") + } + if connect.ApplicationIntent != "" && connect.ApplicationIntent != "default" { + query.Add("applicationintent", connect.ApplicationIntent) + } + if connect.LoginTimeoutSeconds > 0 { + query.Add("connection timeout", fmt.Sprint(connect.LoginTimeoutSeconds)) + } + if connect.PacketSize > 0 { + query.Add("packet size", fmt.Sprint(connect.PacketSize)) + } + if connect.WorkstationName != "" { + query.Add("workstation id", connect.WorkstationName) + } + if connect.Encrypt != "" && connect.Encrypt != "default" { + query.Add("encrypt", connect.Encrypt) + } + if connect.LogLevel > 0 { + query.Add("log", fmt.Sprint(connect.LogLevel)) + } + connectionURL.RawQuery = query.Encode() + return connectionURL.String(), nil +} diff --git a/pkg/sqlcmd/format_test.go b/pkg/sqlcmd/format_test.go index 2e5b4236..877af3a0 100644 --- a/pkg/sqlcmd/format_test.go +++ b/pkg/sqlcmd/format_test.go @@ -56,7 +56,7 @@ func TestCalcColumnDetails(t *testing.T) { }, } - db, err := ConnectDb() + db, err := ConnectDb(t) if assert.NoError(t, err, "ConnectDB failed") { defer db.Close() for _, test := range tests { diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 53ac6a21..efca05d3 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "net/url" "os" "os/signal" osuser "os/user" @@ -21,7 +20,6 @@ import ( mssql "github.com/denisenkom/go-mssqldb" "github.com/denisenkom/go-mssqldb/msdsn" - "github.com/gohxs/readline" "github.com/golang-sql/sqlexp" ) @@ -34,41 +32,14 @@ var ( ErrCtrlC = errors.New(WarningPrefix + "The last operation was terminated because the user pressed CTRL+C") ) -// ConnectSettings are the settings for connections that can't be -// inferred from scripting variables -type ConnectSettings struct { - // UseTrustedConnection indicates integrated auth is used when no user name is provided - UseTrustedConnection bool - // TrustServerCertificate sets the TrustServerCertificate setting on the connection string - TrustServerCertificate bool - AuthenticationMethod string - // DisableEnvironmentVariables determines if sqlcmd resolves scripting variables from the process environment - DisableEnvironmentVariables bool - // DisableVariableSubstitution determines if scripting variables should be evaluated - DisableVariableSubstitution bool - // Password is the password used with SQL authentication - Password string - // Encrypt is the choice of encryption - Encrypt string - // PacketSize is the size of the packet for TDS communication - PacketSize int - // LoginTimeoutSeconds specifies the timeout for establishing a connection - LoginTimeoutSeconds int - // WorkstationName is the string to use to identify the host in server DMVs - WorkstationName string - // ApplicationIntent can only be empty or "ReadOnly" - ApplicationIntent string - // mssql driver log level - LogLevel int - ExitOnError bool - ErrorSeverityLevel uint8 -} - -func (c ConnectSettings) authenticationMethod() string { - if c.AuthenticationMethod == "" { - return NotSpecified - } - return c.AuthenticationMethod +// Console defines methods used for console input and output +type Console interface { + // Readline returns the next line of input. + Readline() (string, error) + // Readpassword displays the given prompt and returns a password + ReadPassword(prompt string) ([]byte, error) + // SetPrompt sets the prompt text shown to input the next line + SetPrompt(s string) } // Sqlcmd is the core processor for text lines. @@ -77,7 +48,7 @@ func (c ConnectSettings) authenticationMethod() string { // When the batch delimiter is encountered it sends the current batch to the active connection and prints // the results to the output writer type Sqlcmd struct { - lineIo *readline.Instance + lineIo Console workingDirectory string db *sql.DB out io.WriteCloser @@ -93,7 +64,7 @@ type Sqlcmd struct { } // New creates a new Sqlcmd instance -func New(l *readline.Instance, workingDirectory string, vars *Variables) *Sqlcmd { +func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd { s := &Sqlcmd{ lineIo: l, workingDirectory: workingDirectory, @@ -136,12 +107,8 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { } else { cmd, args, err = s.batch.Next() } - switch { - case err == readline.ErrInterrupt: - // Ignore any error printing the ctrl-c notice since we are exiting - _, _ = s.GetOutput().Write([]byte(ErrCtrlC.Error() + SqlcmdEol)) - return nil - case err != nil: + + if err != nil { if err == io.EOF { if s.batch.Length == 0 { return lastError @@ -150,6 +117,9 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { if !execute { return nil } + } else if err.Error() == "Interrupt" { + // Ignore any error printing the ctrl-c notice since we are exiting + _, _ = s.GetOutput().Write([]byte(ErrCtrlC.Error() + SqlcmdEol)) } else { _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) } @@ -229,106 +199,33 @@ func (s *Sqlcmd) SetError(e io.WriteCloser) { s.err = e } -// ConnectionString returns the go-mssql connection string to use for queries -func (s *Sqlcmd) ConnectionString() (connectionString string, err error) { - serverName, instance, port, err := s.vars.SQLCmdServer() - if serverName == "" { - serverName = "." - } - if err != nil { - return "", err - } - query := url.Values{} - connectionURL := &url.URL{ - Scheme: "sqlserver", - Path: instance, - } - - if s.sqlAuthentication() { - connectionURL.User = url.UserPassword(s.vars.SQLCmdUser(), s.Connect.Password) - } - if port > 0 { - connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port) - } else { - connectionURL.Host = serverName - } - if s.vars.SQLCmdDatabase() != "" { - query.Add("database", s.vars.SQLCmdDatabase()) - } - - if s.Connect.TrustServerCertificate { - query.Add("trustservercertificate", "true") - } - if s.Connect.ApplicationIntent != "" && s.Connect.ApplicationIntent != "default" { - query.Add("applicationintent", s.Connect.ApplicationIntent) - } - if s.Connect.LoginTimeoutSeconds > 0 { - query.Add("connection timeout", fmt.Sprint(s.Connect.LoginTimeoutSeconds)) - } - if s.Connect.PacketSize > 0 { - query.Add("packet size", fmt.Sprint(s.Connect.PacketSize)) - } - if s.Connect.WorkstationName != "" { - query.Add("workstation id", s.Connect.WorkstationName) - } - if s.Connect.Encrypt != "" && s.Connect.Encrypt != "default" { - query.Add("encrypt", s.Connect.Encrypt) - } - if s.Connect.LogLevel > 0 { - query.Add("log", fmt.Sprint(s.Connect.LogLevel)) - } - connectionURL.RawQuery = query.Encode() - return connectionURL.String(), nil -} - // ConnectDb opens a connection to the database with the given modifications to the connection -func (s *Sqlcmd) ConnectDb(server string, user string, password string, nopw bool) error { - if user != "" && password == "" && !nopw { - return ErrNeedPassword +// nopw == true means don't prompt for a password if the auth type requires it +// if connect is nil, ConnectDb uses the current connection. If non-nil and the connection succeeds, +// s.Connect is replaced with the new value. +func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { + newConnection := connect != nil + if connect == nil { + connect = &s.Connect } - connstr, err := s.ConnectionString() - if err != nil { - return err + var connector driver.Connector + useAad := !connect.sqlAuthentication() && !connect.integratedAuthentication() + if connect.requiresPassword() && !nopw && connect.Password == "" { + var err error + if connect.Password, err = s.promptPassword(); err != nil { + return err + } } - - connectionURL, err := url.Parse(connstr) + connstr, err := connect.ConnectionString() if err != nil { return err } - if server != "" { - serverName, instance, port, err := splitServer(server) - if err != nil { - return err - } - connectionURL.Path = instance - if port > 0 { - connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port) - } else { - connectionURL.Host = serverName - } - } - - var connector driver.Connector - // To determine whether to use Sql auth/windows auth/aad auth, compare the current ConnectSettings with the new parameters - // If sqlcmd was started with sql auth or windows auth, :connect will not switch to AAD - // if sqlcmd was started with AAD auth, it will remain in some variant of AAD auth depending on the user/password combination - useAad := !s.sqlAuthentication() && !s.integratedAuthentication() - if password == "" { - password = s.Connect.Password - } if !useAad { - if user != "" { - connectionURL.User = url.UserPassword(user, password) - } - - connector, err = mssql.NewConnector(connectionURL.String()) + connector, err = mssql.NewConnector(connstr) } else { - if user == "" { - user = s.vars.SQLCmdUser() - } - connector, err = s.GetTokenBasedConnection(connectionURL.String(), user, password) + connector, err = GetTokenBasedConnection(connstr, connect.authenticationMethod()) } if err != nil { return err @@ -344,32 +241,40 @@ func (s *Sqlcmd) ConnectDb(server string, user string, password string, nopw boo s.db.Close() } s.db = db - if server != "" { - s.vars.Set(SQLCMDSERVER, server) - } - if user != "" { - s.vars.Set(SQLCMDUSER, user) - s.Connect.UseTrustedConnection = false - s.Connect.Password = password - } else if s.vars.SQLCmdUser() == "" { + s.vars.Set(SQLCMDSERVER, connect.ServerName) + s.vars.Set(SQLCMDDBNAME, connect.Database) + if connect.UserName != "" { + s.vars.Set(SQLCMDUSER, connect.UserName) + } else { u, e := osuser.Current() if e != nil { panic("Unable to get user name") } - if !useAad { - s.Connect.UseTrustedConnection = true - } s.vars.Set(SQLCMDUSER, u.Username) } - + if newConnection { + s.Connect = *connect + } if s.batch != nil { s.batch.batchline = 1 } return nil } -// IncludeFile opens the given file and processes its batches -// When processAll is true text not followed by a go statement is run as a query +func (s *Sqlcmd) promptPassword() (string, error) { + if s.lineIo == nil { + return "", nil + } + pwd, err := s.lineIo.ReadPassword("Password:") + if err != nil { + return "", err + } + + return string(pwd), nil +} + +// IncludeFile opens the given file and processes its batches. +// When processAll is true, text not followed by a go statement is run as a query func (s *Sqlcmd) IncludeFile(path string, processAll bool) error { f, err := os.Open(path) if err != nil { @@ -451,15 +356,6 @@ func setupCloseHandler(s *Sqlcmd) { }() } -func (s *Sqlcmd) integratedAuthentication() bool { - return s.Connect.UseTrustedConnection || (s.vars.SQLCmdUser() == "" && s.Connect.authenticationMethod() == NotSpecified) -} - -func (s *Sqlcmd) sqlAuthentication() bool { - return s.Connect.authenticationMethod() == SqlPassword || - (!s.Connect.UseTrustedConnection && s.Connect.authenticationMethod() == NotSpecified && s.vars.SQLCmdUser() != "") -} - // runQuery runs the query and prints the results // The return value is based on the first cell of the last column of the last result set. // If it's numeric, it will be converted to int diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index 6acb9a26..fd152420 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -24,7 +24,6 @@ const oneRowAffected = "(1 row affected)" func TestConnectionStringFromSqlCmd(t *testing.T) { type connectionStringTest struct { settings *ConnectSettings - setup func(*Variables) connectionString string } @@ -32,59 +31,38 @@ func TestConnectionStringFromSqlCmd(t *testing.T) { commands := []connectionStringTest{ - {nil, nil, "sqlserver://."}, + {&ConnectSettings{}, "sqlserver://."}, { - &ConnectSettings{TrustServerCertificate: true, WorkstationName: "mystation"}, - func(vars *Variables) { - vars.Set(SQLCMDDBNAME, "somedatabase") - }, + &ConnectSettings{TrustServerCertificate: true, WorkstationName: "mystation", Database: "somedatabase"}, "sqlserver://.?database=somedatabase&trustservercertificate=true&workstation+id=mystation", }, { - &ConnectSettings{WorkstationName: "mystation", Encrypt: "false"}, - func(vars *Variables) { - vars.Set(SQLCMDDBNAME, "somedatabase") - }, + &ConnectSettings{WorkstationName: "mystation", Encrypt: "false", Database: "somedatabase"}, "sqlserver://.?database=somedatabase&encrypt=false&workstation+id=mystation", }, { - &ConnectSettings{TrustServerCertificate: true, Password: pwd}, - func(vars *Variables) { - vars.Set(SQLCMDSERVER, `someserver/instance`) - vars.Set(SQLCMDDBNAME, "somedatabase") - vars.Set(SQLCMDUSER, "someuser") - }, + &ConnectSettings{TrustServerCertificate: true, Password: pwd, ServerName: `someserver/instance`, Database: "somedatabase", UserName: "someuser"}, fmt.Sprintf("sqlserver://someuser:%s@someserver/instance?database=somedatabase&trustservercertificate=true", pwd), }, { - &ConnectSettings{TrustServerCertificate: true, UseTrustedConnection: true, Password: pwd}, - func(vars *Variables) { - vars.Set(SQLCMDSERVER, `tcp:someserver,1045`) - vars.Set(SQLCMDUSER, "someuser") - }, + &ConnectSettings{TrustServerCertificate: true, UseTrustedConnection: true, Password: pwd, ServerName: `tcp:someserver,1045`, UserName: "someuser"}, "sqlserver://someserver:1045?trustservercertificate=true", }, { - nil, - func(vars *Variables) { - vars.Set(SQLCMDSERVER, `tcp:someserver,1045`) - }, + &ConnectSettings{ServerName: `tcp:someserver,1045`}, "sqlserver://someserver:1045", }, + { + &ConnectSettings{ServerName: "someserver", AuthenticationMethod: azuread.ActiveDirectoryServicePrincipal, UserName: "myapp@mytenant", Password: pwd}, + fmt.Sprintf("sqlserver://myapp%%40mytenant:%s@someserver", pwd), + }, } - for _, test := range commands { - v := InitializeVariables(false) - if test.setup != nil { - test.setup(v) - } - s := &Sqlcmd{vars: v} - if test.settings != nil { - s.Connect = *test.settings - } - connectionString, err := s.ConnectionString() - if assert.NoError(t, err, "Unexpected error from %+v", s) { - assert.Equal(t, test.connectionString, connectionString, "Wrong connection string from: %+v", *s) + for i, test := range commands { + + connectionString, err := test.settings.ConnectionString() + if assert.NoError(t, err, "Unexpected error from [%d] %+v", i, test.settings) { + assert.Equal(t, test.connectionString, connectionString, "Wrong connection string from [%d]: %+v", i, test.settings) } } } @@ -97,13 +75,8 @@ set will be to localhost using Windows auth. func TestSqlCmdConnectDb(t *testing.T) { v := InitializeVariables(true) s := &Sqlcmd{vars: v} - if canTestAzureAuth() { - s.Connect.AuthenticationMethod = azuread.ActiveDirectoryDefault - } else { - s.Connect.Password = os.Getenv(SQLCMDPASSWORD) - } - - err := s.ConnectDb("", "", "", false) + s.Connect = newConnect(t) + err := s.ConnectDb(nil, false) if assert.NoError(t, err, "ConnectDb should succeed") { sqlcmduser := os.Getenv(SQLCMDUSER) if sqlcmduser == "" { @@ -114,15 +87,11 @@ func TestSqlCmdConnectDb(t *testing.T) { } } -func ConnectDb() (*sql.DB, error) { +func ConnectDb(t testing.TB) (*sql.DB, error) { v := InitializeVariables(true) s := &Sqlcmd{vars: v} - if canTestAzureAuth() { - s.Connect.AuthenticationMethod = azuread.ActiveDirectoryDefault - } else { - s.Connect.Password = os.Getenv(SQLCMDPASSWORD) - } - err := s.ConnectDb("", "", "", false) + s.Connect = newConnect(t) + err := s.ConnectDb(nil, false) return s.db, err } @@ -229,7 +198,6 @@ func TestGetRunnableQuery(t *testing.T) { r = s.getRunnableQuery(test.raw) assert.Equalf(t, test.raw, r, `runnableQuery without variable subs for "%s"`, test.raw) } - } func TestExitInitialQuery(t *testing.T) { @@ -301,6 +269,91 @@ func TestSqlCmdSetErrorLevel(t *testing.T) { assert.Equal(t, 16, s.Exitcode, "Select error should be the exit code") } +type testConsole struct { + PromptText string + OnPasswordPrompt func(prompt string) ([]byte, error) + OnReadLine func() (string, error) +} + +func (tc *testConsole) Readline() (string, error) { + return tc.OnReadLine() +} + +func (tc *testConsole) ReadPassword(prompt string) ([]byte, error) { + return tc.OnPasswordPrompt(prompt) +} + +func (tc *testConsole) SetPrompt(s string) { + tc.PromptText = s +} + +func TestPromptForPasswordNegative(t *testing.T) { + prompted := false + console := &testConsole{ + OnPasswordPrompt: func(prompt string) ([]byte, error) { + assert.Equal(t, "Password:", prompt, "Incorrect password prompt") + prompted = true + return []byte{}, nil + }, + OnReadLine: func() (string, error) { + assert.Fail(t, "ReadLine should not be called") + return "", nil + }, + } + v := InitializeVariables(true) + s := New(console, "", v) + s.Connect.UserName = "someuser" + err := s.ConnectDb(nil, false) + assert.True(t, prompted, "Password prompt not shown for SQL auth") + assert.Error(t, err, "ConnectDb") + prompted = false + s.Connect.AuthenticationMethod = azuread.ActiveDirectoryPassword + err = s.ConnectDb(nil, false) + assert.True(t, prompted, "Password prompt not shown for AD Password auth") + assert.Error(t, err, "ConnectDb") + prompted = false +} + +func TestPromptForPasswordPositive(t *testing.T) { + prompted := false + c := newConnect(t) + if c.Password == "" { + // See if azure variables are set for activedirectoryserviceprincipal + c.UserName = os.Getenv("AZURE_CLIENT_ID") + "@" + os.Getenv("AZURE_TENANT_ID") + c.Password = os.Getenv("AZURE_CLIENT_SECRET") + c.AuthenticationMethod = azuread.ActiveDirectoryServicePrincipal + if c.Password == "" { + t.Skip("No password available") + } + } + password := c.Password + c.Password = "" + console := &testConsole{ + OnPasswordPrompt: func(prompt string) ([]byte, error) { + assert.Equal(t, "Password:", prompt, "Incorrect password prompt") + prompted = true + return []byte(password), nil + }, + OnReadLine: func() (string, error) { + assert.Fail(t, "ReadLine should not be called") + return "", nil + }, + } + v := InitializeVariables(true) + s := New(console, "", v) + // attempt without password prompt + err := s.ConnectDb(&c, true) + assert.False(t, prompted, "ConnectDb with nopw=true should not prompt for password") + assert.Error(t, err, "ConnectDb with nopw==true and no password provided") + err = s.ConnectDb(&c, false) + assert.True(t, prompted, "ConnectDb with !nopw should prompt for password") + assert.NoError(t, err, "ConnectDb with !nopw and valid password returned from prompt") + if s.Connect.Password != password { + t.Fatal(t, err, "Password not stored in the connection") + } +} + +// runSqlCmd uses lines as input for sqlcmd instead of relying on file or console input func runSqlCmd(t testing.TB, s *Sqlcmd, lines []string) error { t.Helper() i := 0 @@ -320,16 +373,11 @@ func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) { v := InitializeVariables(true) v.Set(SQLCMDMAXVARTYPEWIDTH, "0") s := New(nil, "", v) - if canTestAzureAuth() { - t.Log("Using ActiveDirectoryDefault") - s.Connect.AuthenticationMethod = azuread.ActiveDirectoryDefault - } else { - s.Connect.Password = os.Getenv(SQLCMDPASSWORD) - } + s.Connect = newConnect(t) s.Format = NewSQLCmdDefaultFormatter(true) buf := &memoryBuffer{buf: new(bytes.Buffer)} s.SetOutput(buf) - err := s.ConnectDb("", "", "", true) + err := s.ConnectDb(nil, true) assert.NoError(t, err, "s.ConnectDB") return s, buf } @@ -339,17 +387,12 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { v := InitializeVariables(true) v.Set(SQLCMDMAXVARTYPEWIDTH, "0") s := New(nil, "", v) - if canTestAzureAuth() { - t.Log("Using ActiveDirectoryDefault") - s.Connect.AuthenticationMethod = azuread.ActiveDirectoryDefault - } else { - s.Connect.Password = os.Getenv(SQLCMDPASSWORD) - } + s.Connect = newConnect(t) s.Format = NewSQLCmdDefaultFormatter(true) file, err := os.CreateTemp("", "sqlcmdout") assert.NoError(t, err, "os.CreateTemp") s.SetOutput(file) - err = s.ConnectDb("", "", "", true) + err = s.ConnectDb(nil, true) assert.NoError(t, err, "s.ConnectDB") return s, file } @@ -360,3 +403,18 @@ func canTestAzureAuth() bool { userName := os.Getenv(SQLCMDUSER) return strings.Contains(server, ".database.windows.net") && userName == "" } + +func newConnect(t testing.TB) ConnectSettings { + t.Helper() + connect := ConnectSettings{ + UserName: os.Getenv(SQLCMDUSER), + Database: os.Getenv(SQLCMDDBNAME), + ServerName: os.Getenv(SQLCMDSERVER), + Password: os.Getenv(SQLCMDPASSWORD), + } + if canTestAzureAuth() { + t.Log("Using ActiveDirectoryDefault") + connect.AuthenticationMethod = azuread.ActiveDirectoryDefault + } + return connect +}