Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ We will be implementing command line switches and behaviors over time. Several s
- Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types.
- All commands must fit on one line, even `EXIT`. Interactive mode will not check for open parentheses or quotes for commands and prompt for successive lines. The ODBC sqlcmd allows the query run by `EXIT(query)` to span multiple lines.

### Miscellaneous enhancements

- `:Connect` now has an optional `-G` parameter to select one of the authentication methods for Azure SQL Database - `SqlAuthentication`, `ActiveDirectoryDefault`, `ActiveDirectoryIntegrated`, `ActiveDirectoryServicePrincipal`, `ActiveDirectoryManagedIdentity`, `ActiveDirectoryPassword`. If `-G` is not provided, either Integrated security or SQL Authentication will be used, dependent on the presence of a `-U` user name parameter.
- The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces.

### Azure Active Directory Authentication

This version of sqlcmd supports a broader range of AAD authentication models, based on the [azidentity package](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity). The implementation relies on an AAD Connector in the [driver](https://github.com/denisenkom/go-mssqldb).
Expand Down
59 changes: 37 additions & 22 deletions cmd/sqlcmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func main() {
setVars(vars, &args)

// so far sqlcmd prints all the errors itself so ignore it
exitCode, _ := run(vars)
exitCode, _ := run(vars, &args)
os.Exit(exitCode)
}

Expand Down Expand Up @@ -156,43 +156,57 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) {

}

func setConnect(s *sqlcmd.Sqlcmd, args *SQLCmdArguments) {
func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sqlcmd.Variables) {
if !args.DisableCmdAndWarn {
s.Connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD)
}
s.Connect.UseTrustedConnection = args.UseTrustedConnection
s.Connect.TrustServerCertificate = args.TrustServerCertificate
s.Connect.AuthenticationMethod = args.authenticationMethod(s.Connect.Password != "")
s.Connect.DisableEnvironmentVariables = args.DisableCmdAndWarn
s.Connect.DisableVariableSubstitution = args.DisableVariableSubstitution
s.Connect.ApplicationIntent = args.ApplicationIntent
s.Connect.LoginTimeoutSeconds = args.LoginTimeout
s.Connect.Encrypt = args.EncryptConnection
s.Connect.PacketSize = args.PacketSize
s.Connect.WorkstationName = args.WorkstationName
s.Connect.LogLevel = args.DriverLoggingLevel
s.Connect.ExitOnError = args.ExitOnError
s.Connect.ErrorSeverityLevel = args.ErrorSeverityLevel
connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD)
}
connect.ServerName = args.Server
if connect.ServerName == "" {
connect.ServerName, _ = vars.Get(sqlcmd.SQLCMDSERVER)
}
connect.Database = args.DatabaseName
if connect.Database == "" {
connect.Database, _ = vars.Get(sqlcmd.SQLCMDDBNAME)
}
connect.UserName = args.UserName
if connect.UserName == "" {
connect.UserName, _ = vars.Get(sqlcmd.SQLCMDUSER)
}
connect.UseTrustedConnection = args.UseTrustedConnection
connect.TrustServerCertificate = args.TrustServerCertificate
connect.AuthenticationMethod = args.authenticationMethod(connect.Password != "")
connect.DisableEnvironmentVariables = args.DisableCmdAndWarn
connect.DisableVariableSubstitution = args.DisableVariableSubstitution
connect.ApplicationIntent = args.ApplicationIntent
connect.LoginTimeoutSeconds = args.LoginTimeout
connect.Encrypt = args.EncryptConnection
connect.PacketSize = args.PacketSize
connect.WorkstationName = args.WorkstationName
connect.LogLevel = args.DriverLoggingLevel
connect.ExitOnError = args.ExitOnError
connect.ErrorSeverityLevel = args.ErrorSeverityLevel
}

func run(vars *sqlcmd.Variables) (int, error) {
func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
wd, err := os.Getwd()
if err != nil {
return 1, err
}

iactive := args.InputFile == nil && args.Query == ""
var console sqlcmd.Console = nil
var line *readline.Instance
if iactive {
line, err = readline.New(">")
if err != nil {
return 1, err
}
console = line
defer line.Close()
}

s := sqlcmd.New(line, wd, vars)

s := sqlcmd.New(console, wd, vars)
setConnect(&s.Connect, args, vars)
if args.BatchTerminator != "GO" {
err = s.Cmd.SetBatchTerminator(args.BatchTerminator)
if err != nil {
Expand All @@ -203,7 +217,7 @@ func run(vars *sqlcmd.Variables) (int, error) {
return 1, err
}

setConnect(s, &args)
setConnect(&s.Connect, args, vars)
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false)
if args.OutputFile != "" {
err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile})
Expand All @@ -218,7 +232,8 @@ func run(vars *sqlcmd.Variables) (int, error) {
once = true
s.Query = args.Query
}
err = s.ConnectDb("", "", "", !iactive)
// connect using no overrides
err = s.ConnectDb(nil, !iactive)
if err != nil {
return 1, err
}
Expand Down
7 changes: 3 additions & 4 deletions cmd/sqlcmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func TestRunInputFiles(t *testing.T) {
vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0")
setVars(vars, &args)

exitCode, err := run(vars)
exitCode, err := run(vars, &args)
assert.NoError(t, err, "run")
assert.Equal(t, 0, exitCode, "exitCode")
bytes, err := os.ReadFile(o.Name())
Expand All @@ -148,7 +148,7 @@ func TestQueryAndExit(t *testing.T) {
vars.Set("VAR1", "100")
setVars(vars, &args)

exitCode, err := run(vars)
exitCode, err := run(vars, &args)
assert.NoError(t, err, "run")
assert.Equal(t, 0, exitCode, "exitCode")
bytes, err := os.ReadFile(o.Name())
Expand All @@ -173,8 +173,7 @@ func TestAzureAuth(t *testing.T) {
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0")
setVars(vars, &args)

exitCode, err := run(vars)
exitCode, err := run(vars, &args)
assert.NoError(t, err, "run")
assert.Equal(t, 0, exitCode, "exitCode")
bytes, err := os.ReadFile(o.Name())
Expand Down
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ module github.com/microsoft/go-sqlcmd
go 1.16

require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0
github.com/alecthomas/kong v0.2.18-0.20210621093454-54558f65e86f
github.com/chzyer/logex v1.1.10 // indirect
github.com/chzyer/test v0.0.0-20210722231415-061457976a23 // indirect
Expand All @@ -15,3 +13,4 @@ require (
github.com/stretchr/testify v1.7.0
)

replace github.com/denisenkom/go-mssqldb => github.com/shueybubbles/go-mssqldb v0.10.1-0.20220303143659-8896461e4ec7
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMn
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisenkom/go-mssqldb v0.12.0 h1:VtrkII767ttSPNRfFekePK3sctr+joXgO58stqQbtUA=
github.com/denisenkom/go-mssqldb v0.12.0/go.mod h1:iiK0YP1ZeepvmBQk/QpLEhhTNJgfzrpArPY/aFvc9yU=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/gohxs/readline v0.0.0-20171011095936-a780388e6e7c h1:yE35fKFwcelIte3q5q1/cPiY7pI7vvf5/j/0ddxNCKs=
github.com/gohxs/readline v0.0.0-20171011095936-a780388e6e7c/go.mod h1:9S/fKAutQ6wVHqm1jnp9D9sc5hu689s9AaTWFS92LaU=
Expand All @@ -31,6 +29,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/shueybubbles/go-mssqldb v0.10.1-0.20220303143659-8896461e4ec7 h1:4CIaYagSRCGr0/Gh6cfF5cQx3RVE3qrQukZn8iMO6Y8=
github.com/shueybubbles/go-mssqldb v0.10.1-0.20220303143659-8896461e4ec7/go.mod h1:iiK0YP1ZeepvmBQk/QpLEhhTNJgfzrpArPY/aFvc9yU=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand Down
22 changes: 11 additions & 11 deletions pkg/sqlcmd/azure_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package sqlcmd

import (
"database/sql/driver"
"fmt"
"net/url"
"os"

Expand All @@ -24,28 +25,27 @@ func getSqlClientId() string {
return sqlClientId
}

func (s *Sqlcmd) GetTokenBasedConnection(connstr string, user string, password string) (driver.Connector, error) {
func GetTokenBasedConnection(connstr string, authenticationMethod string) (driver.Connector, error) {

connectionUrl, err := url.Parse(connstr)
if err != nil {
return nil, err
}

if user != "" {
connectionUrl.User = url.UserPassword(user, password)
}

query := connectionUrl.Query()
query.Set("fedauth", s.Connect.authenticationMethod())
query.Set("fedauth", authenticationMethod)
query.Set("applicationclientid", getSqlClientId())

switch s.Connect.AuthenticationMethod {
case azuread.ActiveDirectoryServicePrincipal:
case azuread.ActiveDirectoryApplication:
switch authenticationMethod {
case azuread.ActiveDirectoryServicePrincipal, azuread.ActiveDirectoryApplication:
query.Set("clientcertpath", os.Getenv("AZURE_CLIENT_CERTIFICATE_PATH"))
case azuread.ActiveDirectoryInteractive:
loginTimeout := query.Get("connection timeout")
loginTimeoutSeconds := 0
if loginTimeout != "" {
_, _ = fmt.Sscanf(loginTimeout, "%d", &loginTimeoutSeconds)
}
// AAD interactive needs minutes at minimum
if s.Connect.LoginTimeoutSeconds < 120 {
if loginTimeoutSeconds > 0 && loginTimeoutSeconds < 120 {
query.Set("connection timeout", "120")
}
}
Expand Down
43 changes: 43 additions & 0 deletions pkg/sqlcmd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"sort"
"strings"
"syscall"

"github.com/alecthomas/kong"
)

// Command defines a sqlcmd action which can be intermixed with the SQL batch
Expand Down Expand Up @@ -80,6 +82,11 @@ func newCommands() Commands {
action: listCommand,
name: "LIST",
},
"CONNECT": {
regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`),
action: connectCommand,
name: "CONNECT",
},
}

}
Expand Down Expand Up @@ -285,3 +292,39 @@ func listCommand(s *Sqlcmd, args []string, line uint) error {

return nil
}

type connectData struct {
Server string `arg:""`
Database string `short:"D"`
Username string `short:"U"`
Password string `short:"P"`
LoginTimeout int `short:"l"`
AuthenticationMethod string `short:"G"`
}

func connectCommand(s *Sqlcmd, args []string, line uint) error {
if len(args) == 0 || strings.TrimSpace(args[0]) == "" {
return InvalidCommandError("CONNECT", line)
}
arguments := &connectData{}
parser, err := kong.New(arguments)
if err != nil {
return InvalidCommandError("CONNECT", line)
}
if _, err = parser.Parse(strings.Split(args[0], " ")); err != nil {
return InvalidCommandError("CONNECT", line)
}

connect := s.Connect
connect.UserName = arguments.Username
connect.Password = arguments.Password
connect.ServerName = arguments.Server
if arguments.LoginTimeout > 0 {
connect.LoginTimeoutSeconds = arguments.LoginTimeout
}
connect.AuthenticationMethod = arguments.AuthenticationMethod
// If no user name is provided we switch to integrated auth
_ = s.ConnectDb(&connect, s.lineIo == nil)
// ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option
return nil
}
39 changes: 39 additions & 0 deletions pkg/sqlcmd/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package sqlcmd

import (
"bytes"
"fmt"
"os"
"strings"
"testing"

Expand Down Expand Up @@ -41,6 +43,7 @@ func TestCommandParsing(t *testing.T) {
{`:EXIT (select 100 as count)`, "EXIT", []string{"(select 100 as count)"}},
{`:EXIT ( )`, "EXIT", []string{"( )"}},
{`EXIT `, "EXIT", []string{""}},
{`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}},
}

for _, test := range commands {
Expand Down Expand Up @@ -151,3 +154,39 @@ func TestListCommand(t *testing.T) {
o := buf.buf.String()
assert.Equal(t, o, "select 1"+SqlcmdEol, ":list output not equal to batch")
}

func TestConnectCommand(t *testing.T) {
s, _ := setupSqlCmdWithMemoryOutput(t)
prompted := false
s.lineIo = &testConsole{
OnPasswordPrompt: func(prompt string) ([]byte, error) {
prompted = true
return []byte{}, nil
},
}
err := connectCommand(s, []string{"someserver -U someuser"}, 1)
assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure")
assert.True(t, prompted, "connectCommand with user name and no password should prompt for password")
assert.NotEqual(t, "someserver", s.Connect.ServerName, "On error, sqlCmd.Connect does not copy inputs")

err = connectCommand(s, []string{}, 2)
assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error")
c := newConnect(t)

authenticationMethod := ""
if c.Password == "" {
c.UserName = os.Getenv("AZURE_CLIENT_ID") + "@" + os.Getenv("AZURE_TENANT_ID")
c.Password = os.Getenv("AZURE_CLIENT_SECRET")
authenticationMethod = "-G ActiveDirectoryServicePrincipal"
if c.Password == "" {
t.Log("Not trying :Connect with valid password due to no password being available")
return
}
err = connectCommand(s, []string{fmt.Sprintf("%s -U %s -P %s %s", c.ServerName, c.UserName, c.Password, authenticationMethod)}, 3)
assert.NoError(t, err, "connectCommand with valid parameters should not return an error")
// not using assert to avoid printing passwords in the log
if s.Connect.UserName != c.UserName || c.Password != s.Connect.Password {
t.Fatal("After connect, sqlCmd.Connect is not updated")
}
}
}
Loading