Skip to content

Commit f05207d

Browse files
authored
support variables in some commands (#123)
1 parent eeea072 commit f05207d

File tree

4 files changed

+128
-27
lines changed

4 files changed

+128
-27
lines changed

pkg/sqlcmd/commands.go

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"regexp"
1010
"sort"
11+
"strconv"
1112
"strings"
1213

1314
"github.com/alecthomas/kong"
@@ -57,8 +58,7 @@ func newCommands() Commands {
5758
regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`),
5859
action: errorCommand,
5960
name: "ERROR",
60-
},
61-
"READFILE": {
61+
}, "READFILE": {
6262
regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`),
6363
action: readFileCommand,
6464
name: "READFILE",
@@ -143,7 +143,13 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error {
143143
}
144144
}
145145
query = strings.TrimSpace(params[1 : len(params)-1])
146-
if query != "" {
146+
s.batch.Reset([]rune(query))
147+
_, _, err := s.batch.Next()
148+
if err != nil {
149+
return err
150+
}
151+
query = s.batch.String()
152+
if s.batch.String() != "" {
147153
query = s.getRunnableQuery(query)
148154
s.Exitcode, _ = s.runQuery(query)
149155
}
@@ -239,7 +245,7 @@ func readFileCommand(s *Sqlcmd, args []string, line uint) error {
239245
if args == nil || len(args) != 1 {
240246
return InvalidCommandError(":R", line)
241247
}
242-
return s.IncludeFile(args[0], false)
248+
return s.IncludeFile(resolveArgumentVariables(s, []rune(args[0])), false)
243249
}
244250

245251
// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables
@@ -313,33 +319,91 @@ type connectData struct {
313319
Database string `short:"D"`
314320
Username string `short:"U"`
315321
Password string `short:"P"`
316-
LoginTimeout int `short:"l"`
322+
LoginTimeout string `short:"l"`
317323
AuthenticationMethod string `short:"G"`
318324
}
319325

320326
func connectCommand(s *Sqlcmd, args []string, line uint) error {
321-
if len(args) == 0 || strings.TrimSpace(args[0]) == "" {
327+
328+
if len(args) == 0 {
329+
return InvalidCommandError("CONNECT", line)
330+
}
331+
cmdLine := strings.TrimSpace(args[0])
332+
if cmdLine == "" {
322333
return InvalidCommandError("CONNECT", line)
323334
}
324335
arguments := &connectData{}
325336
parser, err := kong.New(arguments)
326337
if err != nil {
327338
return InvalidCommandError("CONNECT", line)
328339
}
329-
if _, err = parser.Parse(strings.Split(args[0], " ")); err != nil {
340+
341+
// Fields removes extra whitespace.
342+
// Note :connect doesn't support passwords with spaces
343+
if _, err = parser.Parse(strings.Fields(cmdLine)); err != nil {
330344
return InvalidCommandError("CONNECT", line)
331345
}
332346

333347
connect := s.Connect
334-
connect.UserName = arguments.Username
335-
connect.Password = arguments.Password
336-
connect.ServerName = arguments.Server
337-
if arguments.LoginTimeout > 0 {
338-
connect.LoginTimeoutSeconds = arguments.LoginTimeout
348+
connect.UserName = resolveArgumentVariables(s, []rune(arguments.Username))
349+
connect.Password = resolveArgumentVariables(s, []rune(arguments.Password))
350+
connect.ServerName = resolveArgumentVariables(s, []rune(arguments.Server))
351+
timeout := resolveArgumentVariables(s, []rune(arguments.LoginTimeout))
352+
if timeout != "" {
353+
if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil {
354+
if timeoutSeconds < 0 {
355+
return InvalidCommandError("CONNECT", line)
356+
}
357+
connect.LoginTimeoutSeconds = int(timeoutSeconds)
358+
}
339359
}
340360
connect.AuthenticationMethod = arguments.AuthenticationMethod
341361
// If no user name is provided we switch to integrated auth
342362
_ = s.ConnectDb(&connect, s.lineIo == nil)
343363
// ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option
344364
return nil
345365
}
366+
367+
func resolveArgumentVariables(s *Sqlcmd, arg []rune) string {
368+
var b *strings.Builder
369+
end := len(arg)
370+
for i := 0; i < end; {
371+
c, next := arg[i], grab(arg, i+1, end)
372+
switch {
373+
case c == '$' && next == '(':
374+
vl, ok := readVariableReference(arg, i+2, end)
375+
if ok {
376+
varName := string(arg[i+2 : vl])
377+
val, ok := s.resolveVariable(varName)
378+
if ok {
379+
if b == nil {
380+
b = new(strings.Builder)
381+
b.Grow(len(arg))
382+
b.WriteString(string(arg[0:i]))
383+
}
384+
b.WriteString(val)
385+
} else {
386+
_, _ = s.GetError().Write([]byte(UndefinedVariable(varName).Error() + SqlcmdEol))
387+
if b != nil {
388+
b.WriteString(string(arg[i : vl+1]))
389+
}
390+
}
391+
i += ((vl - i) + 1)
392+
} else {
393+
if b != nil {
394+
b.WriteString("$(")
395+
}
396+
i += 2
397+
}
398+
default:
399+
if b != nil {
400+
b.WriteRune(c)
401+
}
402+
i++
403+
}
404+
}
405+
if b == nil {
406+
return string(arg)
407+
}
408+
return b.String()
409+
}

pkg/sqlcmd/commands_test.go

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111
"testing"
1212

13+
"github.com/microsoft/go-mssqldb/azuread"
1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/require"
1516
)
@@ -44,6 +45,7 @@ func TestCommandParsing(t *testing.T) {
4445
{`:EXIT ( )`, "EXIT", []string{"( )"}},
4546
{`EXIT `, "EXIT", []string{""}},
4647
{`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}},
48+
{`:r c:\$(var)\file.sql`, "READFILE", []string{`c:\$(var)\file.sql`}},
4749
}
4850

4951
for _, test := range commands {
@@ -156,7 +158,7 @@ func TestListCommand(t *testing.T) {
156158
}
157159

158160
func TestConnectCommand(t *testing.T) {
159-
s, _ := setupSqlCmdWithMemoryOutput(t)
161+
s, buf := setupSqlCmdWithMemoryOutput(t)
160162
prompted := false
161163
s.lineIo = &testConsole{
162164
OnPasswordPrompt: func(prompt string) ([]byte, error) {
@@ -174,19 +176,26 @@ func TestConnectCommand(t *testing.T) {
174176
c := newConnect(t)
175177

176178
authenticationMethod := ""
177-
if c.Password == "" {
178-
c.UserName = os.Getenv("AZURE_CLIENT_ID") + "@" + os.Getenv("AZURE_TENANT_ID")
179-
c.Password = os.Getenv("AZURE_CLIENT_SECRET")
180-
authenticationMethod = "-G ActiveDirectoryServicePrincipal"
181-
if c.Password == "" {
182-
t.Log("Not trying :Connect with valid password due to no password being available")
183-
return
184-
}
185-
err = connectCommand(s, []string{fmt.Sprintf("%s -U %s -P %s %s", c.ServerName, c.UserName, c.Password, authenticationMethod)}, 3)
186-
assert.NoError(t, err, "connectCommand with valid parameters should not return an error")
179+
password := ""
180+
username := ""
181+
if canTestAzureAuth() {
182+
authenticationMethod = "-G " + azuread.ActiveDirectoryDefault
183+
}
184+
if c.Password != "" {
185+
password = "-P " + c.Password
186+
}
187+
if c.UserName != "" {
188+
username = "-U " + c.UserName
189+
}
190+
s.vars.Set("servername", c.ServerName)
191+
s.vars.Set("to", "111")
192+
buf.buf.Reset()
193+
err = connectCommand(s, []string{fmt.Sprintf("$(servername) %s %s %s -l $(to)", username, password, authenticationMethod)}, 3)
194+
if assert.NoError(t, err, "connectCommand with valid parameters should not return an error") {
187195
// not using assert to avoid printing passwords in the log
188-
if s.Connect.UserName != c.UserName || c.Password != s.Connect.Password {
189-
t.Fatal("After connect, sqlCmd.Connect is not updated")
196+
assert.NotContains(t, buf.buf.String(), "$(servername)", "ConnectDB should have succeeded")
197+
if s.Connect.UserName != c.UserName || c.Password != s.Connect.Password || s.Connect.LoginTimeoutSeconds != 111 {
198+
t.Fatalf("After connect, sqlCmd.Connect is not updated %+v", s.Connect)
190199
}
191200
}
192201
}
@@ -212,3 +221,30 @@ func TestErrorCommand(t *testing.T) {
212221
assert.Regexp(t, "Msg 50000, Level 16, State 1, Server .*, Line 2"+SqlcmdEol+"Error"+SqlcmdEol, string(errText), "Error file contents")
213222
}
214223
}
224+
225+
func TestResolveArgumentVariables(t *testing.T) {
226+
type argTest struct {
227+
arg string
228+
val string
229+
err string
230+
}
231+
232+
args := []argTest{
233+
{"$(var1)", "var1val", ""},
234+
{"$(var1", "$(var1", ""},
235+
{`C:\folder\$(var1)\$(var2)\$(var1)\file.sql`, `C:\folder\var1val\$(var2)\var1val\file.sql`, "Sqlcmd: Error: 'var2' scripting variable not defined."},
236+
{`C:\folder\$(var1\$(var2)\$(var1)\file.sql`, `C:\folder\$(var1\$(var2)\var1val\file.sql`, "Sqlcmd: Error: 'var2' scripting variable not defined."},
237+
}
238+
vars := InitializeVariables(false)
239+
s := New(nil, "", vars)
240+
s.vars.Set("var1", "var1val")
241+
buf := &memoryBuffer{buf: new(bytes.Buffer)}
242+
defer buf.Close()
243+
s.SetError(buf)
244+
for _, test := range args {
245+
actual := resolveArgumentVariables(s, []rune(test.arg))
246+
assert.Equal(t, test.val, actual, "Incorrect argument parsing of "+test.arg)
247+
assert.Contains(t, buf.buf.String(), test.err, "Error output mismatch for "+test.arg)
248+
buf.buf.Reset()
249+
}
250+
}

pkg/sqlcmd/parse.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func readCommand(c Commands, r []rune, i, end int) (*Command, []string, int) {
5858
return cmd, args, i
5959
}
6060

61-
// readVariableReference returns the length of the variable reference or false if it's not a valid identifier
61+
// readVariableReference returns the index of the end of the variable reference or false if it's not a valid identifier
6262
func readVariableReference(r []rune, i int, end int) (int, bool) {
6363
for ; i < end; i++ {
6464
if r[i] == ')' {

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ func TestGetRunnableQuery(t *testing.T) {
217217
func TestExitInitialQuery(t *testing.T) {
218218
s, buf := setupSqlCmdWithMemoryOutput(t)
219219
defer buf.Close()
220-
s.Query = "EXIT(SELECT '1200', 2100)"
220+
_ = s.vars.Setvar("var1", "1200")
221+
s.Query = "EXIT(SELECT '$(var1)', 2100)"
221222
err := s.Run(true, false)
222223
if assert.NoError(t, err, "s.Run(once = true)") {
223224
s.SetOutput(nil)

0 commit comments

Comments
 (0)