diff --git a/.gitignore b/.gitignore index 22a0a10a..991d4e1f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ *.dll *.so *.dylib +bin +obj # Test binary, built with `go test -c` *.test @@ -21,3 +23,7 @@ testresults.xml # .syso is generated by go-winres. Only needed for official builds *.syso + +# IDE files +.idea + diff --git a/LICENSE b/LICENSE index 9e841e7a..b381ec3f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,21 @@ - MIT License +MIT License (MIT) - Copyright (c) Microsoft Corporation. +Copyright © Microsoft Corp. - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/cmd/cmd.go b/cmd/cmd.go deleted file mode 100644 index 75443b01..00000000 --- a/cmd/cmd.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package cmd - -import ( - "github.com/microsoft/go-sqlcmd/cmd/root" - "github.com/microsoft/go-sqlcmd/internal" - "github.com/microsoft/go-sqlcmd/internal/cmdparser" - "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" -) - -var loggingLevel int -var outputType string -var configFilename string -var rootCmd cmdparser.Command - -// Initialize initializes the command-line interface. The func passed into -// cmdparser.Initialize is called after the command-line from the user has been -// parsed, so the helpers are initialized with the values from the command-line -// like '-v 4' which sets the logging level to maximum etc. -func Initialize() { - cmdparser.Initialize(initialize) - rootCmd = cmdparser.New[*Root](root.SubCommands()...) -} - -func initialize() { - options := internal.InitializeOptions{ - ErrorHandler: checkErr, - HintHandler: displayHints, - OutputType: "yaml", - LoggingLevel: 2, - } - - config.SetFileName(configFilename) - config.Load() - internal.Initialize(options) -} - -// Execute runs the application based on the command-line -// parameters the user has passed in. -func Execute() { - rootCmd.Execute() -} - -// IsValidSubCommand is TEMPORARY code, that will be removed when -// we enable the new cobra based CLI by default. It returns true if the -// command-line provided by the user indicates they want the new cobra -// based CLI, e.g. sqlcmd install, or sqlcmd query, or sqlcmd --help etc. -func IsValidSubCommand(command string) bool { - return rootCmd.IsSubCommand(command) -} - -// checkErr uses Cobra to check err, and halts the application if err is not -// nil. Pass (inject) checkErr into all dependencies (helpers etc.) as an -// errorHandler. -// -// To aid debugging issues, if the logging level is > 2 (e.g. -v 3 or -4), we -// panic which outputs a stacktrace. -func checkErr(err error) { - if loggingLevel > 2 { - if err != nil { - panic(err) - } - } - rootCmd.CheckErr(err) -} - -// displayHints displays helpful information on what the user should do next -// to make progress. displayHints is injected into dependencies (helpers etc.) -func displayHints(hints []string) { - if len(hints) > 0 { - output.Infof("\nHINT:") - for i, hint := range hints { - output.Infof(" %d. %v", i+1, hint) - } - } -} diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go deleted file mode 100644 index 5225a113..00000000 --- a/cmd/cmd_test.go +++ /dev/null @@ -1,315 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package cmd - -import ( - "errors" - "fmt" - "github.com/microsoft/go-sqlcmd/cmd/root" - "github.com/microsoft/go-sqlcmd/internal" - "github.com/microsoft/go-sqlcmd/internal/cmdparser" - "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" - "github.com/microsoft/go-sqlcmd/internal/pal" - "os" - "runtime" - "strings" - "testing" -) - -// Set to true to run unit tests without a network connection -var offlineMode = false -var useCached = "" -var encryptPassword = "" - -type test struct { - name string - args struct{ args []string } -} - -func init() { - if runtime.GOOS == "windows" { - encryptPassword = " --encrypt-password" - } -} - -func TestCommandLineHelp(t *testing.T) { - setup(t.Name()) - tests := []test{ - {"default", split("--help")}, - } - run(t, tests) -} - -func TestNegCommandLines(t *testing.T) { - setup(t.Name()) - tests := []test{ - {"neg-config-use-context-double-name", - split("config use-context badbad --name andbad")}, - {"neg-config-use-context-bad-name", - split("config use-context badbad")}, - {"neg-config-get-contexts-bad-context", - split("config get-contexts badbad")}, - {"neg-config-get-endpoints-bad-endpoint", - split("config get-endpoints badbad")}, - {"neg-install-no-eula", - split("install mssql")}, - } - run(t, tests) -} - -func TestConfigContexts(t *testing.T) { - setup(t.Name()) - tests := []test{ - {"neg-config-add-context-no-endpoint", - split("config add-context")}, - {"config-add-endpoint", - split("config add-endpoint --address localhost --port 1433")}, - {"config-add-endpoint", - split("config add-endpoint --address localhost --port 1433")}, - {"neg-config-add-context-bad-user", - split("config add-context --endpoint endpoint --user badbad")}, - {"config-get-endpoints", - split("config get-endpoints endpoint")}, - {"config-get-endpoints", - split("config get-endpoints")}, - {"config-get-endpoints", - split("config get-endpoints --detailed")}, - {"config-add-context", - split("config add-context --endpoint endpoint")}, - /*{"uninstall-but-context-has-no-container", - split("uninstall --force --yes")},*/ - {"config-add-endpoint", - split("config add-endpoint")}, - {"config-add-context", - split("config add-context --endpoint endpoint")}, - {"config-use-context", - split("config use-context context")}, - {"config-get-contexts", - split("config get-contexts context")}, - {"config-get-contexts", - split("config get-contexts")}, - {"config-get-contexts", - split("config get-contexts --detailed")}, - {"config-delete-context", - split("config delete-context context --cascade")}, - {"neg-config-delete-context", - split("config delete-context")}, - {"neg-config-delete-context", - split("config delete-context badbad-name")}, - - {"cleanup", - split("config delete-endpoint endpoint2")}, - {"cleanup", - split("config delete-endpoint endpoint3")}, - {"cleanup", - split("config delete-context context2")}, - } - - run(t, tests) -} - -func TestConfigUsers(t *testing.T) { - setup(t.Name()) - tests := []test{ - {"neg-config-get-users-bad-user", - split("config get-users badbad")}, - {"config-add-user", - split("config add-user --username foobar")}, - {"config-add-user", - split("config add-user --username foobar")}, - {"config-get-users", - split("config get-users user")}, - {"config-get-users", - split("config get-users")}, - {"config-get-users", - split("config get-users --detailed")}, - {"neg-config-add-user-no-username", - split("config add-user")}, - {"neg-config-add-user-no-password", - split("config add-user --username foobar")}, - - // Cleanup - {"cleanup", - split("config delete-user user")}, - {"cleanup", - split("config delete-user user2")}, - } - - run(t, tests) -} - -func TestLocalContext(t *testing.T) { - setup(t.Name()) - - tests := []test{ - {"neg-config-delete-endpoint-no-name", - split("config delete-endpoint")}, - {"config-add-endpoint", - split("config add-endpoint --address localhost --port 1433")}, - {"config-add-user", - split("config add-user --username foobar")}, - {"config-add-context", - split("config add-context --user user --endpoint endpoint --name my-context")}, - {"config-delete-context-cascade", - split("config delete-context my-context --cascade")}, - {"config-view", - split("config view")}, - {"config-view", - split("config view --raw")}, - - {"neg-config-add-user-bad-auth-type", - split("config add-user --username foobar --auth-type badbad")}, - } - - if len(encryptPassword) > 2 { // are we on a platform that supports encryption - tests = append(tests, test{"neg-config-add-user-bad-use-encrypted", - split(fmt.Sprintf("config add-user --username foobar --auth-type other%v", encryptPassword))}) - } - - run(t, tests) -} - -func TestGetTags(t *testing.T) { - setup(t.Name()) - tests := []test{ - {"get-tags", - split("install mssql get-tags")}, - } - - run(t, tests) -} - -func TestMssqlInstall(t *testing.T) { - setup(t.Name()) - tests := []test{ - {"install", - split(fmt.Sprintf("install mssql%v --user-database my-database --accept-eula%v", useCached, encryptPassword))}, - {"config-current-context", - split("config current-context")}, - {"config-connection-strings", - split("config connection-strings")}, - {"query", - split("query GO")}, - {"query", - split("query")}, - {"neg-query-two-queries", - split("query bad --query bad")}, - - /* How to get code coverage for user input - {"neg-uninstall-no-yes", - split("uninstall")},*/ - {"uninstall", - split("uninstall --yes --force")}, - } - - run(t, tests) -} - -func runTests(t *testing.T, tt struct { - name string - args struct{ args []string } -}) { - cmd := cmdparser.New[*Root](root.SubCommands()...) - cmd.ArgsForUnitTesting(tt.args.args) - - t.Logf("Running: %v", tt.args.args) - - if tt.name == "neg-config-add-user-no-password" { - os.Setenv("SQLCMD_PASSWORD", "") - } else { - os.Setenv("SQLCMD_PASSWORD", "badpass") - } - - // If test name starts with 'neg-' expect a Panic - if strings.HasPrefix(tt.name, "neg-") { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - cmd.Execute() - } - cmd.Execute() -} - -func Test_displayHints(t *testing.T) { - displayHints([]string{"Test Hint"}) -} - -func TestIsValidRootCommand(t *testing.T) { - Initialize() - IsValidSubCommand("install") - IsValidSubCommand("create") - IsValidSubCommand("nope") -} - -func TestRunCommand(t *testing.T) { - loggingLevel = 4 - Execute() -} - -func Test_checkErr(t *testing.T) { - loggingLevel = 3 - - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - - checkErr(errors.New("Expected error")) -} - -func run(t *testing.T, tests []test) { - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { runTests(t, tt) }) - } - - verifyConfigIsEmpty(t) -} - -func verifyConfigIsEmpty(t *testing.T) { - if !config.IsEmpty() { - bytes := output.Struct(config.GetRedactedConfig(true)) - t.Errorf("Config is not empty. Content of config file:\n%s\nConfig file used:%s", - string(bytes), - config.GetConfigFileUsed()) - t.Fail() - } -} - -func setup(testName string) { - useCached = " --cached" - if !offlineMode { - useCached = "" - } - - errorHandler := func(err error) { - if err != nil { - panic(err) - } - } - - options := internal.InitializeOptions{ - ErrorHandler: errorHandler, - HintHandler: displayHints, - OutputType: "yaml", - LoggingLevel: 4, - } - internal.Initialize(options) - config.SetFileName(pal.FilenameInUserHomeDotDirectory( - ".sqlcmd", - "sqlconfig-"+testName, - )) - config.Clean() -} - -type args struct { - args []string -} - -func split(cmd string) args { - return args{strings.Split(cmd, " ")} -} diff --git a/cmd/modern/build.cmd b/cmd/modern/build.cmd new file mode 100644 index 00000000..0f08eaec --- /dev/null +++ b/cmd/modern/build.cmd @@ -0,0 +1 @@ +go build -o sqlcmd.exe \ No newline at end of file diff --git a/cmd/modern/build.sh b/cmd/modern/build.sh new file mode 100644 index 00000000..91180b11 --- /dev/null +++ b/cmd/modern/build.sh @@ -0,0 +1 @@ +go build -o sqlcmd diff --git a/doc.go b/cmd/modern/doc.go similarity index 100% rename from doc.go rename to cmd/modern/doc.go diff --git a/cmd/modern/main.go b/cmd/modern/main.go new file mode 100644 index 00000000..9ff912d1 --- /dev/null +++ b/cmd/modern/main.go @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// Package main is the entrypoint for sqlcmd. This package first initializes +// a new instance of the Root cmd then checks if the new cobra-based +// command-line interface (CLI) should be used based on if the first argument provided +// by the user is a valid sub-command for the new CLI, if so it executes the +// new cobra CLI; otherwise, it falls back to the old kong-based CLI. +package main + +import ( + "github.com/microsoft/go-sqlcmd/internal" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/output" + "github.com/microsoft/go-sqlcmd/internal/output/verbosity" + "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" + "github.com/spf13/cobra" + + "os" + + legacyCmd "github.com/microsoft/go-sqlcmd/cmd/sqlcmd" +) + +var rootCmd *Root +var outputter *output.Output + +// main is the entry point for the sqlcmd command line interface. +// It parses command line options and initializes the command parser. +// If the first argument is a modern CLI subcommand, the modern CLI is +// executed. Otherwise, the legacy CLI is executed. +func main() { + dependencies := dependency.Options{ + Output: output.New(output.Options{ + StandardWriter: os.Stdout, + ErrorHandler: checkErr, + HintHandler: displayHints})} + rootCmd = cmdparser.New[*Root](dependencies) + + if isFirstArgModernCliSubCommand() { + cmdparser.Initialize(initializeCallback) + rootCmd.Execute() + } else { + legacyCmd.Execute() + } +} + +// isFirstArgModernCliSubCommand is TEMPORARY code, to be removed when +// we remove the Kong based CLI +func isFirstArgModernCliSubCommand() (isNewCliCommand bool) { + if len(os.Args) > 0 { + if rootCmd.IsValidSubCommand(os.Args[1]) { + isNewCliCommand = true + } + } + return +} + +// initializeCallback is called after the command line has been parsed and +// all values provided by the user are now available +func initializeCallback() { + + // Assigns a new outputter now that we have the outputType and loggingLevel + // provided to us from the user + outputter = output.New( + output.Options{ + StandardWriter: os.Stdout, + ErrorHandler: checkErr, + HintHandler: displayHints, + OutputType: rootCmd.outputType, + LoggingLevel: verbosity.Level(rootCmd.loggingLevel), + }) + rootCmd.SetCrossCuttingConcerns( + dependency.Options{ + EndOfLine: sqlcmd.SqlcmdEol, + Output: outputter, + }) + internal.Initialize( + internal.InitializeOptions{ + ErrorHandler: checkErr, + TraceHandler: outputter.Tracef, + HintHandler: displayHints, + LineBreak: sqlcmd.SqlcmdEol, + }) + config.SetFileName(rootCmd.configFilename) + config.Load() +} + +// checkErr uses Cobra to check err, and halts the application if err is not +// nil. Pass (inject) checkErr into all dependencies (internal helpers etc.) as an +// errorHandler. +// +// To aid debugging issues, if the logging level is > 2 (e.g. -v 3 or -v 4), we +// panic which outputs a stacktrace. +func checkErr(err error) { + if rootCmd.loggingLevel > 2 { + if err != nil { + panic(err) + } + } else { + cobra.CheckErr(err) + } +} + +// displayHints displays helpful information on what the user should do next +// to make progress. displayHints is injected into dependencies (helpers etc.) +func displayHints(hints []string) { + if len(hints) > 0 { + outputter.Infof("%vHINT:", sqlcmd.SqlcmdEol) + for i, hint := range hints { + outputter.Infof(" %d. %v", i+1, hint) + } + } +} diff --git a/cmd/modern/main_test.go b/cmd/modern/main_test.go new file mode 100644 index 00000000..0df65471 --- /dev/null +++ b/cmd/modern/main_test.go @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package main + +import ( + "errors" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/microsoft/go-sqlcmd/internal/output" + "github.com/microsoft/go-sqlcmd/internal/pal" + "github.com/microsoft/go-sqlcmd/internal/test" + "github.com/stretchr/testify/assert" + "os" + "testing" +) + +func TestMainStart(t *testing.T) { + os.Args[1] = "--help" + main() +} + +func TestInitializeCallback(t *testing.T) { + rootCmd = cmdparser.New[*Root](dependency.Options{}) + initializeCallback() +} + +func TestDisplayHints(t *testing.T) { + buf := test.NewMemoryBuffer() + defer buf.Close() + outputter = output.New(output.Options{StandardWriter: buf}) + displayHints([]string{"This is a hint"}) + assert.Equal(t, pal.LineBreak()+ + "HINT:"+ + pal.LineBreak()+ + " 1. This is a hint"+pal.LineBreak(), buf.String()) +} + +func TestCheckErr(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + rootCmd = cmdparser.New[*Root](dependency.Options{}) + rootCmd.loggingLevel = 4 + checkErr(nil) + checkErr(errors.New("test error")) +} diff --git a/cmd/globaloptions.go b/cmd/modern/options.go similarity index 97% rename from cmd/globaloptions.go rename to cmd/modern/options.go index 76c421ac..ff70d2c3 100644 --- a/cmd/globaloptions.go +++ b/cmd/modern/options.go @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -package cmd +package main type GlobalOptions struct { TrustServerCertificate bool diff --git a/cmd/modern/root.go b/cmd/modern/root.go new file mode 100644 index 00000000..22e6f661 --- /dev/null +++ b/cmd/modern/root.go @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package main + +import ( + "github.com/microsoft/go-sqlcmd/cmd/modern/root" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/config" +) + +// Root type implements the very top-level command for sqlcmd (which contains +// all the sub-commands, like install, query, config etc. +type Root struct { + cmdparser.Cmd + + configFilename string + loggingLevel int + outputType string +} + +// DefineCommand defines the top-level sqlcmd sub-commands. +// It sets the cli name, description, and subcommands, and adds global flags. +// It also provides usage examples for sqlcmd. +func (c *Root) DefineCommand(...cmdparser.CommandOptions) { + examples := []cmdparser.ExampleOptions{ + { + Description: "Install, Query, Uninstall SQL Server", + Steps: []string{ + "sqlcmd install mssql", + `sqlcmd query "SELECT @@version"`, + "sqlcmd uninstall"}}} + + commandOptions := cmdparser.CommandOptions{ + Use: "sqlcmd", + Short: "sqlcmd: command-line interface for the #SQLFamily", + SubCommands: c.SubCommands(), + Examples: examples, + } + + c.Cmd.DefineCommand(commandOptions) + c.addGlobalFlags() +} + +// SubCommands returns a slice of subcommands for the Root command. +// The returned subcommands are Config, Install, query, and Uninstall. +func (c *Root) SubCommands() []cmdparser.Command { + dependencies := c.Dependencies() + + return []cmdparser.Command{ + cmdparser.New[*root.Config](dependencies), + cmdparser.New[*root.Install](dependencies), + cmdparser.New[*root.Query](dependencies), + cmdparser.New[*root.Uninstall](dependencies), + } +} + +// Execute runs the application based on the command-line +// parameters the user has passed in. +func (c *Root) Execute() { + c.Cmd.Execute() +} + +// IsValidSubCommand is TEMPORARY code, that will be removed when +// we enable the new cobra based CLI by default. It returns true if the +// command-line provided by the user indicates they want the new cobra +// based CLI, e.g. sqlcmd install, or sqlcmd query, or sqlcmd --help etc. +func (c *Root) IsValidSubCommand(command string) bool { + return c.IsSubCommand(command) +} + +func (c *Root) addGlobalFlags() { + c.AddFlag(cmdparser.FlagOptions{ + Bool: &globalOptions.TrustServerCertificate, + Name: "trust-server-certificate", + Shorthand: "C", + Usage: "Whether to trust the certificate presented by the endpoint for encryption", + }) + + c.AddFlag(cmdparser.FlagOptions{ + String: &globalOptions.DatabaseName, + Name: "database-name", + Shorthand: "d", + Usage: "The initial database for the connection", + }) + + c.AddFlag(cmdparser.FlagOptions{ + Bool: &globalOptions.UseTrustedConnection, + Name: "use-trusted-connection", + Shorthand: "E", + Usage: "Whether to use integrated security", + }) + + c.AddFlag(cmdparser.FlagOptions{ + String: &c.configFilename, + DefaultString: config.DefaultFileName(), + Name: "sqlconfig", + Usage: "Configuration file", + }) + + c.AddFlag(cmdparser.FlagOptions{ + String: &c.outputType, + DefaultString: "yaml", + Name: "output", + Shorthand: "o", + Usage: "output type (yaml, json or xml)", + }) + + c.AddFlag(cmdparser.FlagOptions{ + Int: (*int)(&c.loggingLevel), + DefaultInt: 2, + Name: "verbosity", + Shorthand: "v", + Usage: "Log level, error=0, warn=1, info=2, debug=3, trace=4", + }) +} diff --git a/cmd/modern/root/config.go b/cmd/modern/root/config.go new file mode 100644 index 00000000..b695b5f8 --- /dev/null +++ b/cmd/modern/root/config.go @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package root + +import ( + "github.com/microsoft/go-sqlcmd/cmd/modern/root/config" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" +) + +// Config defines the `sqlcmd config` sub-commands +type Config struct { + cmdparser.Cmd +} + +// DefineCommand defines the `sqlcmd config` command, which is only +// more sub-commands (`sqlcmd config` does not `run` anything itself) +func (c *Config) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ + Use: "config", + Short: `Modify sqlconfig files using subcommands like "sqlcmd config use-context mssql"`, + SubCommands: c.SubCommands(), + } + c.Cmd.DefineCommand(options) +} + +// SubCommands sets up all the sub-commands for `sqlcmd config` +func (c *Config) SubCommands() []cmdparser.Command { + dependencies := c.Dependencies() + + return []cmdparser.Command{ + cmdparser.New[*config.AddContext](dependencies), + cmdparser.New[*config.AddEndpoint](dependencies), + cmdparser.New[*config.AddUser](dependencies), + cmdparser.New[*config.ConnectionStrings](dependencies), + cmdparser.New[*config.CurrentContext](dependencies), + cmdparser.New[*config.DeleteContext](dependencies), + cmdparser.New[*config.DeleteEndpoint](dependencies), + cmdparser.New[*config.DeleteUser](dependencies), + cmdparser.New[*config.GetContexts](dependencies), + cmdparser.New[*config.GetEndpoints](dependencies), + cmdparser.New[*config.GetUsers](dependencies), + cmdparser.New[*config.UseContext](dependencies), + cmdparser.New[*config.View](dependencies), + } +} diff --git a/cmd/root/config/add-context.go b/cmd/modern/root/config/add-context.go similarity index 74% rename from cmd/root/config/add-context.go rename to cmd/modern/root/config/add-context.go index a4d9d0c6..9ed269dc 100644 --- a/cmd/root/config/add-context.go +++ b/cmd/modern/root/config/add-context.go @@ -5,12 +5,12 @@ package config import ( "fmt" - "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// AddContext implements the `sqlcmd config add-context` command type AddContext struct { cmdparser.Cmd @@ -19,18 +19,18 @@ type AddContext struct { userName string } -func (c *AddContext) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *AddContext) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "add-context", Short: "Add a context", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "Add a default context", Steps: []string{"sqlcmd config add-context --name my-context"}}, }, Run: c.run} - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -49,7 +49,13 @@ func (c *AddContext) DefineCommand(...cmdparser.Command) { Usage: "Name of user this context will use, use `sqlcmd config get-users` to see list"}) } +// run adds a context to the configuration and sets it as the current context. The +// context consists of an endpoint and an optional user. The function checks +// if the specified endpoint and user exist and if not, it returns an error with +// suggestions on how to create them. If the context is successfully added, it +// outputs a message indicating the current context. func (c *AddContext) run() { + output := c.Output() context := sqlconfig.Context{ ContextDetails: sqlconfig.ContextDetails{ Endpoint: c.endpointName, @@ -67,7 +73,7 @@ func (c *AddContext) run() { } if c.userName != "" { - if !config.UserExists(c.userName) { + if !config.UserNameExists(c.userName) { output.FatalfWithHintExamples([][]string{ {"View list of users", "sqlcmd config get-users"}, {"Add the user", fmt.Sprintf("sqlcmd config add-user --name %v", c.userName)}, @@ -81,6 +87,5 @@ func (c *AddContext) run() { output.InfofWithHintExamples([][]string{ {"To start interactive query session", "sqlcmd query"}, {"To run a query", "sqlcmd query \"SELECT @@version\""}, - }, - "Current Context '%v'", context.Name) + }, "Current Context '%v'", context.Name) } diff --git a/cmd/modern/root/config/add-context_test.go b/cmd/modern/root/config/add-context_test.go new file mode 100644 index 00000000..32774ed4 --- /dev/null +++ b/cmd/modern/root/config/add-context_test.go @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +func TestAddContext(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddEndpoint]() + cmdparser.TestCmd[*AddContext]("--endpoint endpoint") +} + +func TestNegAddContext(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddContext]("--endpoint does-not-exist") +} + +func TestNegAddContext2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddEndpoint]() + cmdparser.TestCmd[*AddContext]("--endpoint endpoint --user does-not-exist") +} diff --git a/cmd/root/config/add-endpoint.go b/cmd/modern/root/config/add-endpoint.go similarity index 73% rename from cmd/root/config/add-endpoint.go rename to cmd/modern/root/config/add-endpoint.go index a2436e67..57a4bb86 100644 --- a/cmd/root/config/add-endpoint.go +++ b/cmd/modern/root/config/add-endpoint.go @@ -5,12 +5,13 @@ package config import ( "fmt" - "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// AddEndpoint implements the `sqlcmd config add-endpoint` command type AddEndpoint struct { cmdparser.Cmd @@ -19,20 +20,19 @@ type AddEndpoint struct { port int } -func (c *AddEndpoint) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *AddEndpoint) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "add-endpoint", Short: "Add an endpoint", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "Add a default endpoint", Steps: []string{"sqlcmd config add-endpoint --name my-endpoint --address localhost --port 1433"}, }, }, - Run: c.run, - } + Run: c.run} - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -56,10 +56,12 @@ func (c *AddEndpoint) DefineCommand(...cmdparser.Command) { }) } +// run adds an endpoint to the system configuration. It creates a sqlconfig.Endpoint +// struct with the given parameters, and then adds the struct to the system configuration +// using the config.AddEndpoint method. If the endpoint is successfully added, it prints +// a message with information about the endpoint and hints on how to use the endpoint. func (c *AddEndpoint) run() { - if c.name == "containerId" { - panic("containerId") - } + output := c.Output() endpoint := sqlconfig.Endpoint{ EndpointDetails: sqlconfig.EndpointDetails{ diff --git a/cmd/modern/root/config/add-endpoint_test.go b/cmd/modern/root/config/add-endpoint_test.go new file mode 100644 index 00000000..dc9d87f0 --- /dev/null +++ b/cmd/modern/root/config/add-endpoint_test.go @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "testing" +) + +func TestAddEndpoint(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddEndpoint]() +} diff --git a/cmd/root/config/add-user.go b/cmd/modern/root/config/add-user.go similarity index 72% rename from cmd/root/config/add-user.go rename to cmd/modern/root/config/add-user.go index 41055e8c..346cb375 100644 --- a/cmd/root/config/add-user.go +++ b/cmd/modern/root/config/add-user.go @@ -4,14 +4,15 @@ package config import ( - "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "os" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" "github.com/microsoft/go-sqlcmd/internal/secret" - "os" ) +// AddUser implements the `sqlcmd config add-user` command type AddUser struct { cmdparser.Cmd @@ -21,11 +22,11 @@ type AddUser struct { encryptPassword bool } -func (c *AddUser) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *AddUser) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "add-user", Short: "Add a user", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "Add a user", Steps: []string{ @@ -33,10 +34,9 @@ func (c *AddUser) DefineCommand(...cmdparser.Command) { "sqlcmd config add-user --name my-user --name user1"}, }, }, - Run: c.run, - } + Run: c.run} - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -61,7 +61,17 @@ func (c *AddUser) DefineCommand(...cmdparser.Command) { c.encryptPasswordFlag() } +// run a user to the configuration. It sets the user's name and +// authentication type, and, if the authentication type is 'basic', it sets the +// user's username and password (either in plain text or encrypted, depending +// on the --encrypt-password flag). If the user's authentication type is not 'basic' +// or 'other', an error is thrown. If the --encrypt-password flag is set but the +// authentication type is not 'basic', an error is thrown. If the authentication +// type is 'basic' but the username or password is not provided, an error is thrown. +// If the username is provided but the password is not, an error is thrown. func (c *AddUser) run() { + output := c.Output() + if c.authType != "basic" && c.authType != "other" { output.FatalfWithHints([]string{"Authentication type must be 'basic' or 'other'"}, diff --git a/cmd/root/config/add-user_darwin.go b/cmd/modern/root/config/add-user_darwin.go similarity index 100% rename from cmd/root/config/add-user_darwin.go rename to cmd/modern/root/config/add-user_darwin.go diff --git a/cmd/root/config/add-user_linux.go b/cmd/modern/root/config/add-user_linux.go similarity index 100% rename from cmd/root/config/add-user_linux.go rename to cmd/modern/root/config/add-user_linux.go diff --git a/cmd/modern/root/config/add-user_test.go b/cmd/modern/root/config/add-user_test.go new file mode 100644 index 00000000..a71e539c --- /dev/null +++ b/cmd/modern/root/config/add-user_test.go @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/test" + "os" + "testing" +) + +func TestAddUser(t *testing.T) { + os.Setenv("SQLCMD_PASSWORD", "it's-a-secret") + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddUser]("--username user1") +} + +func TestNegAddUser(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddUser]("--username user1 --auth-type bad-bad") +} + +func TestNegAddUser2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddUser]("--username user1 --auth-type other --encrypt-password") +} + +func TestNegAddUser3(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + os.Setenv("SQLCMD_PASSWORD", "") + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddUser]("--username user1") +} + +func TestNegAddUser4(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + os.Setenv("SQLCMD_PASSWORD", "whatever") + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddUser]() +} diff --git a/cmd/root/config/add-user_windows.go b/cmd/modern/root/config/add-user_windows.go similarity index 100% rename from cmd/root/config/add-user_windows.go rename to cmd/modern/root/config/add-user_windows.go diff --git a/cmd/root/config/connection-strings.go b/cmd/modern/root/config/connection-strings.go similarity index 54% rename from cmd/root/config/connection-strings.go rename to cmd/modern/root/config/connection-strings.go index 12ad9d7c..d47aa69f 100644 --- a/cmd/root/config/connection-strings.go +++ b/cmd/modern/root/config/connection-strings.go @@ -5,22 +5,23 @@ package config import ( "fmt" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/microsoft/go-sqlcmd/internal/secret" ) +// ConnectionStrings implements the `sqlcmd config connection-strings` command type ConnectionStrings struct { cmdparser.Cmd } -func (c *ConnectionStrings) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *ConnectionStrings) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "connection-strings", Short: "Display connections strings for the current context", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "List connection strings for all client drivers", Steps: []string{ @@ -32,37 +33,46 @@ func (c *ConnectionStrings) DefineCommand(...cmdparser.Command) { Aliases: []string{"cs"}, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) } +// run generates connection strings for the current context in multiple formats. +// The generated connection strings will include the current endpoint and user information. func (c *ConnectionStrings) run() { + output := c.Output() + // connectionStringFormats borrowed from "portal.azure.com" "connection strings" pane var connectionStringFormats = map[string]string{ - "ADO.NET": "Server=tcp:%s,%d;Initial Catalog=%s;Persist Security Options=False;User ID=%s;Password=%s;MultipleActiveResultSets=False;Encode=True;TrustServerCertificate=False;Connection Timeout=30;", + "ADO.NET": "Server=tcp:%s,%d;Initial Catalog=%s;Persist Security options=False;User ID=%s;Password=%s;MultipleActiveResultSets=False;Encode=True;TrustServerCertificate=False;Connection Timeout=30;", "JDBC": "jdbc:sqlserver://%s:%d;database=%s;user=%s;password=%s;encrypt=true;trustServerCertificate=false;loginTimeout=30;", "ODBC": "Driver={ODBC Driver 13 for SQL Server};Server=tcp:%s,%d;Database=%s;Uid=%s;Pwd=%s;Encode=yes;TrustServerCertificate=no;Connection Timeout=30;", } - endpoint, user := config.GetCurrentContext() - for k, v := range connectionStringFormats { - connectionStringFormats[k] = fmt.Sprintf(v, + endpoint, user := config.CurrentContext() + if user != nil { + for k, v := range connectionStringFormats { + connectionStringFormats[k] = fmt.Sprintf(v, + endpoint.EndpointDetails.Address, + endpoint.EndpointDetails.Port, + "master", + user.BasicAuth.Username, + secret.Decode(user.BasicAuth.Password, user.BasicAuth.PasswordEncrypted)) + } + + format := pal.CmdLineWithEnvVars( + []string{"SQLCMDPASSWORD=%s"}, + "sqlcmd -S %s,%d -U %s", + ) + + connectionStringFormats["SQLCMD"] = fmt.Sprintf(format, + secret.Decode(user.BasicAuth.Password, user.BasicAuth.PasswordEncrypted), endpoint.EndpointDetails.Address, endpoint.EndpointDetails.Port, - "master", - user.BasicAuth.Username, - secret.Decode(user.BasicAuth.Password, user.BasicAuth.PasswordEncrypted)) - } - - format := pal.CmdLineWithEnvVars( - []string{"SQLCMDPASSWORD=%s"}, - "sqlcmd -S %s,%d -U %s", - ) + user.BasicAuth.Username) - connectionStringFormats["SQLCMD"] = fmt.Sprintf(format, - secret.Decode(user.BasicAuth.Password, user.BasicAuth.PasswordEncrypted), - endpoint.EndpointDetails.Address, - endpoint.EndpointDetails.Port, - user.BasicAuth.Username) + output.Struct(connectionStringFormats) - output.Struct(connectionStringFormats) + } else { + output.Infof("Connection Strings only supported for Basic Auth type") + } } diff --git a/cmd/modern/root/config/connection-strings_test.go b/cmd/modern/root/config/connection-strings_test.go new file mode 100644 index 00000000..3ad54696 --- /dev/null +++ b/cmd/modern/root/config/connection-strings_test.go @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/output" + "os" + "testing" +) + +func TestConnectionStrings(t *testing.T) { + cmdparser.TestSetup(t) + + output := output.New(output.Options{HintHandler: func(hints []string) {}, ErrorHandler: func(err error) {}}) + options := internal.InitializeOptions{ + ErrorHandler: func(err error) { + if err != nil { + panic(err) + } + }, + HintHandler: func(strings []string) {}, + TraceHandler: output.Tracef, + LineBreak: "\n", + } + internal.Initialize(options) + + os.Setenv("SQLCMD_PASSWORD", "it's-a-secret") + + cmdparser.TestCmd[*AddEndpoint]() + cmdparser.TestCmd[*AddUser]("--username user") + cmdparser.TestCmd[*AddContext]("--endpoint endpoint --user user") + cmdparser.TestCmd[*ConnectionStrings]() +} diff --git a/cmd/root/config/current-context.go b/cmd/modern/root/config/current-context.go similarity index 58% rename from cmd/root/config/current-context.go rename to cmd/modern/root/config/current-context.go index d0e5ff3d..52f45c03 100644 --- a/cmd/root/config/current-context.go +++ b/cmd/modern/root/config/current-context.go @@ -6,30 +6,30 @@ package config import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// CurrentContext implements the `sqlcmd config current-context` command type CurrentContext struct { cmdparser.Cmd } -func (c *CurrentContext) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *CurrentContext) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "current-context", Short: "Display the current-context", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "Display the current-context", Steps: []string{ "sqlcmd config current-context"}, }, }, - Run: c.run, - } + Run: c.run} - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) } func (c *CurrentContext) run() { - output.Infof("%v\n", config.GetCurrentContextName()) + output := c.Output() + output.Infof("%v\n", config.CurrentContextName()) } diff --git a/cmd/modern/root/config/current-context_test.go b/cmd/modern/root/config/current-context_test.go new file mode 100644 index 00000000..e77ff509 --- /dev/null +++ b/cmd/modern/root/config/current-context_test.go @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "testing" +) + +func TestCurrentContext(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*CurrentContext]() +} diff --git a/cmd/root/config/delete-context.go b/cmd/modern/root/config/delete-context.go similarity index 64% rename from cmd/root/config/delete-context.go rename to cmd/modern/root/config/delete-context.go index a8bf92df..403ef5a1 100644 --- a/cmd/root/config/delete-context.go +++ b/cmd/modern/root/config/delete-context.go @@ -6,9 +6,9 @@ package config import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// DeleteContext implements the `sqlcmd config delete-context` command type DeleteContext struct { cmdparser.Cmd @@ -16,11 +16,11 @@ type DeleteContext struct { cascade bool } -func (c *DeleteContext) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *DeleteContext) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "delete-context", Short: "Delete a context", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "Delete a context", Steps: []string{ @@ -30,10 +30,10 @@ func (c *DeleteContext) DefineCommand(...cmdparser.Command) { }, Run: c.run, - FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagInfo{Flag: "name", Value: &c.name}, + FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{Flag: "name", Value: &c.name}, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -47,7 +47,13 @@ func (c *DeleteContext) DefineCommand(...cmdparser.Command) { Usage: "Delete the context's endpoint and user as well"}) } +// run is responsible for deleting a context in a configuration. It first checks if +// a name is provided and if the context exists. If the cascade flag is set, it will +// also delete the associated endpoint and user. It then deletes the context +// and prints a message to the output indicating the context has been deleted. func (c *DeleteContext) run() { + output := c.Output() + if c.name == "" { output.FatalWithHints([]string{"Use the --name flag to pass in a context name to delete"}, "A 'name' is required") @@ -58,8 +64,8 @@ func (c *DeleteContext) run() { if c.cascade { config.DeleteEndpoint(context.Endpoint) - if *context.User != "" { - config.DeleteUser(*context.User) + if *context.ContextDetails.User != "" { + config.DeleteUser(*context.ContextDetails.User) } } diff --git a/cmd/modern/root/config/delete-context_test.go b/cmd/modern/root/config/delete-context_test.go new file mode 100644 index 00000000..5037a12d --- /dev/null +++ b/cmd/modern/root/config/delete-context_test.go @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +func TestDeleteContext(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddUser]("--username user --auth-type other") + cmdparser.TestCmd[*AddEndpoint]() + cmdparser.TestCmd[*AddContext]("--endpoint endpoint --user user") + cmdparser.TestCmd[*DeleteContext]("--name context") +} + +func TestNegDeleteContext(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*DeleteContext]() +} + +func TestNegDeleteContext2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*DeleteContext]("--name does-not-exist") +} diff --git a/cmd/root/config/delete-endpoint.go b/cmd/modern/root/config/delete-endpoint.go similarity index 63% rename from cmd/root/config/delete-endpoint.go rename to cmd/modern/root/config/delete-endpoint.go index 37a03de0..e5980c2e 100644 --- a/cmd/root/config/delete-endpoint.go +++ b/cmd/modern/root/config/delete-endpoint.go @@ -5,22 +5,23 @@ package config import ( "fmt" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// DeleteEndpoint implements the `sqlcmd config delete-endpoint` command type DeleteEndpoint struct { cmdparser.Cmd name string } -func (c *DeleteEndpoint) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *DeleteEndpoint) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "delete-endpoint", Short: "Delete an endpoint", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "Delete an endpoint", Steps: []string{ @@ -30,10 +31,10 @@ func (c *DeleteEndpoint) DefineCommand(...cmdparser.Command) { }, Run: c.run, - FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagInfo{Flag: "name", Value: &c.name}, + FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{Flag: "name", Value: &c.name}, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -41,7 +42,12 @@ func (c *DeleteEndpoint) DefineCommand(...cmdparser.Command) { Usage: "Name of endpoint to delete"}) } +// run is used to delete an endpoint with the given name. If the specified endpoint +// does not exist, the function will print an error message and return. If the +// endpoint exists, it will be deleted and a success message will be printed. func (c *DeleteEndpoint) run() { + output := c.Output() + if c.name == "" { output.Fatal("Endpoint name must be provided. Provide endpoint name with --name flag") } diff --git a/cmd/modern/root/config/delete-endpoint_test.go b/cmd/modern/root/config/delete-endpoint_test.go new file mode 100644 index 00000000..ddf4b568 --- /dev/null +++ b/cmd/modern/root/config/delete-endpoint_test.go @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +func TestDeleteEndpoint(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddEndpoint]() + cmdparser.TestCmd[*DeleteEndpoint]("--name endpoint") +} + +func TestNegDeleteEndpoint(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*DeleteEndpoint]() +} + +func TestNegDeleteEndpoint2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*DeleteEndpoint]("--name does-not-exist") +} diff --git a/cmd/root/config/delete-user.go b/cmd/modern/root/config/delete-user.go similarity index 69% rename from cmd/root/config/delete-user.go rename to cmd/modern/root/config/delete-user.go index b061806e..284f8159 100644 --- a/cmd/root/config/delete-user.go +++ b/cmd/modern/root/config/delete-user.go @@ -6,20 +6,20 @@ package config import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// DeleteUser implements the `sqlcmd config delete-user` command type DeleteUser struct { cmdparser.Cmd name string } -func (c *DeleteUser) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *DeleteUser) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "delete-user", Short: "Delete a user", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "Delete a user", Steps: []string{ @@ -28,11 +28,11 @@ func (c *DeleteUser) DefineCommand(...cmdparser.Command) { }, Run: c.run, - FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagInfo{ + FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{ Flag: "name", Value: &c.name}, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -41,6 +41,8 @@ func (c *DeleteUser) DefineCommand(...cmdparser.Command) { } func (c *DeleteUser) run() { + output := c.Output() + config.DeleteUser(c.name) output.Infof("User '%v' deleted", c.name) } diff --git a/cmd/modern/root/config/delete-user_test.go b/cmd/modern/root/config/delete-user_test.go new file mode 100644 index 00000000..f29cdbd4 --- /dev/null +++ b/cmd/modern/root/config/delete-user_test.go @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "testing" +) + +func TestDeleteUser(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*DeleteUser]() +} diff --git a/cmd/root/config/get-contexts.go b/cmd/modern/root/config/get-contexts.go similarity index 80% rename from cmd/root/config/get-contexts.go rename to cmd/modern/root/config/get-contexts.go index d7d3648a..fc48c7fa 100644 --- a/cmd/root/config/get-contexts.go +++ b/cmd/modern/root/config/get-contexts.go @@ -6,9 +6,9 @@ package config import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// GetContexts implements the `sqlcmd config get-contexts` command type GetContexts struct { cmdparser.Cmd @@ -16,11 +16,11 @@ type GetContexts struct { detailed bool } -func (c *GetContexts) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *GetContexts) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "get-contexts", Short: "Display one or many contexts from the sqlconfig file", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "List all the context names in your sqlconfig file", Steps: []string{"sqlcmd config get-contexts"}, @@ -36,10 +36,10 @@ func (c *GetContexts) DefineCommand(...cmdparser.Command) { }, Run: c.run, - FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagInfo{Flag: "name", Value: &c.name}, + FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{Flag: "name", Value: &c.name}, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -53,6 +53,8 @@ func (c *GetContexts) DefineCommand(...cmdparser.Command) { } func (c *GetContexts) run() { + output := c.Output() + if c.name != "" { if config.ContextExists(c.name) { context := config.GetContext(c.name) diff --git a/cmd/modern/root/config/get-contexts_test.go b/cmd/modern/root/config/get-contexts_test.go new file mode 100644 index 00000000..ba989bad --- /dev/null +++ b/cmd/modern/root/config/get-contexts_test.go @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +func TestGetContexts(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddEndpoint]("--name endpoint") + cmdparser.TestCmd[*AddContext]("--endpoint endpoint") + cmdparser.TestCmd[*GetContexts]() + cmdparser.TestCmd[*GetContexts]("context") +} + +func TestNegGetContexts(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*GetContexts]("does-not-exist") +} diff --git a/cmd/root/config/get-endpoints.go b/cmd/modern/root/config/get-endpoints.go similarity index 80% rename from cmd/root/config/get-endpoints.go rename to cmd/modern/root/config/get-endpoints.go index 2e901e69..8b637045 100644 --- a/cmd/root/config/get-endpoints.go +++ b/cmd/modern/root/config/get-endpoints.go @@ -6,9 +6,9 @@ package config import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// GetEndpoints implements the `sqlcmd config get-endpoints` command type GetEndpoints struct { cmdparser.Cmd @@ -16,11 +16,11 @@ type GetEndpoints struct { detailed bool } -func (c *GetEndpoints) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *GetEndpoints) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "get-endpoints", Short: "Display one or many endpoints from the sqlconfig file", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "List all the endpoints in your sqlconfig file", Steps: []string{"sqlcmd config get-endpoints"}}, @@ -32,10 +32,10 @@ func (c *GetEndpoints) DefineCommand(...cmdparser.Command) { Steps: []string{"sqlcmd config get-endpoints my-endpoint"}}, }, Run: c.run, - FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagInfo{Flag: "name", Value: &c.name}, + FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{Flag: "name", Value: &c.name}, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -49,6 +49,8 @@ func (c *GetEndpoints) DefineCommand(...cmdparser.Command) { } func (c *GetEndpoints) run() { + output := c.Output() + if c.name != "" { if config.EndpointExists(c.name) { context := config.GetEndpoint(c.name) diff --git a/cmd/modern/root/config/get-endpoints_test.go b/cmd/modern/root/config/get-endpoints_test.go new file mode 100644 index 00000000..8bf25e4a --- /dev/null +++ b/cmd/modern/root/config/get-endpoints_test.go @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +func TestGetEndpoints(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddEndpoint]("--name endpoint") + cmdparser.TestCmd[*GetEndpoints]() + cmdparser.TestCmd[*GetEndpoints]("endpoint") + +} + +func TestNegGetEndpoints(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*GetEndpoints]("does-not-exist") +} diff --git a/cmd/root/config/get-users.go b/cmd/modern/root/config/get-users.go similarity index 77% rename from cmd/root/config/get-users.go rename to cmd/modern/root/config/get-users.go index 9bfaa447..d5d24557 100644 --- a/cmd/root/config/get-users.go +++ b/cmd/modern/root/config/get-users.go @@ -6,9 +6,9 @@ package config import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// GetUsers implements the `sqlcmd config get-users` command type GetUsers struct { cmdparser.Cmd @@ -16,11 +16,11 @@ type GetUsers struct { detailed bool } -func (c *GetUsers) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *GetUsers) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "get-users", Short: "Display one or many users from the sqlconfig file", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "List all the users in your sqlconfig file", Steps: []string{"sqlcmd config get-users"}, @@ -36,10 +36,10 @@ func (c *GetUsers) DefineCommand(...cmdparser.Command) { }, Run: c.run, - FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagInfo{Flag: "name", Value: &c.name}, + FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{Flag: "name", Value: &c.name}, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -53,8 +53,10 @@ func (c *GetUsers) DefineCommand(...cmdparser.Command) { } func (c *GetUsers) run() { + output := c.Output() + if c.name != "" { - if config.UserExists(c.name) { + if config.UserNameExists(c.name) { user := config.GetUser(c.name) output.Struct(user) } else { diff --git a/cmd/modern/root/config/get-users_test.go b/cmd/modern/root/config/get-users_test.go new file mode 100644 index 00000000..68b04fdb --- /dev/null +++ b/cmd/modern/root/config/get-users_test.go @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +func TestGetUsers(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddUser]("--name user --username user") + cmdparser.TestCmd[*GetUsers]() + cmdparser.TestCmd[*GetUsers]("user") +} + +func TestNegGetUsers(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*GetUsers]("does-not-exist") +} diff --git a/cmd/root/config/use-context.go b/cmd/modern/root/config/use-context.go similarity index 63% rename from cmd/root/config/use-context.go rename to cmd/modern/root/config/use-context.go index 06abf878..2e98a500 100644 --- a/cmd/root/config/use-context.go +++ b/cmd/modern/root/config/use-context.go @@ -6,32 +6,29 @@ package config import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// UseContext implements the `sqlcmd config use-context` command type UseContext struct { cmdparser.Cmd name string } -func (c *UseContext) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *UseContext) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "use-context", - Short: "Display one or many users from the sqlconfig file", - Examples: []cmdparser.ExampleInfo{ - { - Description: "Use the context for the user@mssql sql instance", - Steps: []string{"sqlcmd config use-context user@mssql"}, - }, - }, + Short: "Set the current context", + Examples: []cmdparser.ExampleOptions{{ + Description: "Set the mssql context (endpoint/user) to be the current context", + Steps: []string{"sqlcmd config use-context mssql"}}}, Aliases: []string{"use", "change-context", "set-context"}, Run: c.run, - FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagInfo{Flag: "name", Value: &c.name}, + FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{Flag: "name", Value: &c.name}, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.name, @@ -40,6 +37,8 @@ func (c *UseContext) DefineCommand(...cmdparser.Command) { } func (c *UseContext) run() { + output := c.Output() + if config.ContextExists(c.name) { config.SetCurrentContextName(c.name) output.InfofWithHints([]string{ diff --git a/cmd/modern/root/config/use-context_test.go b/cmd/modern/root/config/use-context_test.go new file mode 100644 index 00000000..7419c43b --- /dev/null +++ b/cmd/modern/root/config/use-context_test.go @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +func TestUseContext(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*AddEndpoint]() + cmdparser.TestCmd[*AddContext]("--endpoint endpoint") + cmdparser.TestCmd[*UseContext]("--name context") +} + +func TestNegUseContext(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*UseContext]("does-not-exist") +} diff --git a/cmd/root/config/view.go b/cmd/modern/root/config/view.go similarity index 74% rename from cmd/root/config/view.go rename to cmd/modern/root/config/view.go index 6e6d6e09..fe0fadac 100644 --- a/cmd/root/config/view.go +++ b/cmd/modern/root/config/view.go @@ -6,20 +6,20 @@ package config import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "github.com/microsoft/go-sqlcmd/internal/output" ) +// View implements the `sqlcmd config view` command type View struct { cmdparser.Cmd raw bool } -func (c *View) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *View) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "view", Short: "Display merged sqlconfig settings or a specified sqlconfig file", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "Show merged sqlconfig settings", Steps: []string{"sqlcmd config view"}, @@ -33,7 +33,7 @@ func (c *View) DefineCommand(...cmdparser.Command) { Run: c.run, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ Name: "raw", @@ -43,6 +43,8 @@ func (c *View) DefineCommand(...cmdparser.Command) { } func (c *View) run() { - contents := config.GetRedactedConfig(c.raw) + output := c.Output() + + contents := config.RedactedConfig(c.raw) output.Struct(contents) } diff --git a/cmd/modern/root/config/view_test.go b/cmd/modern/root/config/view_test.go new file mode 100644 index 00000000..9050588b --- /dev/null +++ b/cmd/modern/root/config/view_test.go @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package config + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "testing" +) + +func TestView(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*View]() +} diff --git a/cmd/modern/root/config_test.go b/cmd/modern/root/config_test.go new file mode 100644 index 00000000..f1bf477d --- /dev/null +++ b/cmd/modern/root/config_test.go @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package root + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "testing" +) + +// TestConfig runs a sanity test of `sqlcmd config` +func TestConfig(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*Config]() +} diff --git a/cmd/modern/root/install.go b/cmd/modern/root/install.go new file mode 100644 index 00000000..63d6b29a --- /dev/null +++ b/cmd/modern/root/install.go @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package root + +import ( + "github.com/microsoft/go-sqlcmd/cmd/modern/root/install" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" +) + +// Install defines the `sqlcmd install` sub-commands +type Install struct { + cmdparser.Cmd +} + +func (c *Install) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ + Use: "install", + Short: "Install/Create #SQLFamily and Tools", + Aliases: []string{"create"}, + SubCommands: c.SubCommands(), + } + + c.Cmd.DefineCommand(options) +} + +// SubCommands sets up the sub-commands for `sqlcmd install` such as +// `sqlcmd install mssql` and `sqlcmd install azsql-edge` +func (c *Install) SubCommands() []cmdparser.Command { + dependencies := c.Dependencies() + + return []cmdparser.Command{ + cmdparser.New[*install.Mssql](dependencies), + cmdparser.New[*install.Edge](dependencies), + } +} diff --git a/cmd/modern/root/install/edge.go b/cmd/modern/root/install/edge.go new file mode 100644 index 00000000..92b9f245 --- /dev/null +++ b/cmd/modern/root/install/edge.go @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package install + +import ( + "github.com/microsoft/go-sqlcmd/cmd/modern/root/install/edge" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/microsoft/go-sqlcmd/internal/pal" +) + +// Edge implements the `sqlcmd install azsql-edge command and sub-commands +type Edge struct { + cmdparser.Cmd + MssqlBase +} + +func (c *Edge) DefineCommand(...cmdparser.CommandOptions) { + const repo = "azure-sql-edge" + + options := cmdparser.CommandOptions{ + Use: "azsql-edge", + Short: "Install Azure Sql Edge", + Examples: []cmdparser.ExampleOptions{{ + Description: "Install Azure SQL Edge in a container", + Steps: []string{"sqlcmd install azsql-edge"}}}, + Run: c.MssqlBase.Run, + SubCommands: c.SubCommands(), + } + + c.MssqlBase.SetCrossCuttingConcerns(dependency.Options{ + EndOfLine: pal.LineBreak(), + Output: c.Output(), + }) + + c.Cmd.DefineCommand(options) + c.AddFlags(c.AddFlag, repo, "edge") +} + +func (c *Edge) SubCommands() []cmdparser.Command { + return []cmdparser.Command{ + cmdparser.New[*edge.GetTags](c.Dependencies()), + } +} diff --git a/cmd/root/install/edge/get-tags.go b/cmd/modern/root/install/edge/get-tags.go similarity index 61% rename from cmd/root/install/edge/get-tags.go rename to cmd/modern/root/install/edge/get-tags.go index 8adf5dc5..cb6c3bd5 100644 --- a/cmd/root/install/edge/get-tags.go +++ b/cmd/modern/root/install/edge/get-tags.go @@ -6,31 +6,32 @@ package edge import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/container" - "github.com/microsoft/go-sqlcmd/internal/output" ) type GetTags struct { cmdparser.Cmd } -func (c *GetTags) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *GetTags) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "get-tags", - Short: "Get tags available for mssql edge install", - Examples: []cmdparser.ExampleInfo{ + Short: "Get tags available for Azure SQL Edge install", + Examples: []cmdparser.ExampleOptions{ { Description: "List tags", - Steps: []string{"sqlcmd install mssql-edge get-tags"}, + Steps: []string{"sqlcmd install azsql-edge get-tags"}, }, }, Aliases: []string{"gt", "lt"}, Run: c.run, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) } func (c *GetTags) run() { + output := c.Output() + tags := container.ListTags( "azure-sql-edge", "https://mcr.microsoft.com", diff --git a/cmd/modern/root/install/edge_test.go b/cmd/modern/root/install/edge_test.go new file mode 100644 index 00000000..5bb6251b --- /dev/null +++ b/cmd/modern/root/install/edge_test.go @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package install + +import ( + "github.com/microsoft/go-sqlcmd/cmd/modern/root/install/edge" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestInstallEdge(t *testing.T) { + // DEVNOTE: To prevent "import cycle not allowed" golang compile time error (due to + // cleaning up the Install using root.Uninstall), we don't use root.Uninstall, + // and use the controller object instead + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*edge.GetTags]() + cmdparser.TestCmd[*Edge]("--accept-eula --user-database foo") + + controller := container.NewController() + id := config.ContainerId() + err := controller.ContainerStop(id) + assert.Nil(t, err) + err = controller.ContainerRemove(id) + assert.Nil(t, err) +} diff --git a/cmd/root/install/mssql-base.go b/cmd/modern/root/install/mssql-base.go similarity index 75% rename from cmd/root/install/mssql-base.go rename to cmd/modern/root/install/mssql-base.go index c970999a..f15dbf6c 100644 --- a/cmd/root/install/mssql-base.go +++ b/cmd/modern/root/install/mssql-base.go @@ -5,19 +5,19 @@ package install import ( "fmt" - "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" "github.com/microsoft/go-sqlcmd/internal/container" "github.com/microsoft/go-sqlcmd/internal/mssql" - "github.com/microsoft/go-sqlcmd/internal/output" "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/microsoft/go-sqlcmd/internal/secret" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" "github.com/spf13/viper" ) -// MssqlBase provide base support for installing SQL Server. +// MssqlBase provide base support for installing SQL Server and all of its +// various flavors, e.g. SQL Server Edge. type MssqlBase struct { cmdparser.Cmd @@ -40,7 +40,11 @@ type MssqlBase struct { defaultContextName string collation string + port int + sqlcmdPkg *sqlcmd.Sqlcmd + + unittesting bool } func (c *MssqlBase) AddFlags( @@ -151,9 +155,24 @@ func (c *MssqlBase) AddFlags( Name: "collation", Usage: "The SQL Server collation", }) + + addFlag(cmdparser.FlagOptions{ + Int: &c.port, + DefaultInt: 0, + Name: "port-override", + Usage: "Port override (next available port from 1433 upwards used by default)", + }) } +// Run checks that the end-user license agreement has been accepted, +// constructs the container image name from the provided registry, repository, and tag, +// and sets the context name to a default value if it is not provided. +// Finally, it installs the image as a container and names it using the context name. +// If the EULA has not been accepted, it prints an error message with suggestions for how to proceed, +// and exits the program. func (c *MssqlBase) Run() { + output := c.Cmd.Output() + var imageName string if !c.acceptEula && viper.GetString("ACCEPT_EULA") == "" { @@ -176,7 +195,14 @@ func (c *MssqlBase) Run() { c.installContainerImage(imageName, c.contextName) } +// installContainerImage installs an image for a SQL Server container. The image +// is specified by imageName, and the container will be given the name contextName. +// If the useCached flag is set, the function will skip downloading the image +// from the internet. The function outputs progress messages to the command-line +// as it runs. If any errors are encountered, they will be printed to the +// command-line and the program will exit. func (c *MssqlBase) installContainerImage(imageName string, contextName string) { + output := c.Cmd.Output() saPassword := c.generatePassword() env := []string{ @@ -184,17 +210,23 @@ func (c *MssqlBase) installContainerImage(imageName string, contextName string) fmt.Sprintf("MSSQL_SA_PASSWORD=%s", saPassword), fmt.Sprintf("MSSQL_COLLATION=%s", c.collation), } - port := config.FindFreePortForTds() + if c.port == 0 { + c.port = config.FindFreePortForTds() + } controller := container.NewController() if !c.useCached { output.Infof("Downloading %v", imageName) err := controller.EnsureImage(imageName) - if err != nil { + if err != nil || c.unittesting { output.FatalfErrorWithHints( err, []string{ - "Is a container runtime installed on this machine (e.g. Podman or Docker)?\n\tIf not, download desktop engine from:\n\t\thttps://podman-desktop.io/\n\t\tor\n\t\thttps://docs.docker.com/get-docker/", + "Is a container runtime installed on this machine (e.g. Podman or Docker)?" + sqlcmd.SqlcmdEol + + "\tIf not, download desktop engine from:" + sqlcmd.SqlcmdEol + + "\t\thttps://podman-desktop.io/" + sqlcmd.SqlcmdEol + + "\t\tor" + sqlcmd.SqlcmdEol + + "\t\thttps://docs.docker.com/get-docker/", "Is a container runtime running. Try `podman ps` or `docker ps` (list containers), does it return without error?", fmt.Sprintf("If `podman ps` or `docker ps` works, try downloading the image with: `podman|docker pull %s`", imageName)}, "Unable to download image %s", imageName) @@ -202,8 +234,8 @@ func (c *MssqlBase) installContainerImage(imageName string, contextName string) } output.Infof("Starting %v", imageName) - containerId := controller.ContainerRun(imageName, env, port, []string{}, false) - previousContextName := config.GetCurrentContextName() + containerId := controller.ContainerRun(imageName, env, c.port, []string{}, false) + previousContextName := config.CurrentContextName() userName := pal.UserName() password := c.generatePassword() @@ -213,7 +245,7 @@ func (c *MssqlBase) installContainerImage(imageName string, contextName string) config.AddContextWithContainer( contextName, imageName, - port, + c.port, containerId, userName, password, @@ -222,7 +254,7 @@ func (c *MssqlBase) installContainerImage(imageName string, contextName string) output.Infof( "Created context %q in %q, configuring user account...", - config.GetCurrentContextName(), + config.CurrentContextName(), config.GetConfigFileUsed(), ) @@ -235,7 +267,7 @@ func (c *MssqlBase) installContainerImage(imageName string, contextName string) "sa", userName) - endpoint, _ := config.GetCurrentContext() + endpoint, _ := config.CurrentContext() c.sqlcmdPkg = mssql.Connect( endpoint, &sqlconfig.User{ @@ -271,17 +303,27 @@ func (c *MssqlBase) installContainerImage(imageName string, contextName string) output.InfofWithHintExamples(hints, "Now ready for client connections on port %d", - port, + c.port, ) } +func (c *MssqlBase) query(commandText string) { + mssql.Query(c.sqlcmdPkg, commandText) +} + +// createNonSaUser creates a user (non-sa) and assigns the sysadmin role +// to the user. It also creates a default database with the provided name +// and assigns the default database to the user. Finally, it disables +// the sa account and rotates the sa password for security reasons. func (c *MssqlBase) createNonSaUser(userName string, password string) { + output := c.Cmd.Output() + defaultDatabase := "master" if c.defaultDatabase != "" { defaultDatabase = c.defaultDatabase output.Infof("Creating default database [%s]", defaultDatabase) - c.Query(fmt.Sprintf("CREATE DATABASE [%s]", defaultDatabase)) + c.query(fmt.Sprintf("CREATE DATABASE [%s]", defaultDatabase)) } const createLogin = `CREATE LOGIN [%s] @@ -293,17 +335,17 @@ CHECK_POLICY=OFF` @loginame = N'%s', @rolename = N'sysadmin'` - c.Query(fmt.Sprintf(createLogin, userName, password, defaultDatabase)) - c.Query(fmt.Sprintf(addSrvRoleMember, userName)) + c.query(fmt.Sprintf(createLogin, userName, password, defaultDatabase)) + c.query(fmt.Sprintf(addSrvRoleMember, userName)) // Correct safety protocol is to rotate the sa password, because the first // sa password has been in the docker environment (as SA_PASSWORD) - c.Query(fmt.Sprintf("ALTER LOGIN [sa] WITH PASSWORD = N'%s';", + c.query(fmt.Sprintf("ALTER LOGIN [sa] WITH PASSWORD = N'%s';", c.generatePassword())) - c.Query("ALTER LOGIN [sa] DISABLE") + c.query("ALTER LOGIN [sa] DISABLE") if c.defaultDatabase != "" { - c.Query(fmt.Sprintf("ALTER AUTHORIZATION ON DATABASE::[%s] TO %s", + c.query(fmt.Sprintf("ALTER AUTHORIZATION ON DATABASE::[%s] TO %s", defaultDatabase, userName)) } } @@ -318,7 +360,3 @@ func (c *MssqlBase) generatePassword() (password string) { return } - -func (c *MssqlBase) Query(commandText string) { - mssql.Query(c.sqlcmdPkg, commandText) -} diff --git a/cmd/root/install/mssql-base_darwin.go b/cmd/modern/root/install/mssql-base_darwin.go similarity index 100% rename from cmd/root/install/mssql-base_darwin.go rename to cmd/modern/root/install/mssql-base_darwin.go diff --git a/cmd/root/install/mssql-base_linux.go b/cmd/modern/root/install/mssql-base_linux.go similarity index 100% rename from cmd/root/install/mssql-base_linux.go rename to cmd/modern/root/install/mssql-base_linux.go diff --git a/cmd/root/install/mssql-base_windows.go b/cmd/modern/root/install/mssql-base_windows.go similarity index 100% rename from cmd/root/install/mssql-base_windows.go rename to cmd/modern/root/install/mssql-base_windows.go diff --git a/cmd/modern/root/install/mssql.go b/cmd/modern/root/install/mssql.go new file mode 100644 index 00000000..46162b07 --- /dev/null +++ b/cmd/modern/root/install/mssql.go @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package install + +import ( + "github.com/microsoft/go-sqlcmd/cmd/modern/root/install/mssql" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/microsoft/go-sqlcmd/internal/pal" +) + +// Mssql implements the `sqlcmd install mssql command and sub-commands +type Mssql struct { + cmdparser.Cmd + MssqlBase +} + +func (c *Mssql) DefineCommand(...cmdparser.CommandOptions) { + const repo = "mssql/server" + + options := cmdparser.CommandOptions{ + Use: "mssql", + Short: "Install SQL Server", + Examples: []cmdparser.ExampleOptions{{ + Description: "Install SQL Server in a container", + Steps: []string{"sqlcmd install mssql"}}}, + Run: c.MssqlBase.Run, + SubCommands: c.SubCommands(), + } + + c.MssqlBase.SetCrossCuttingConcerns(dependency.Options{ + EndOfLine: pal.LineBreak(), + Output: c.Output(), + }) + + c.Cmd.DefineCommand(options) + c.AddFlags(c.AddFlag, repo, "mssql") +} + +func (c *Mssql) SubCommands() []cmdparser.Command { + return []cmdparser.Command{ + cmdparser.New[*mssql.GetTags](c.Dependencies()), + } +} diff --git a/cmd/root/install/mssql/get-tags.go b/cmd/modern/root/install/mssql/get-tags.go similarity index 75% rename from cmd/root/install/mssql/get-tags.go rename to cmd/modern/root/install/mssql/get-tags.go index 867f2a7b..43627756 100644 --- a/cmd/root/install/mssql/get-tags.go +++ b/cmd/modern/root/install/mssql/get-tags.go @@ -6,18 +6,17 @@ package mssql import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/container" - "github.com/microsoft/go-sqlcmd/internal/output" ) type GetTags struct { cmdparser.Cmd } -func (c *GetTags) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *GetTags) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "get-tags", Short: "Get tags available for mssql install", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "List tags", Steps: []string{"sqlcmd install mssql get-tags"}, @@ -27,11 +26,13 @@ func (c *GetTags) DefineCommand(...cmdparser.Command) { Run: c.run, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) } func (c *GetTags) run() { + output := c.Output() + tags := container.ListTags( "mssql/server", "https://mcr.microsoft.com", diff --git a/cmd/modern/root/install/mssql_test.go b/cmd/modern/root/install/mssql_test.go new file mode 100644 index 00000000..a9f1afa3 --- /dev/null +++ b/cmd/modern/root/install/mssql_test.go @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package install + +import ( + "github.com/microsoft/go-sqlcmd/cmd/modern/root/install/mssql" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/microsoft/go-sqlcmd/internal/test" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestInstallMssql(t *testing.T) { + // DEVNOTE: To prevent "import cycle not allowed" golang compile time error (due to + // cleaning up the Install using root.Uninstall), we don't use root.Uninstall, + // and use the controller object instead + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*mssql.GetTags]() + cmdparser.TestCmd[*Mssql]("--accept-eula --user-database foo") + + controller := container.NewController() + id := config.ContainerId() + err := controller.ContainerStop(id) + assert.Nil(t, err) + err = controller.ContainerRemove(id) + assert.Nil(t, err) +} + +func TestNegInstallMssql(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*Mssql]() +} + +func TestNegInstallMssql2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*Mssql]("--accept-eula --repo does/not/exist") +} diff --git a/cmd/modern/root/install_test.go b/cmd/modern/root/install_test.go new file mode 100644 index 00000000..94396866 --- /dev/null +++ b/cmd/modern/root/install_test.go @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package root + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "testing" +) + +// TestInstall runs a sanity test of `sqlcmd install` +func TestInstall(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*Install]() +} diff --git a/cmd/root/query.go b/cmd/modern/root/query.go similarity index 65% rename from cmd/root/query.go rename to cmd/modern/root/query.go index 08dbe614..89a9cf2c 100644 --- a/cmd/root/query.go +++ b/cmd/modern/root/query.go @@ -11,30 +11,31 @@ import ( "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" ) +// Query defines the `sqlcmd query` command type Query struct { cmdparser.Cmd text string } -func (c *Query) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *Query) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "query", Short: "Run a query against the current context", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ {Description: "Run a query", Steps: []string{ `sqlcmd query "SELECT @@SERVERNAME"`, `sqlcmd query --text "SELECT @@SERVERNAME"`, `sqlcmd query --query "SELECT @@SERVERNAME"`, }}}, Run: c.run, - FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagInfo{ + FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{ Flag: "text", Value: &c.text, }, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ String: &c.text, @@ -42,7 +43,7 @@ func (c *Query) DefineCommand(...cmdparser.Command) { Shorthand: "t", Usage: "Command text to run"}) - // BUG(stuartpa): Decide on if --text or --query is best + // BUG(stuartpa): Decide on if --text or --query is best (or leave both for convenience) c.AddFlag(cmdparser.FlagOptions{ String: &c.text, Name: "query", @@ -50,8 +51,12 @@ func (c *Query) DefineCommand(...cmdparser.Command) { Usage: "Command text to run"}) } +// run executes the Query command. +// It connects to a SQL Server endpoint using the current context from the config file, +// and either runs an interactive SQL console or executes the provided query. +// If an error occurs, it is handled by the CheckErr function. func (c *Query) run() { - endpoint, user := config.GetCurrentContext() + endpoint, user := config.CurrentContext() var line sqlcmd.Console = nil if c.text == "" { diff --git a/cmd/modern/root/query_test.go b/cmd/modern/root/query_test.go new file mode 100644 index 00000000..0d1c6775 --- /dev/null +++ b/cmd/modern/root/query_test.go @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package root + +import ( + "fmt" + "github.com/microsoft/go-sqlcmd/cmd/modern/root/config" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/stretchr/testify/assert" + "os" + "testing" +) + +// TestQuery runs a sanity test of `sqlcmd query` using the local instance on 1433 +func TestQuery(t *testing.T) { + t.Skip("stuartpa: This is failing in the pipeline (Login failed for user 'sa'.)") + cmdparser.TestSetup(t) + + // if SQLCMDSERVER != "" add an endpoint using the --address + if os.Getenv("SQLCMDSERVER") == "" { + cmdparser.TestCmd[*config.AddEndpoint]() + } else { + t.Logf("SQLCMDSERVER: %v", os.Getenv("SQLCMDSERVER")) + cmdparser.TestCmd[*config.AddEndpoint](fmt.Sprintf("--address %v", os.Getenv("SQLCMDSERVER"))) + } + + // If the SQLCMDPASSWORD envvar is set, then add a basic user, otherwise + // we'll use trusted auth + if os.Getenv("SQLCMDPASSWORD") != "" && + os.Getenv("SQLCMDUSER") != "" { + + // sqlcmd uses the SQLCMD_PASSWORD env var, but the tests use the + // SQLCMDPASSWORD env var + err := os.Setenv("SQLCMD_PASSWORD", os.Getenv("SQLCMDPASSWORD")) + assert.Nil(t, err) + cmdparser.TestCmd[*config.AddUser]( + fmt.Sprintf("--name user1 --username %s", + os.Getenv("SQLCMDUSER"))) + cmdparser.TestCmd[*config.AddContext]("--endpoint endpoint --user user1") + } else { + cmdparser.TestCmd[*config.AddContext]("--endpoint endpoint") + } + cmdparser.TestCmd[*config.View]() // displaying the config (info in-case test fails) + cmdparser.TestCmd[*Query]("PRINT") +} diff --git a/cmd/root/uninstall.go b/cmd/modern/root/uninstall.go similarity index 70% rename from cmd/root/uninstall.go rename to cmd/modern/root/uninstall.go index 18bfcc18..c4c95553 100644 --- a/cmd/root/uninstall.go +++ b/cmd/modern/root/uninstall.go @@ -5,14 +5,15 @@ package root import ( "fmt" + "path/filepath" + "strings" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" "github.com/microsoft/go-sqlcmd/internal/container" - "github.com/microsoft/go-sqlcmd/internal/output" - "path/filepath" - "strings" ) +// Uninstall defines the `sqlcmd uninstall` command type Uninstall struct { cmdparser.Cmd @@ -31,11 +32,11 @@ var systemDatabases = [...]string{ "/var/opt/mssql/data/master.mdf", } -func (c *Uninstall) DefineCommand(...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ +func (c *Uninstall) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ Use: "uninstall", Short: "Uninstall/Delete the current context", - Examples: []cmdparser.ExampleInfo{ + Examples: []cmdparser.ExampleOptions{ { Description: "Uninstall/Delete the current context (includes the endpoint and user)", Steps: []string{`sqlcmd uninstall`}}, @@ -50,7 +51,7 @@ func (c *Uninstall) DefineCommand(...cmdparser.Command) { Run: c.run, } - c.Cmd.DefineCommand() + c.Cmd.DefineCommand(options) c.AddFlag(cmdparser.FlagOptions{ Bool: &c.yes, @@ -65,23 +66,31 @@ func (c *Uninstall) DefineCommand(...cmdparser.Command) { }) } +// run executes the Uninstall command. +// It checks that the current context exists, and if it does, +// it verifies that no user database files exist if the force flag is not set. +// It then stops and removes the current context's container, +// removes the current context from the config file, and saves the config. +// If the operation is successful, it prints a message with the new current context. func (c *Uninstall) run() { - if config.GetCurrentContextName() == "" { + output := c.Output() + + if config.CurrentContextName() == "" { output.FatalfWithHintExamples([][]string{ {"To view available contexts", "sqlcmd config get-contexts"}, }, "No current context") } - if currentContextEndPointExists() { + if c.currentContextEndPointExists() { if config.CurrentContextEndpointHasContainer() { controller := container.NewController() - id := config.GetContainerId() - endpoint, _ := config.GetCurrentContext() + id := config.ContainerId() + endpoint, _ := config.CurrentContext() var input string if !c.yes { output.Infof( "Current context is %q. Do you want to continue? (Y/N)", - config.GetCurrentContextName(), + config.CurrentContextName(), ) _, err := fmt.Scanln(&input) c.CheckErr(err) @@ -92,7 +101,7 @@ func (c *Uninstall) run() { } if !c.force { output.Infof("Verifying no user (non-system) database (.mdf) files") - userDatabaseSafetyCheck(controller, id) + c.userDatabaseSafetyCheck(controller, id) } output.Infof( @@ -102,7 +111,7 @@ func (c *Uninstall) run() { err := controller.ContainerStop(id) c.CheckErr(err) - output.Infof("Removing context %s", config.GetCurrentContextName()) + output.Infof("Removing context %s", config.CurrentContextName()) err = controller.ContainerRemove(id) c.CheckErr(err) } @@ -110,7 +119,7 @@ func (c *Uninstall) run() { config.RemoveCurrentContext() config.Save() - newContextName := config.GetCurrentContextName() + newContextName := config.CurrentContextName() if newContextName != "" { output.Infof("Current context is now %s", newContextName) } else { @@ -119,7 +128,12 @@ func (c *Uninstall) run() { } } -func userDatabaseSafetyCheck(controller *container.Controller, id string) { +// userDatabaseSafetyCheck checks for the presence of user database files +// in the current context's container. It takes a container.Controller and a container ID as arguments. +// If user database files are found and the force flag is not set, it prints an error message +// with suggestions for how to proceed, and exits the program. +func (c *Uninstall) userDatabaseSafetyCheck(controller *container.Controller, id string) { + output := c.Output() files := controller.ContainerFiles(id, "*.mdf") for _, databaseFile := range files { if strings.HasSuffix(databaseFile, ".mdf") { @@ -143,7 +157,8 @@ func userDatabaseSafetyCheck(controller *container.Controller, id string) { } } -func currentContextEndPointExists() (exists bool) { +func (c *Uninstall) currentContextEndPointExists() (exists bool) { + output := c.Output() exists = true if !config.EndpointsExists() { diff --git a/cmd/modern/root/uninstall_test.go b/cmd/modern/root/uninstall_test.go new file mode 100644 index 00000000..28c28d4b --- /dev/null +++ b/cmd/modern/root/uninstall_test.go @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package root + +import ( + "github.com/microsoft/go-sqlcmd/cmd/modern/root/install" + "github.com/microsoft/go-sqlcmd/cmd/modern/root/install/edge" + "github.com/microsoft/go-sqlcmd/cmd/modern/root/install/mssql" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +// TestUninstall installs Mssql (on a specific port to enable parallel testing), and then +// uninstalls it +func TestUninstall(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*mssql.GetTags]() + cmdparser.TestCmd[*install.Mssql]("--accept-eula --port-override 1500") + cmdparser.TestCmd[*Uninstall]("--yes") +} + +// TestUninstallWithUserDbPresent(t *testing.T) { installs Mssql (on a specific port to enable parallel testing), with a +// user database, and then uninstalls it using the --force option +func TestUninstallWithUserDbPresent(t *testing.T) { + cmdparser.TestSetup(t) + cmdparser.TestCmd[*edge.GetTags]() + cmdparser.TestCmd[*install.Edge]("--accept-eula --user-database foo --port-override 1501") + cmdparser.TestCmd[*Uninstall]("--yes --force") +} + +// TestNegUninstallNoInstanceToUninstall tests that we fail if no instance to +// uninstall +func TestNegUninstallNoInstanceToUninstall(t *testing.T) { + t.Skip("stuartpa: Not passing on Linux, not sure why right now") + defer func() { test.CatchExpectedError(recover(), t) }() + + cmdparser.TestSetup(t) + cmdparser.TestCmd[*Uninstall]("--yes") +} diff --git a/cmd/modern/root_test.go b/cmd/modern/root_test.go new file mode 100644 index 00000000..4dcbbf95 --- /dev/null +++ b/cmd/modern/root_test.go @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package main + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/stretchr/testify/assert" + "testing" +) + +// TestRoot is a quick sanity test +func TestRoot(t *testing.T) { + c := cmdparser.New[*Root](dependency.Options{}) + c.DefineCommand() + c.SetArgsForUnitTesting([]string{}) + c.Execute() +} + +func TestIsValidSubCommand(t *testing.T) { + c := cmdparser.New[*Root](dependency.Options{}) + invalid := c.IsValidSubCommand("nope") + assert.Equal(t, false, invalid) + valid := c.IsValidSubCommand("query") + assert.Equal(t, true, valid) +} diff --git a/cmd/sqlconfig/doc.go b/cmd/modern/sqlconfig/doc.go similarity index 92% rename from cmd/sqlconfig/doc.go rename to cmd/modern/sqlconfig/doc.go index 55928bcf..bfcddb80 100644 --- a/cmd/sqlconfig/doc.go +++ b/cmd/modern/sqlconfig/doc.go @@ -25,9 +25,10 @@ An example of the sqlconfig file looks like this: apiversion: v1 endpoints: - - container: - id: 0e698e65e19d9c - image: mcr.microsoft.com/mssql/server:2022-latest + - asset: + - container: + id: 0e698e65e19d9c + image: mcr.microsoft.com/mssql/server:2022-latest endpoint: address: localhost port: 1435 diff --git a/cmd/sqlconfig/sqlconfig.go b/cmd/modern/sqlconfig/sqlconfig.go similarity index 75% rename from cmd/sqlconfig/sqlconfig.go rename to cmd/modern/sqlconfig/sqlconfig.go index 6b78fce5..2958da89 100644 --- a/cmd/sqlconfig/sqlconfig.go +++ b/cmd/modern/sqlconfig/sqlconfig.go @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +// Package sqlconfig, defines the metadata for representing a sqlconfig file. +// It includes structs for representing an endpoint, context, user, and the overall +// sqlconfig file itself. Each struct has fields for storing the various pieces +// of information that make up an SQL configuration, such as endpoint address +// and port, context name and endpoint, and user authentication type and details. +// These structs are used to manage and manipulate the sqlconfig. package sqlconfig type EndpointDetails struct { @@ -46,10 +52,9 @@ type User struct { } type Sqlconfig struct { - ApiVersion string `mapstructure:"apiVersion"` + Version string `mapstructure:"version"` Endpoints []Endpoint `mapstructure:"endpoints"` Contexts []Context `mapstructure:"contexts"` CurrentContext string `mapstructure:"currentcontext"` - Kind string `mapstructure:"kind"` Users []User `mapstructure:"users"` } diff --git a/cmd/root.go b/cmd/root.go deleted file mode 100644 index 03a0c7a9..00000000 --- a/cmd/root.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package cmd - -import ( - "github.com/microsoft/go-sqlcmd/internal/cmdparser" - "github.com/microsoft/go-sqlcmd/internal/config" -) - -type Root struct { - cmdparser.Cmd -} - -func (c *Root) DefineCommand(subCommands ...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ - Use: "sqlcmd", - Short: "sqlcmd: a command-line interface for the #SQLFamily", - Examples: []cmdparser.ExampleInfo{ - { - Description: "Run a query", - Steps: []string{`sqlcmd query "SELECT @@SERVERNAME"`}}}, - } - - c.Cmd.DefineCommand(subCommands...) - c.addGlobalFlags() -} - -func (c *Root) addGlobalFlags() { - c.AddFlag(cmdparser.FlagOptions{ - Bool: &globalOptions.TrustServerCertificate, - Name: "trust-server-certificate", - Shorthand: "C", - Usage: "Whether to trust the certificate presented by the endpoint for encryption", - }) - - c.AddFlag(cmdparser.FlagOptions{ - String: &globalOptions.DatabaseName, - Name: "database-name", - Shorthand: "d", - Usage: "The initial database for the connection", - }) - - c.AddFlag(cmdparser.FlagOptions{ - Bool: &globalOptions.UseTrustedConnection, - Name: "use-trusted-connection", - Shorthand: "E", - Usage: "Whether to use integrated security", - }) - - c.AddFlag(cmdparser.FlagOptions{ - String: &configFilename, - DefaultString: config.DefaultFileName(), - Name: "sqlconfig", - Usage: "Configuration file", - }) - - c.AddFlag(cmdparser.FlagOptions{ - String: &outputType, - DefaultString: "yaml", - Name: "output", - Shorthand: "o", - Usage: "output type (yaml, json or xml)", - }) - - c.AddFlag(cmdparser.FlagOptions{ - Int: &loggingLevel, - DefaultInt: 2, - Name: "verbosity", - Shorthand: "v", - Usage: "Log level, error=0, warn=1, info=2, debug=3, trace=4", - }) -} diff --git a/cmd/root/config.go b/cmd/root/config.go deleted file mode 100644 index af3f7c89..00000000 --- a/cmd/root/config.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package root - -import ( - "github.com/microsoft/go-sqlcmd/internal/cmdparser" -) - -type Config struct { - cmdparser.Cmd -} - -func (c *Config) DefineCommand(subCommands ...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ - Use: "config", - Short: `Modify sqlconfig files using subcommands like "sqlcmd config use-context mssql"`, - } - c.Cmd.DefineCommand(subCommands...) -} diff --git a/cmd/root/config/sub-commands.go b/cmd/root/config/sub-commands.go deleted file mode 100644 index 0066282a..00000000 --- a/cmd/root/config/sub-commands.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package config - -import "github.com/microsoft/go-sqlcmd/internal/cmdparser" - -func SubCommands() []cmdparser.Command { - return []cmdparser.Command{ - cmdparser.New[*AddContext](), - cmdparser.New[*AddEndpoint](), - cmdparser.New[*AddUser](), - cmdparser.New[*ConnectionStrings](), - cmdparser.New[*CurrentContext](), - cmdparser.New[*DeleteContext](), - cmdparser.New[*DeleteEndpoint](), - cmdparser.New[*DeleteUser](), - cmdparser.New[*GetContexts](), - cmdparser.New[*GetEndpoints](), - cmdparser.New[*GetUsers](), - cmdparser.New[*UseContext](), - cmdparser.New[*View](), - } -} diff --git a/cmd/root/install.go b/cmd/root/install.go deleted file mode 100644 index 973817c9..00000000 --- a/cmd/root/install.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package root - -import ( - "github.com/microsoft/go-sqlcmd/internal/cmdparser" -) - -type Install struct { - cmdparser.Cmd -} - -func (c *Install) DefineCommand(subCommands ...cmdparser.Command) { - c.Cmd.Options = cmdparser.Options{ - Use: "install", - Short: "Install/Create #SQLFamily and Tools", - Aliases: []string{"create"}, - } - c.Cmd.DefineCommand(subCommands...) -} diff --git a/cmd/root/install/edge.go b/cmd/root/install/edge.go deleted file mode 100644 index b7ba4313..00000000 --- a/cmd/root/install/edge.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package install - -import ( - "github.com/microsoft/go-sqlcmd/internal/cmdparser" -) - -type Edge struct { - cmdparser.Cmd - MssqlBase -} - -func (c *Edge) DefineCommand(subCommands ...cmdparser.Command) { - const repo = "azure-sql-edge" - - c.Cmd.Options = cmdparser.Options{ - Use: "mssql-edge", - Short: "Install SQL Server Edge", - Examples: []cmdparser.ExampleInfo{{ - Description: "Install SQL Server Edge in a container", - Steps: []string{"sqlcmd install mssql-edge"}}}, - Run: c.MssqlBase.Run, - } - - c.Cmd.DefineCommand(subCommands...) - c.AddFlags(c.AddFlag, repo, "edge") -} diff --git a/cmd/root/install/edge/sub-commands.go b/cmd/root/install/edge/sub-commands.go deleted file mode 100644 index e72f9ed3..00000000 --- a/cmd/root/install/edge/sub-commands.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package edge - -import "github.com/microsoft/go-sqlcmd/internal/cmdparser" - -var SubCommands = []cmdparser.Command{ - cmdparser.New[*GetTags](), -} diff --git a/cmd/root/install/mssql.go b/cmd/root/install/mssql.go deleted file mode 100644 index ab28d937..00000000 --- a/cmd/root/install/mssql.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package install - -import ( - "github.com/microsoft/go-sqlcmd/internal/cmdparser" -) - -type Mssql struct { - cmdparser.Cmd - MssqlBase -} - -func (c *Mssql) DefineCommand(subCommands ...cmdparser.Command) { - const repo = "mssql/server" - - c.Cmd.Options = cmdparser.Options{ - Use: "mssql", - Short: "Install SQL Server", - Examples: []cmdparser.ExampleInfo{{ - Description: "Install SQL Server in a container", - Steps: []string{"sqlcmd install mssql"}}}, - Run: c.MssqlBase.Run, - } - - c.Cmd.DefineCommand(subCommands...) - c.AddFlags(c.AddFlag, repo, "mssql") -} diff --git a/cmd/root/install/mssql/sub-commands.go b/cmd/root/install/mssql/sub-commands.go deleted file mode 100644 index cde7a3b8..00000000 --- a/cmd/root/install/mssql/sub-commands.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package mssql - -import "github.com/microsoft/go-sqlcmd/internal/cmdparser" - -var SubCommands = []cmdparser.Command{ - cmdparser.New[*GetTags](), -} diff --git a/cmd/root/install/sub-commands.go b/cmd/root/install/sub-commands.go deleted file mode 100644 index bf12daaa..00000000 --- a/cmd/root/install/sub-commands.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package install - -import ( - "github.com/microsoft/go-sqlcmd/cmd/root/install/edge" - "github.com/microsoft/go-sqlcmd/cmd/root/install/mssql" - "github.com/microsoft/go-sqlcmd/internal/cmdparser" -) - -var SubCommands = []cmdparser.Command{ - cmdparser.New[*Mssql](mssql.SubCommands...), - cmdparser.New[*Edge](edge.SubCommands...), -} diff --git a/cmd/root/sub-commands.go b/cmd/root/sub-commands.go deleted file mode 100644 index 20e2b2c3..00000000 --- a/cmd/root/sub-commands.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package root - -import ( - "github.com/microsoft/go-sqlcmd/cmd/root/config" - "github.com/microsoft/go-sqlcmd/cmd/root/install" - "github.com/microsoft/go-sqlcmd/internal/cmdparser" -) - -func SubCommands() []cmdparser.Command { - return []cmdparser.Command{ - cmdparser.New[*Config](config.SubCommands()...), - cmdparser.New[*Query](), - cmdparser.New[*Install](install.SubCommands...), - cmdparser.New[*Uninstall](), - } -} diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index fad7737e..ff6659be 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +// //go:generate go-winres make --file-version=git-tag --product-version=git-tag package sqlcmd @@ -196,7 +197,6 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) { for v := range args.Variables { vars.Set(v, args.Variables[v]) } - } func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sqlcmd.Variables) { diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index 15e41290..c2ebee2f 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -190,8 +190,6 @@ func TestUnicodeOutput(t *testing.T) { } func TestUnicodeInput(t *testing.T) { - // BUG(stuartpa): This test has to be fixed before merging - t.Skip() testfiles := []string{ filepath.Join(`testdata`, `selectutf8.txt`), diff --git a/go.sum b/go.sum index 0e25ee54..71eecd14 100644 --- a/go.sum +++ b/go.sum @@ -450,6 +450,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220224120231-95c6836cb0e7/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/cmdparser/cmd.go b/internal/cmdparser/cmd.go index 8afccf75..04cd922f 100644 --- a/internal/cmdparser/cmd.go +++ b/internal/cmdparser/cmd.go @@ -5,18 +5,28 @@ package cmdparser import ( "fmt" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/microsoft/go-sqlcmd/internal/pal" + "strings" + "github.com/microsoft/go-sqlcmd/internal/output" "github.com/spf13/cobra" - "os" - "strings" ) +// AddFlag adds a flag to the command instance of type Cmd. The flag is added +// according to the provided FlagOptions. If the FlagOptions does not have a +// name or usage, then the function panics. If the flag is of type String then +// it adds the flag to the PersistentFlags of the command instance with the +// provided options. Similarly, if the flag is of type Int or Bool, the flag +// is added to the PersistentFlags with the provided options. If a shorthand +// is provided, then it uses StringVarP or BoolVarP or IntVarP instead of +// StringVar or BoolVar or IntVar respectively. func (c *Cmd) AddFlag(options FlagOptions) { if options.Name == "" { panic("Must provide name") } if options.Usage == "" { - panic("Must provide usage") + panic("Must provide usage for flag") } if options.String != nil { @@ -40,7 +50,7 @@ func (c *Cmd) AddFlag(options FlagOptions) { } if options.Int != nil { - if options.String != nil || options.Bool != nil { + if options.Bool != nil { panic("Only provide one type") } if options.Shorthand == "" { @@ -60,9 +70,6 @@ func (c *Cmd) AddFlag(options FlagOptions) { } if options.Bool != nil { - if options.String != nil || options.Int != nil { - panic("Only provide one type") - } if options.Shorthand == "" { c.command.PersistentFlags().BoolVar( options.Bool, @@ -80,35 +87,45 @@ func (c *Cmd) AddFlag(options FlagOptions) { } } -func (c *Cmd) ArgsForUnitTesting(args []string) { - c.command.SetArgs(args) -} +// DefineCommand defines a command with the provided CommandOptions and adds +// it to the command list. If only one CommandOptions is provided, it is used +// as the command options. Otherwise, the default CommandOptions are used. The +// function sets the command usage, short and long descriptions, aliases, examples, +// and run function. It also sets the maximum number of arguments allowed for the +// command, and adds any subcommands specified in the CommandOptions. +func (c *Cmd) DefineCommand(options ...CommandOptions) { + if len(options) == 1 { + c.options = options[0] + } -func (c *Cmd) DefineCommand(subCommands ...Command) { - if c.Options.Use == "" { + if c.options.Use == "" { panic("Must implement command definition") } - if c.Options.Long == "" { - c.Options.Long = c.Options.Short + if c.options.Long == "" { + c.options.Long = c.options.Short } c.command = cobra.Command{ - Use: c.Options.Use, - Short: c.Options.Short, - Long: c.Options.Long, - Aliases: c.Options.Aliases, + Use: c.options.Use, + Short: c.options.Short, + Long: c.options.Long, + Aliases: c.options.Aliases, Example: c.generateExamples(), Run: c.run, } - if c.Options.FirstArgAlternativeForFlag != nil { + if c.options.FirstArgAlternativeForFlag != nil { c.command.Args = cobra.MaximumNArgs(1) + + // IDIOMATIC: override the Use so the --help includes the flag name in caps and square bracket + // e.g. `sqlcmd config use-context [NAME]` or `sqlcmd config delete-user [NAME]` + c.command.Use = c.options.Use + " [" + strings.ToUpper(c.options.FirstArgAlternativeForFlag.Flag) + "]" } else { c.command.Args = cobra.MaximumNArgs(0) } - c.addSubCommands(subCommands) + c.addSubCommands(c.options.SubCommands) } // CheckErr passes the error down to cobra.CheckErr (which is likely to call @@ -117,28 +134,57 @@ func (c *Cmd) DefineCommand(subCommands ...Command) { // process, and call panic instead so the call stack can be added to the unit test // output. func (c *Cmd) CheckErr(err error) { - // If we are in a unit test driver, then panic, otherwise pass down to cobra.CheckErr - if strings.HasSuffix(os.Args[0], ".test") || // are we in go test? - (len(os.Args) > 1 && os.Args[1] == "-test.v") { // are we in goland unittest? - if err != nil { - panic(err) - } - } else { - cobra.CheckErr(err) - } + output := c.Output() + output.FatalErr(err) } +// Command returns the cobra Command associated with the Cmd. This method +// allows for easy access and manipulation of the command's properties and behavior. func (c *Cmd) Command() *cobra.Command { return &c.command } +// Execute function is responsible for executing the underlying command for +// this Cmd object. The function first attempts to execute the command, and then +// checks for any errors that may have occurred during execution. If an error +// is detected, the CheckErr method is called to handle the error. This function +// is typically called after defining and configuring the command using +// the DefineCommand and SetArgsForUnitTesting functions. func (c *Cmd) Execute() { err := c.command.Execute() c.CheckErr(err) } -func (c *Cmd) IsSubCommand(command string) (valid bool) { +// Output function is a getter function that returns the output.Output instance +// associated with the Cmd instance. If no output.Output instance has been +// set, the function initializes a new instance and returns it. +func (c *Cmd) Output() *output.Output { + if c.dependencies.Output == nil { + panic("output.New has not been called yet") + } + return c.dependencies.Output +} + +func (c *Cmd) Dependencies() dependency.Options { + return c.dependencies +} +// Inject dependencies into the Cmd struct. The options parameter is a struct +// containing a reference to the output struct, which the function then +// assigns to the output field of the Cmd struct. This allows for the +// output struct to be mocked in unit tests. +func (c *Cmd) SetCrossCuttingConcerns(dependencies dependency.Options) { + if dependencies.Output == nil { + panic("Output is nil") + } + c.dependencies = dependencies +} + +// IsSubCommand returns true if the provided command string +// matches the name or an alias of one of the object sub-commands, +// or if the command string is "--help" or "completion". Otherwise, +// it returns false. +func (c *Cmd) IsSubCommand(command string) (valid bool) { if command == "--help" { valid = true } else if command == "completion" { @@ -162,47 +208,75 @@ func (c *Cmd) IsSubCommand(command string) (valid bool) { return } +// SetArgsForUnitTesting sets the arguments for a unit test. +// This function allows users to specify arguments to the command for testing purposes. +func (c *Cmd) SetArgsForUnitTesting(args []string) { + c.command.SetArgs(args) +} + +// addSubCommands is a helper function that is used to add multiple sub-commands +// to a parent command in the application. It takes a slice of Command objects +// as an input and then adds each Command object to the parent command using +// the AddCommand method. This allows for a modular approach to defining +// the application's command hierarchy and makes it easy to add new sub-commands +// to the parent command. func (c *Cmd) addSubCommands(commands []Command) { + if c.dependencies.Output == nil { + panic("Why is output nil?") + } for _, subCommand := range commands { c.command.AddCommand(subCommand.Command()) } } +// generateExamples generates a list of examples for a command. It iterates +// over the Examples property of the CommandOptions struct, appending the +// description and steps for each example. The resulting string is returned. func (c *Cmd) generateExamples() string { var sb strings.Builder - for _, e := range c.Options.Examples { - sb.WriteString(fmt.Sprintf("# %v\n", e.Description)) + for _, e := range c.options.Examples { + sb.WriteString(fmt.Sprintf("# %v%v", e.Description, pal.LineBreak())) for _, s := range e.Steps { - sb.WriteString(fmt.Sprintf(" %v\n", s)) + sb.WriteString(fmt.Sprintf(" %v%v", s, pal.LineBreak())) } } return sb.String() } +// run function is a command handler for the cobra library. It checks if the first +// argument has been provided as an alternative for the specified flag and, if so, +// sets the value of that flag to the provided argument. If the Run option has been +// specified in the CommandOptions, it calls that function. func (c *Cmd) run(_ *cobra.Command, args []string) { - if c.Options.FirstArgAlternativeForFlag != nil { + if c.options.FirstArgAlternativeForFlag != nil { if len(args) > 0 { flag, err := c.command.PersistentFlags().GetString( - c.Options.FirstArgAlternativeForFlag.Flag) + c.options.FirstArgAlternativeForFlag.Flag) c.CheckErr(err) + if flag != "" { - output.Fatal( + c.dependencies.Output.Fatal( fmt.Sprintf( "Both an argument and the --%v flag have been provided. "+ "Please provide either an argument or the --%v flag", - c.Options.FirstArgAlternativeForFlag.Flag, - c.Options.FirstArgAlternativeForFlag.Flag)) + c.options.FirstArgAlternativeForFlag.Flag, + c.options.FirstArgAlternativeForFlag.Flag)) } - if c.Options.FirstArgAlternativeForFlag.Value == nil { + if c.options.FirstArgAlternativeForFlag.Value == nil { panic("Must set Value") } - *c.Options.FirstArgAlternativeForFlag.Value = args[0] + *c.options.FirstArgAlternativeForFlag.Value = args[0] } } - if c.Options.Run != nil { - c.Options.Run() + if c.options.Run == nil { + // If command has no run, it has sub-commands only, then display help if no + // sub-command entered + err := c.command.Help() + c.CheckErr(err) + } else { + c.options.Run() } } diff --git a/internal/cmdparser/cmd_test.go b/internal/cmdparser/cmd_test.go index 5921203d..ef64572a 100644 --- a/internal/cmdparser/cmd_test.go +++ b/internal/cmdparser/cmd_test.go @@ -1,20 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + package cmdparser import ( + "fmt" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/microsoft/go-sqlcmd/internal/output" + "github.com/microsoft/go-sqlcmd/internal/test" "github.com/spf13/cobra" "testing" ) func TestCmd_run(t *testing.T) { s := "" - c := &Cmd{ - Options: Options{ - FirstArgAlternativeForFlag: &AlternativeForFlagInfo{ + c := Cmd{ + options: CommandOptions{ + FirstArgAlternativeForFlag: &AlternativeForFlagOptions{ Flag: "name", Value: &s, }, }, - command: cobra.Command{}, + dependencies: dependency.Options{Output: output.New(output.Options{ErrorHandler: func(err error) {}, HintHandler: func(hints []string) {}})}, + command: cobra.Command{}, } c.AddFlag(FlagOptions{ Name: "name", @@ -23,3 +31,148 @@ func TestCmd_run(t *testing.T) { }) c.run(nil, []string{"name-value"}) } + +func TestNegCmd_run(t *testing.T) { + s := "" + c := Cmd{ + options: CommandOptions{ + FirstArgAlternativeForFlag: &AlternativeForFlagOptions{ + Flag: "name", + Value: &s, + }}, + dependencies: dependency.Options{Output: output.New(output.Options{ErrorHandler: func(err error) {}, HintHandler: func(hints []string) {}})}, + command: cobra.Command{}, + } + c.AddFlag(FlagOptions{ + Name: "name", + Usage: "name", + String: &s, + }) + c.run(nil, []string{"name-value"}) +} + +func TestNegCmdProvideBothFlagAndCmd(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + TestSetup(t) + s := "" + c := Cmd{ + options: CommandOptions{ + Use: "foo", + FirstArgAlternativeForFlag: &AlternativeForFlagOptions{ + Flag: "name", + Value: &s, + }, + Run: func() { fmt.Println("Running command") }, + }, + dependencies: dependency.Options{Output: output.New(output.Options{})}, + } + c.DefineCommand() + c.AddFlag(FlagOptions{ + Name: "name", + Usage: "name", + String: &s, + }) + c.SetArgsForUnitTesting([]string{"name-value", "--name", "another-value"}) + c.Execute() +} + +func TestNegCmdAlternativeValueNotSet(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + s := "" + c := Cmd{ + options: CommandOptions{ + Use: "foo", + FirstArgAlternativeForFlag: &AlternativeForFlagOptions{ + Flag: "name", + Value: nil, + }, + Run: func() {}, + }, + } + c.DefineCommand() + c.AddFlag(FlagOptions{ + Name: "name", + Usage: "name", + String: &s, + }) + c.SetArgsForUnitTesting([]string{"name-value"}) + c.Execute() +} + +func TestNegAddFlag(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + c := Cmd{options: CommandOptions{ + Use: "foo"}} + c.AddFlag(FlagOptions{ + Name: "", + Usage: "name", + }) +} + +func TestNegAddFlag2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + c := Cmd{options: CommandOptions{ + Use: "foo"}} + c.AddFlag(FlagOptions{Name: "name", Usage: ""}) +} + +func TestNegAddFlag3(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + s := "'" + b := false + c := Cmd{options: CommandOptions{Use: "foo"}} + c.AddFlag(FlagOptions{Name: "name", Usage: "usage", String: &s, Bool: &b}) +} + +func TestNegAddFlag4(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + b := false + i := 0 + c := Cmd{options: CommandOptions{Use: "foo"}} + c.AddFlag(FlagOptions{Name: "name", Usage: "usage", Bool: &b, Int: &i}) +} + +func TestNegDefineCommandNoCommandOptions(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + c := Cmd{options: CommandOptions{}} + c.DefineCommand() +} + +// TestCmd_CheckErrInNotTestingMode covers the code that is not used +// for testing (because we don't want os.Exit() to be called by cobra.checkErr, +// so this test runs in NotTesting mode, and then doesn't pass in an error, +// so the code is covered, but os.Exist isn't called +func TestCmd_CheckErrInNotTestingMode(t *testing.T) { + c := Cmd{ + dependencies: dependency.Options{ + EndOfLine: "", + Output: output.New(output.Options{ + ErrorHandler: func(err error) {}, + HintHandler: func(hints []string) {}, + }), + }, + unitTesting: true, + } + c.CheckErr(nil) +} + +func TestNegOutputNewHasNotBeenCalled(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + c := Cmd{} + c.Output() +} + +func TestNegOutputNewHasNotBeenCalled2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + c := Cmd{} + c.SetCrossCuttingConcerns(dependency.Options{}) +} diff --git a/internal/cmdparser/cmdparser.go b/internal/cmdparser/cmdparser.go deleted file mode 100644 index ec934d90..00000000 --- a/internal/cmdparser/cmdparser.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package cmdparser - -import ( - "github.com/spf13/cobra" -) - -// Initialize runs the init func() after the command-line provided by the user -// has been parsed. -func Initialize(init func()) { - cobra.OnInitialize(init) -} - -// New creates a cmdparser. After New returns, call Execute() method -// on the top-level Command -// -// Example: -// -// topLevel : = cmd.New[*MyCommand]() -// topLevel.Execute() -// -// Example with sub-commands -// -// topLevel := cmd.New[*MyCommand](MyCommand.subCommands) -func New[T PtrAsReceiverWrapper[CommandPtr], CommandPtr any](subCommands ...Command) (cmd T) { - cmd = new(CommandPtr) - cmd.DefineCommand(subCommands...) - return -} - -// PtrAsReceiverWrapper per golang design doc "an unfortunate necessary kludge": -// https://go.googlesource.com/proposal/+/refs/heads/master/design/43651-type-parameters.md#pointer-method-example -// https://www.reddit.com/r/golang/comments/uqwh5d/generics_new_value_from_pointer_type_with/ -type PtrAsReceiverWrapper[T any] interface { - Command - *T -} diff --git a/internal/cmdparser/cmdparser_test.go b/internal/cmdparser/cmdparser_test.go deleted file mode 100644 index 10ce66eb..00000000 --- a/internal/cmdparser/cmdparser_test.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package cmdparser - -import ( - "fmt" - "testing" -) - -type TopLevelCommand struct { - Cmd -} - -func (c *TopLevelCommand) DefineCommand(subCommands ...Command) { - c.Options = Options{ - Use: "top-level", - Short: "Hello-World", - Examples: []ExampleInfo{ - {Description: "First example", - Steps: []string{"This is the example"}}, - }, - } - - c.Cmd.DefineCommand(subCommands...) -} - -type SubCommand1 struct { - Cmd - - name string -} - -func (c *SubCommand1) DefineCommand(subCommands ...Command) { - c.Options = Options{ - Use: "sub-command1", - Short: "Sub Command 1", - FirstArgAlternativeForFlag: &AlternativeForFlagInfo{ - Flag: "name", - Value: &c.name, - }, - Run: func() { fmt.Println("Running: Sub Command 1") }, - } - c.Cmd.DefineCommand(subCommands...) - c.AddFlag(FlagOptions{ - Name: "name", - String: &c.name, - Usage: "usage", - }) -} - -type SubCommand11 struct { - Cmd -} - -func (c *SubCommand11) DefineCommand(...Command) { - c.Options = Options{ - Use: "sub-command11", - Short: "Sub Command 11", - Run: func() { fmt.Println("Running: Sub Command 11") }, - } - c.Cmd.DefineCommand() -} - -type SubCommand2 struct { - Cmd -} - -func (c *SubCommand2) DefineCommand(...Command) { - c.Options = Options{ - Use: "sub-command2", - Short: "Sub Command 2", - Aliases: []string{"sub-command2-alias"}, - } - c.Cmd.DefineCommand() -} - -func Test_EndToEnd(t *testing.T) { - subCmd11 := New[*SubCommand11]() - subCmd1 := New[*SubCommand1](subCmd11) - subCmd2 := New[*SubCommand2]() - - topLevel := New[*TopLevelCommand](subCmd1, subCmd2) - - topLevel.IsSubCommand("sub-command2") - topLevel.IsSubCommand("sub-command2-alias") - topLevel.IsSubCommand("--help") - topLevel.IsSubCommand("completion") - - var s string - topLevel.AddFlag(FlagOptions{ - String: &s, - Name: "string", - Usage: "usage", - }) - topLevel.AddFlag(FlagOptions{ - String: &s, - Shorthand: "s", - Name: "string2", - Usage: "usage", - }) - - var i int - topLevel.AddFlag(FlagOptions{ - Int: &i, - Name: "int", - Usage: "usage", - }) - topLevel.AddFlag(FlagOptions{ - Int: &i, - Shorthand: "i", - Name: "int2", - Usage: "usage", - }) - - var b bool - topLevel.AddFlag(FlagOptions{ - Bool: &b, - Name: "bool", - Usage: "usage", - }) - topLevel.AddFlag(FlagOptions{ - Bool: &b, - Shorthand: "b", - Name: "bool2", - Usage: "usage", - }) - - topLevel.ArgsForUnitTesting([]string{"--help"}) - topLevel.Execute() - - topLevel.ArgsForUnitTesting([]string{"sub-command1", "--help"}) - topLevel.Execute() - - topLevel.ArgsForUnitTesting([]string{"sub-command1", "sub-command11"}) - topLevel.Execute() - - topLevel.ArgsForUnitTesting([]string{"sub-command1"}) - topLevel.Execute() -} - -func TestAbstractBase_DefineCommand(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - - c := Cmd{} - c.DefineCommand() -} - -func TestInitialize(t *testing.T) { - Initialize(func() {}) -} diff --git a/internal/cmdparser/dependency/options.go b/internal/cmdparser/dependency/options.go new file mode 100644 index 00000000..7d3d9144 --- /dev/null +++ b/internal/cmdparser/dependency/options.go @@ -0,0 +1,8 @@ +package dependency + +import "github.com/microsoft/go-sqlcmd/internal/output" + +type Options struct { + EndOfLine string + Output *output.Output +} diff --git a/internal/cmdparser/factory.go b/internal/cmdparser/factory.go new file mode 100644 index 00000000..738f9560 --- /dev/null +++ b/internal/cmdparser/factory.go @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package cmdparser + +import ( + "fmt" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/microsoft/go-sqlcmd/internal/output" + "github.com/spf13/cobra" + "os" +) + +// Initialize runs the init func() after the command-line provided by the user +// has been parsed. +func Initialize(init func()) { + cobra.OnInitialize(init) +} + +func New[T PtrAsReceiverWrapper[pointerType], pointerType any](dependencies dependency.Options) (command T) { + if dependencies.Output == nil { + dependencies.Output = output.New(output.Options{ + OutputType: "yaml", + LoggingLevel: 2, + StandardWriter: os.Stdout, + ErrorHandler: func(err error) { + if err != nil { + panic(err) + } + }, + HintHandler: func(hints []string) { fmt.Printf("HINTS: %v\n", hints) }}) + } + if dependencies.EndOfLine == "" { + dependencies.EndOfLine = "\n" + } + + command = new(pointerType) + command.SetCrossCuttingConcerns(dependencies) + command.DefineCommand() + + return +} + +// PtrAsReceiverWrapper per golang design doc "an unfortunate necessary kludge": +// https://go.googlesource.com/proposal/+/refs/heads/master/design/43651-type-parameters.md#pointer-method-example +// https://www.reddit.com/r/golang/comments/uqwh5d/generics_new_value_from_pointer_type_with/ +type PtrAsReceiverWrapper[T any] interface { + Command + *T +} diff --git a/internal/cmdparser/factory_test.go b/internal/cmdparser/factory_test.go new file mode 100644 index 00000000..ceb344ca --- /dev/null +++ b/internal/cmdparser/factory_test.go @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package cmdparser + +import ( + "errors" + "fmt" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +type TopLevelCommand struct { + Cmd +} + +func (c *TopLevelCommand) DefineCommand(...CommandOptions) { + commandOptions := CommandOptions{ + Use: "top-level", + Short: "Hello-World", + Examples: []ExampleOptions{ + {Description: "First example", + Steps: []string{"This is the example"}}, + }, + SubCommands: c.SubCommands(), + } + + c.Cmd.DefineCommand(commandOptions) +} + +func (c *TopLevelCommand) SubCommands() []Command { + return []Command{ + New[*SubCommand1](c.Dependencies()), + New[*SubCommand2](c.Dependencies()), + New[*ErrorCommand](c.Dependencies()), + } +} + +type SubCommand1 struct { + Cmd + + name string +} + +func (c *SubCommand1) DefineCommand(...CommandOptions) { + commandOptions := CommandOptions{ + Use: "sub-command1", + Short: "Sub Command 1", + FirstArgAlternativeForFlag: &AlternativeForFlagOptions{ + Flag: "name", + Value: &c.name, + }, + Run: func() { + c.Output().InfofWithHints([]string{"This is a hint"}, "This is a message") + }, + SubCommands: c.SubCommands(), + } + c.Cmd.DefineCommand(commandOptions) + c.AddFlag(FlagOptions{ + Name: "name", + String: &c.name, + Usage: "usage", + }) +} + +func (c *SubCommand1) SubCommands() []Command { + return []Command{ + New[*SubCommand11](c.Dependencies()), + } +} + +type SubCommand11 struct { + Cmd +} + +func (c *SubCommand11) DefineCommand(...CommandOptions) { + commandOptions := CommandOptions{ + Use: "sub-command11", + Short: "Sub Command 11", + Run: func() { fmt.Println("Running: Sub Command 11") }, + } + c.Cmd.DefineCommand(commandOptions) +} + +type SubCommand2 struct { + Cmd +} + +func (c *SubCommand2) DefineCommand(...CommandOptions) { + commandOptions := CommandOptions{ + Use: "sub-command2", + Short: "Sub Command 2", + Aliases: []string{"sub-command2-alias"}, + } + c.Cmd.DefineCommand(commandOptions) +} + +type ErrorCommand struct { + Cmd +} + +func (c *ErrorCommand) DefineCommand(...CommandOptions) { + commandOptions := CommandOptions{ + Use: "error-command", + Short: "Generate an error", + Run: c.run, + } + c.Cmd.DefineCommand(commandOptions) +} + +func (c *ErrorCommand) run() { + output := c.dependencies.Output + + output.Fatal("This command causes the cli to exit") +} + +func Test_EndToEnd(t *testing.T) { + topLevel := New[*TopLevelCommand](dependency.Options{}) + + topLevel.IsSubCommand("sub-command2") + topLevel.IsSubCommand("sub-command2-alias") + topLevel.IsSubCommand("--help") + topLevel.IsSubCommand("completion") + + var s string + topLevel.AddFlag(FlagOptions{ + String: &s, + Name: "string", + Usage: "usage", + }) + topLevel.AddFlag(FlagOptions{ + String: &s, + Shorthand: "s", + Name: "string2", + Usage: "usage", + }) + + var i int + topLevel.AddFlag(FlagOptions{ + Int: &i, + Name: "int", + Usage: "usage", + }) + topLevel.AddFlag(FlagOptions{ + Int: &i, + Shorthand: "i", + Name: "int2", + Usage: "usage", + }) + + var b bool + topLevel.AddFlag(FlagOptions{ + Bool: &b, + Name: "bool", + Usage: "usage", + }) + topLevel.AddFlag(FlagOptions{ + Bool: &b, + Shorthand: "b", + Name: "bool2", + Usage: "usage", + }) + + topLevel.SetArgsForUnitTesting([]string{"--help"}) + topLevel.Execute() + + topLevel.SetArgsForUnitTesting([]string{"sub-command1", "--help"}) + topLevel.Execute() + + topLevel.SetArgsForUnitTesting([]string{"sub-command1", "sub-command11"}) + topLevel.Execute() + + topLevel.SetArgsForUnitTesting([]string{"sub-command1"}) + topLevel.Execute() +} + +func TestInitialize(t *testing.T) { + Initialize(func() { fmt.Println("Got here") }) +} + +func Test(t *testing.T) { + topLevel := New[*TopLevelCommand](dependency.Options{}) + topLevel.SetArgsForUnitTesting([]string{}) + topLevel.CheckErr(nil) + + topLevel = New[*TopLevelCommand](dependency.Options{}) + topLevel.SetArgsForUnitTesting([]string{}) + topLevel.CheckErr(nil) +} + +func Test2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + topLevel := New[*TopLevelCommand](dependency.Options{}) + topLevel.SetArgsForUnitTesting([]string{}) + topLevel.CheckErr(errors.New("foo")) +} diff --git a/internal/cmdparser/interface.go b/internal/cmdparser/interface.go index 77d748cc..23793e79 100644 --- a/internal/cmdparser/interface.go +++ b/internal/cmdparser/interface.go @@ -3,18 +3,43 @@ package cmdparser -import "github.com/spf13/cobra" +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/spf13/cobra" +) +// Command is an interface for defining and running a command which is +// part of a command line program. Command contains methods for setting +// command options, running the command, and checking for errors. type Command interface { - ArgsForUnitTesting(args []string) + // CheckErr checks if the given error is non-nil and, if it is, it prints the error + // to the output and exits the program with an exit code of 1. CheckErr(error) + + // Command returns the underlying cobra.Command object for this command. + // This is useful for defining subcommands. Command() *cobra.Command - DefineCommand(subCommands ...Command) - Execute() + + // DefineCommand is used to define a new command and its associated + // options, flags, and subcommands. It takes in a CommandOptions + // struct, which allow the caller to specify the command's name, description, + // usage, and behavior. + DefineCommand(...CommandOptions) // IsSubCommand is TEMPORARY code that will be removed when the - // new cobra CLI is enabled by default. It returns true if the command-line + // old Kong CLI is retired. It returns true if the command-line // provided by the user looks like they want the new cobra CLI, e.g. // sqlcmd query, sqlcmd install, sqlcmd --help etc. IsSubCommand(command string) bool + + // SetArgsForUnitTesting method allows a caller to set the arguments for the + // command when running unit tests. This is useful because it allows the caller + // to simulate different command-line input scenarios in their tests. + SetArgsForUnitTesting(args []string) + + // SetCrossCuttingConcerns is used to inject cross-cutting concerns (i.e. dependencies) + // into the Command object (like logging etc.). The dependency.Options allows + // the Command object to have access to the dependencies it needs, without + // having to manage them directly. + SetCrossCuttingConcerns(dependency.Options) } diff --git a/internal/output/verbosity/enum.go b/internal/cmdparser/mode/level.go similarity index 53% rename from internal/output/verbosity/enum.go rename to internal/cmdparser/mode/level.go index c3c40658..3b311ec7 100644 --- a/internal/output/verbosity/enum.go +++ b/internal/cmdparser/mode/level.go @@ -1,14 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -package verbosity +package mode -type Enum int +type Level int const ( - Error Enum = iota - Warn - Info - Debug - Trace + NotTesting Level = iota + Testing ) diff --git a/internal/cmdparser/options.go b/internal/cmdparser/options.go index 4640ac6e..499c8622 100644 --- a/internal/cmdparser/options.go +++ b/internal/cmdparser/options.go @@ -3,6 +3,26 @@ package cmdparser +// The AlternativeForFlagOptions type represents options for defining an alternative +// for a flag. It consists of the name of the flag, as well as a pointer to the +// value to be used as the alternative. This type is typically used in the case +// where the user has provided an argument that should be treated as an alternative +// to a specific flag. +type AlternativeForFlagOptions struct { + Flag string + Value *string +} + +// FlagOptions type represents options for defining a flag for a command-line +// interface. The Name and Shorthand fields specify the long and short names +// for the flag, respectively. The Usage field is a string that describes how the +// flag should be used. The String, DefaultString, Int, DefaultInt, Bool, and +// DefaultBool fields are used to specify the type and default value of the flag, +// if it is a string, int, or bool type. The String and Int fields should be pointers +// to the variables that will store the flag's value, and the Bool field should be +// a pointer to a bool variable that will be set to true if the flag is present. The +// DefaultString, DefaultInt, and DefaultBool fields are the default values to +// use if the flag is not provided by the user. type FlagOptions struct { Name string Shorthand string @@ -18,12 +38,33 @@ type FlagOptions struct { DefaultBool bool } -type Options struct { +// CommandOptions is a struct that allows the caller to specify options for a Command. +// These options include the command's name, description, usage, and behavior. +// The Aliases field specifies alternate names for the command, +// and the Examples field specifies examples of how to use the command. +// The FirstArgAlternativeForFlag field specifies an alternative to the first +// argument when it is provided as a flag, and the Long and Short fields +// specify the command's long and short descriptions, respectively. +// The Run field specifies the behavior of the command when it is executed, +// and the Use field specifies the usage instructions for the command. +// The SubCommands field specifies any subcommands that the command has. +type CommandOptions struct { Aliases []string - Examples []ExampleInfo - FirstArgAlternativeForFlag *AlternativeForFlagInfo + Examples []ExampleOptions + FirstArgAlternativeForFlag *AlternativeForFlagOptions Long string Run func() Short string Use string + SubCommands []Command +} + +// ExampleOptions specifies the details of an example usage of a command. +// It contains a description of the example, and a list of steps that make up +// the example. This type is typically used in conjunction with the Examples +// field of the CommandOptions struct, to provide examples of how to use a +// command in the command's help text. +type ExampleOptions struct { + Description string + Steps []string } diff --git a/internal/cmdparser/test.go b/internal/cmdparser/test.go new file mode 100644 index 00000000..d065b865 --- /dev/null +++ b/internal/cmdparser/test.go @@ -0,0 +1,60 @@ +package cmdparser + +import ( + "github.com/microsoft/go-sqlcmd/internal" + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/output" + "github.com/microsoft/go-sqlcmd/internal/output/verbosity" + "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" + "strings" + "testing" +) + +// Setup internal packages for testing +func TestSetup(t *testing.T) { + o := output.New(output.Options{}) + internal.Initialize( + internal.InitializeOptions{ + ErrorHandler: func(err error) { + if err != nil { + panic(err) + } + }, + TraceHandler: o.Tracef, + HintHandler: func(strings []string) { + o.Infof("HINTS: %v"+sqlcmd.SqlcmdEol, strings) + }, + LineBreak: sqlcmd.SqlcmdEol, + }) + config.SetFileNameForTest(t) + t.Log("Initialized internal packages for testing") +} + +// Run a command expecing it to pass, passing in any supplied args (args are split on " " (space)) +func TestCmd[T PtrAsReceiverWrapper[pointerType], pointerType any](args ...string) { + err := testCmd[T](args...) + + // DEVNOTE: I don't think the code will ever get here (c.Command().Execute() will + // always panic first. This is here to silence code checkers, that require the err return + // variable be checked. + if err != nil { + panic(err) + } +} + +func testCmd[T PtrAsReceiverWrapper[pointerType], pointerType any](args ...string) error { + c := New[T](dependency.Options{ + Output: output.New(output.Options{LoggingLevel: verbosity.Trace}), + }) + c.DefineCommand() + if len(args) > 1 { + panic("Only provide one string of args, they will be split on space") + } else if len(args) == 1 { + c.SetArgsForUnitTesting(strings.Split(args[0], " ")) + } else { + c.SetArgsForUnitTesting([]string{}) + } + err := c.Command().Execute() + return err +} diff --git a/internal/cmdparser/test_test.go b/internal/cmdparser/test_test.go new file mode 100644 index 00000000..7063aba1 --- /dev/null +++ b/internal/cmdparser/test_test.go @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package cmdparser + +import ( + "errors" + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +type TestCommand struct { + Cmd + + throwError string +} + +func (c *TestCommand) DefineCommand(...CommandOptions) { + options := CommandOptions{} + options.Use = "test-cmd" + options.Short = "A test command" + options.FirstArgAlternativeForFlag = &AlternativeForFlagOptions{ + Flag: "throw-error", + Value: &c.throwError, + } + options.Run = func() { + c.Output().InfofWithHints([]string{"This is a hint"}, "Some things to consider") + + if c.throwError == "throw-error" { + c.CheckErr(errors.New("Expected error")) + } + } + + c.Cmd.DefineCommand(options) + c.AddFlag(FlagOptions{Name: "throw-error", Usage: "Throw an error", String: &c.throwError}) +} + +func TestTest(t *testing.T) { + TestSetup(t) + TestCmd[*TestCommand]() +} + +func TestTest2(t *testing.T) { + TestSetup(t) + TestCmd[*TestCommand]("test-cmd") +} + +func TestNextTest(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + TestSetup(t) + TestCmd[*TestCommand](" ", " ") +} + +func TestThrowError(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + TestSetup(t) + TestCmd[*TestCommand]("throw-error") +} + +func TestTest3(t *testing.T) { + TestSetup(t) +} diff --git a/internal/cmdparser/type.go b/internal/cmdparser/type.go index e09df2af..24364488 100644 --- a/internal/cmdparser/type.go +++ b/internal/cmdparser/type.go @@ -3,20 +3,17 @@ package cmdparser -import "github.com/spf13/cobra" - -type AlternativeForFlagInfo struct { - Flag string - Value *string -} +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency" + "github.com/spf13/cobra" +) +// Cmd is the main type used for defining and running command line programs. +// It contains fields and methods for defining the command, setting its options, +// and running the command. type Cmd struct { - Options Options - - command cobra.Command -} - -type ExampleInfo struct { - Description string - Steps []string + dependencies dependency.Options + options CommandOptions + command cobra.Command + unitTesting bool } diff --git a/internal/config/config.go b/internal/config/config.go index 3479aeb1..9f0a3922 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,15 +4,20 @@ package config import ( - . "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" - "github.com/microsoft/go-sqlcmd/internal/file" + . "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/io/file" + "github.com/microsoft/go-sqlcmd/internal/pal" "os" "path/filepath" + "testing" ) var config Sqlconfig var filename string +// SetFileName sets the filename for the file that the application reads from and +// writes to. The file is created if it does not already exist, and Viper is configured +// to use the given filename. func SetFileName(name string) { if name == "" { panic("name is empty") @@ -24,6 +29,15 @@ func SetFileName(name string) { configureViper(filename) } +func SetFileNameForTest(t *testing.T) { + SetFileName(pal.FilenameInUserHomeDotDirectory( + ".sqlcmd", "sqlconfig-"+t.Name())) +} + +// DefaultFileName returns the default filename for the file that the application +// reads from and writes to. This is typically located in the user's home directory +// under the ".sqlcmd" directory. If an error occurs while attempting to retrieve +// the user's home directory, the function will return an empty string. func DefaultFileName() (filename string) { home, err := os.UserHomeDir() checkErr(err) @@ -32,6 +46,10 @@ func DefaultFileName() (filename string) { return } +// Clean resets the application's configuration by setting the Users, Contexts, +// and Endpoints fields to nil, the CurrentContext field to an empty string, +// and saving the updated configuration. This effectively resets the configuration +// to its initial state. func Clean() { config.Users = nil config.Contexts = nil @@ -41,6 +59,11 @@ func Clean() { Save() } +// IsEmpty returns a boolean indicating whether the application's configuration +// is empty. The configuration is considered empty if all of the following fields +// are empty or zero-valued: Users, Contexts, Endpoints, and CurrentContext. +// This function can be used to determine whether the configuration has been +// initialized or reset. func IsEmpty() (isEmpty bool) { if len(config.Users) == 0 && len(config.Contexts) == 0 && @@ -52,6 +75,13 @@ func IsEmpty() (isEmpty bool) { return } +// AddContextWithContainer adds a new context to the application's configuration +// with the given parameters. The context is associated with a container +// identified by its container ID. If any of the required parameters (i.e. containerId, +// imageName, portNumber, username, password, contextName) are empty or +// zero-valued, the function will panic. The function also ensures that the given +// contextName and username are unique, and it encrypts the password if +// requested. The updated configuration is saved to file. func AddContextWithContainer( contextName string, imageName string, @@ -84,8 +114,6 @@ func AddContextWithContainer( endPointName := FindUniqueEndpointName(contextName) userName := username + "@" + contextName - config.ApiVersion = "v1" - config.Kind = "Config" config.CurrentContext = contextName config.Endpoints = append(config.Endpoints, Endpoint{ @@ -124,7 +152,12 @@ func AddContextWithContainer( Save() } -func GetRedactedConfig(raw bool) (c Sqlconfig) { +// RedactedConfig function returns a Sqlconfig struct with the Users field +// having their BasicAuth password field either replaced with the decrypted +// password or the string "REDACTED", depending on the value of the raw +// parameter. This allows the caller to either get the full password or a +// redacted version, where the password is hidden. +func RedactedConfig(raw bool) (c Sqlconfig) { c = config for i := range c.Users { user := c.Users[i] diff --git a/internal/config/config_test.go b/internal/config/config_test.go index a2e0c912..28b06a4b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,16 +4,21 @@ package config import ( - . "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" + . "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "github.com/microsoft/go-sqlcmd/internal/output" "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/microsoft/go-sqlcmd/internal/secret" + "github.com/microsoft/go-sqlcmd/internal/test" "reflect" "strings" "testing" ) func TestConfig(t *testing.T) { + o := output.New(output.Options{LoggingLevel: 4, ErrorHandler: errorCallback, HintHandler: func(hints []string) { + + }}) + type args struct { Config Sqlconfig } @@ -75,8 +80,8 @@ func TestConfig(t *testing.T) { EndpointsExists() EndpointExists("endpoint") GetEndpoint("endpoint") - OutputEndpoints(output.Struct, true) - OutputEndpoints(output.Struct, false) + OutputEndpoints(o.Struct, true) + OutputEndpoints(o.Struct, false) FindFreePortForTds() DeleteEndpoint("endpoint2") DeleteEndpoint("endpoint3") @@ -94,23 +99,23 @@ func TestConfig(t *testing.T) { AddUser(user) AddUser(user) AddUser(user) - UserExists("user") + UserNameExists("user") GetUser("user") UserNameExists("username") - OutputUsers(output.Struct, true) - OutputUsers(output.Struct, false) + OutputUsers(o.Struct, true) + OutputUsers(o.Struct, false) DeleteUser("user3") - GetRedactedConfig(true) - GetRedactedConfig(false) + RedactedConfig(true) + RedactedConfig(false) addContext() addContext() addContext() GetContext("context") - OutputContexts(output.Struct, true) - OutputContexts(output.Struct, false) + OutputContexts(o.Struct, true) + OutputContexts(o.Struct, false) DeleteContext("context3") DeleteContext("context2") DeleteContext("context") @@ -119,10 +124,10 @@ func TestConfig(t *testing.T) { addContext() SetCurrentContextName("context") - GetCurrentContext() + CurrentContext() CurrentContextEndpointHasContainer() - GetContainerId() + ContainerId() RemoveCurrentContext() RemoveCurrentContext() AddContextWithContainer("context", "imageName", 1433, "containerId", "user", "password", false) @@ -223,8 +228,8 @@ func TestUserExists(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if gotExists := UserExists(tt.args.name); gotExists != tt.wantExists { - t.Errorf("UserExists() = %v, want %v", gotExists, tt.wantExists) + if gotExists := UserNameExists(tt.args.name); gotExists != tt.wantExists { + t.Errorf("UserNameExists() = %v, want %v", gotExists, tt.wantExists) } }) } @@ -295,24 +300,14 @@ func TestAddContextWithContainerPanic(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - + defer func() { test.CatchExpectedError(recover(), t) }() AddContextWithContainer(tt.args.contextName, tt.args.imageName, tt.args.portNumber, tt.args.containerId, tt.args.username, tt.args.password, tt.args.encryptPassword) }) } } func TestConfig_AddContextWithNoEndpoint(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() user := "user1" AddContext(Context{ @@ -325,20 +320,13 @@ func TestConfig_AddContextWithNoEndpoint(t *testing.T) { } func TestConfig_GetCurrentContextWithNoContexts(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - GetCurrentContext() + defer func() { test.CatchExpectedError(recover(), t) }() + + CurrentContext() } func TestConfig_GetCurrentContextEndPointNotFoundPanic(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() AddEndpoint(Endpoint{ AssetDetails: &AssetDetails{ @@ -365,14 +353,37 @@ func TestConfig_GetCurrentContextEndPointNotFoundPanic(t *testing.T) { DeleteEndpoint("endpoint") SetCurrentContextName("context") - GetCurrentContext() + CurrentContext() } func TestConfig_DeleteContextThatDoesNotExist(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() + contextOrdinal("does-not-exist") } + +func TestNegConfig_SetFileName(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + SetFileName("") +} + +func TestNegConfig_SetCurrentContextName(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + SetCurrentContextName("does not exist") +} + +func TestNegConfig_SetFileNameForTest(t *testing.T) { + SetFileNameForTest(t) +} + +func TestNegConfig_DefaultFileName(t *testing.T) { + DefaultFileName() +} + +func TestNegConfig_GetContext(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + GetContext("doesnotexist") +} diff --git a/internal/config/context.go b/internal/config/context.go index 0d5e6631..2641d066 100644 --- a/internal/config/context.go +++ b/internal/config/context.go @@ -6,23 +6,83 @@ package config import ( "errors" "fmt" - . "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" - "github.com/microsoft/go-sqlcmd/internal/output" + . "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "strconv" ) +// AddContext adds the context to the sqlconfig file. +// +// Before calling this method, verify the Endpoint exists and give the user +// a descriptive error, (this function will panic, which should never be hit) func AddContext(context Context) { if !EndpointExists(context.Endpoint) { - output.FatalfWithHintExamples([][]string{ - {"Add the endpoint", fmt.Sprintf( - "sqlcmd config add-endpoint --name %v", context.Endpoint)}, - }, "Endpoint '%v' does not exist", context.Endpoint) + panic("Endpoint doesn't exist") } context.Name = FindUniqueContextName(context.Name, *context.User) config.Contexts = append(config.Contexts, context) Save() } +// CurrentContextName returns the name of the current context in the configuration. +// The current context is the one that is currently active and used by the application. +func CurrentContextName() string { + return config.CurrentContext +} + +// ContextExists returns whether a context with the given name exists in the configuration. +// This function iterates over the list of contexts in the configuration and returns +// true if a context with the given name is found. Otherwise, the function returns false. +func ContextExists(name string) (exists bool) { + for _, c := range config.Contexts { + if name == c.Name { + exists = true + break + } + } + return +} + +// CurrentContext returns the current context's endpoint and user from the configuration. +// The function iterates over the list of contexts and endpoints in the configuration and returns the endpoint and user for the current context. +// If the current context does not have an endpoint, the function panics. +func CurrentContext() (endpoint Endpoint, user *User) { + currentContextName := GetCurrentContextOrFatal() + + endPointFound := false + for _, c := range config.Contexts { + if c.Name == currentContextName { + for _, e := range config.Endpoints { + if e.Name == c.Endpoint { + endpoint = e + endPointFound = true + break + } + } + + for _, u := range config.Users { + if u.Name == *c.User { + user = &u + break + } + } + } + } + + if !endPointFound { + panic(fmt.Sprintf( + "Context '%v' has no endpoint. Every context must have an endpoint", + currentContextName, + )) + } + + return +} + +// DeleteContext removes the context with the given name from the application's +// configuration. If the context does not exist, the function does nothing. The +// function also updates the CurrentContext field in the configuration to the +// first remaining context, or an empty string if no contexts remain. The +// updated configuration is saved to file. func DeleteContext(name string) { if ContextExists(name) { ordinal := contextOrdinal(name) @@ -67,12 +127,14 @@ func FindUniqueContextName(name string, username string) (uniqueContextName stri return } -func GetCurrentContextName() string { - return config.CurrentContext -} - +// GetCurrentContextOrFatal returns the name of the current context in the +// configuration or panics if it is not set. +// This function first calls the CurrentContextName function to retrieve the +// current context's name, if the current context's name is empty, the function +// panics with an error message indicating that a context must be set. +// Otherwise, the current context's name is returned. func GetCurrentContextOrFatal() (currentContextName string) { - currentContextName = GetCurrentContextName() + currentContextName = CurrentContextName() if currentContextName == "" { checkErr(errors.New( "no current context. To create a context use `sqlcmd install`, " + @@ -81,13 +143,26 @@ func GetCurrentContextOrFatal() (currentContextName string) { return } +// SetCurrentContextName sets the current context in the configuration to the given name. +// If a context with the given name does not exist, the function panics. +// Otherwise, the CurrentContext field in the configuration object is updated +// with the given name and the configuration is saved to the file. func SetCurrentContextName(name string) { if ContextExists(name) { config.CurrentContext = name Save() + } else { + panic("Context must exist") } } +// RemoveCurrentContext removes the current context from the configuration. +// This function iterates over the list of contexts, endpoints, and users in the +// configuration and removes the current context, its endpoint, and its user. +// If there are no remaining contexts in the configuration after removing the +// current context, the CurrentContext field in the configuration object is set +// to an empty string. Otherwise, the CurrentContext field is set to the name +// of the first remaining context. func RemoveCurrentContext() { currentContextName := config.CurrentContext @@ -125,69 +200,22 @@ func RemoveCurrentContext() { } } -func ContextExists(name string) (exists bool) { - for _, c := range config.Contexts { - if name == c.Name { - exists = true - break - } - } - return -} - -func contextOrdinal(name string) (ordinal int) { - for i, c := range config.Contexts { - if name == c.Name { - ordinal = i - return - } - } - panic("Context not found") -} - -func GetCurrentContext() (endpoint Endpoint, user *User) { - currentContextName := GetCurrentContextOrFatal() - - endPointFound := false - for _, c := range config.Contexts { - if c.Name == currentContextName { - for _, e := range config.Endpoints { - if e.Name == c.Endpoint { - endpoint = e - endPointFound = true - break - } - } - - for _, u := range config.Users { - if u.Name == *c.User { - user = &u - break - } - } - } - } - - if !endPointFound { - panic(fmt.Sprintf( - "Context '%v' has no endpoint. Every context must have an endpoint", - currentContextName, - )) - } - - return -} - +// GetContext retrieves a context from the configuration by its name. +// If the context does not exist, the function panics. +// If the context is not found, the function panics to indicate that the context must exist. func GetContext(name string) (context Context) { for _, c := range config.Contexts { if name == c.Name { context = c - break + return } } - return + panic("Context does not exist") } +// OutputContexts outputs the list of contexts in the configuration. +// The output can be either detailed, which includes all information about each context, or a list of context names only. +// This is controlled by the detailed flag, which is passed to the function. func OutputContexts(formatter func(interface{}) []byte, detailed bool) { if detailed { formatter(config.Contexts) @@ -201,3 +229,13 @@ func OutputContexts(formatter func(interface{}) []byte, detailed bool) { formatter(names) } } + +func contextOrdinal(name string) (ordinal int) { + for i, c := range config.Contexts { + if name == c.Name { + ordinal = i + return + } + } + panic("Context not found") +} diff --git a/internal/config/endpoint-container.go b/internal/config/endpoint-container.go index 0f79359e..dfc87f19 100644 --- a/internal/config/endpoint-container.go +++ b/internal/config/endpoint-container.go @@ -5,7 +5,10 @@ package config import "fmt" -func GetContainerId() (containerId string) { +// This function gets the container ID of the current context's endpoint. It first +// checks if the current context exists and has an endpoint. Then it checks if the +// endpoint has a container and retrieves its ID. Otherwise, it returns the container ID. +func ContainerId() (containerId string) { currentContextName := config.CurrentContext if currentContextName == "" { @@ -33,6 +36,9 @@ func GetContainerId() (containerId string) { panic("Id not found") } +// CurrentContextEndpointHasContainer() checks if the current context endpoint +// has a container. If the endpoint has a AssetDetails.ContainerDetails field, the function +// returns true, otherwise it returns false. func CurrentContextEndpointHasContainer() (exists bool) { currentContextName := config.CurrentContext @@ -57,6 +63,11 @@ func CurrentContextEndpointHasContainer() (exists bool) { return } +// FindFreePortForTds is used to find a free port number to use for the TDS +// protocol. It starts at port number 1433 and continues until it finds a port +// number that is not currently in use by any of the endpoints in the +// configuration. It also checks that the port is available on the local machine. +// If no available port is found after trying up to port number 5000, the function panics. func FindFreePortForTds() (portNumber int) { const startingPortNumber = 1433 diff --git a/internal/config/endpoint-container_test.go b/internal/config/endpoint-container_test.go index 51f0ac94..abe668fc 100644 --- a/internal/config/endpoint-container_test.go +++ b/internal/config/endpoint-container_test.go @@ -1,10 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + package config import ( + . "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "strings" "testing" - - . "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" ) // TestCurrentContextEndpointHasContainer verifies the function panics when @@ -28,7 +30,7 @@ func TestGetContainerId(t *testing.T) { t.Errorf("The code did not panic") } }() - GetContainerId() + ContainerId() } func TestGetContainerId2(t *testing.T) { @@ -59,7 +61,7 @@ func TestGetContainerId2(t *testing.T) { }) SetCurrentContextName("context") - GetContainerId() + ContainerId() } func TestGetContainerId3(t *testing.T) { @@ -92,18 +94,5 @@ func TestGetContainerId3(t *testing.T) { }) SetCurrentContextName("context") - GetContainerId() -} - -func TestGetContainerId4(t *testing.T) { - Clean() - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - - SetCurrentContextName("badbad") - - GetContainerId() + ContainerId() } diff --git a/internal/config/endpoint.go b/internal/config/endpoint.go index 6fa4ed04..bfb2625c 100644 --- a/internal/config/endpoint.go +++ b/internal/config/endpoint.go @@ -5,10 +5,16 @@ package config import ( "fmt" - . "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" + . "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "strconv" ) +// AddEndpoint adds a new endpoint to the application's configuration with +// the given parameters. If the provided endpoint name is not unique, the +// function will modify it to ensure that it is unique before adding it to the +// configuration. The updated configuration is saved to file, and the function +// returns the actual endpoint name that was added. This may be different +// from the provided name if the original name was not unique. func AddEndpoint(endpoint Endpoint) (actualEndpointName string) { endpoint.Name = FindUniqueEndpointName(endpoint.Name) config.Endpoints = append(config.Endpoints, endpoint) @@ -17,6 +23,9 @@ func AddEndpoint(endpoint Endpoint) (actualEndpointName string) { return endpoint.Name } +// DeleteEndpoint removes the endpoint with the given name from the application's +// configuration. If the endpoint does not exist, the function does nothing. The +// updated configuration is saved to file. func DeleteEndpoint(name string) { if EndpointExists(name) { ordinal := endpointOrdinal(name) @@ -25,6 +34,9 @@ func DeleteEndpoint(name string) { } } +// EndpointsExists returns whether there are any endpoints in the configuration. +// This function checks the length of the Endpoints field in the configuration +// object and returns true if it is greater than zero. Otherwise, the function returns false. func EndpointsExists() (exists bool) { if len(config.Endpoints) > 0 { exists = true @@ -33,6 +45,28 @@ func EndpointsExists() (exists bool) { return } +// EndpointExists returns whether an endpoint with the given name exists in +// the configuration. This function iterates over the list of endpoints in the +// configuration and returns true if an endpoint with the given name is found. +// Otherwise, the function returns false. +func EndpointExists(name string) (exists bool) { + if name == "" { + panic("Name must not be empty") + } + + for _, c := range config.Endpoints { + if name == c.Name { + exists = true + break + } + } + return +} + +// EndpointNameExists returns whether an endpoint with the given name exists +// in the configuration. This function iterates over the list of endpoints in the +// configuration and returns true if an endpoint with the given name is found. +// Otherwise, the function returns false. func EndpointNameExists(name string) (exists bool) { for _, v := range config.Endpoints { if v.Name == name { @@ -44,6 +78,11 @@ func EndpointNameExists(name string) (exists bool) { return } +// FindUniqueEndpointName returns a unique name for an endpoint with the +// given name. +// If an endpoint with the given name does not exist in the configuration, the +// function returns the given name. Otherwise, the function returns a modified +// version of the given name that includes a number at the end to make it unique. func FindUniqueEndpointName(name string) (uniqueEndpointName string) { if !EndpointNameExists(name) { uniqueEndpointName = name @@ -67,30 +106,7 @@ func FindUniqueEndpointName(name string) (uniqueEndpointName string) { return } -func EndpointExists(name string) (exists bool) { - if name == "" { - panic("Name must not be empty") - } - - for _, c := range config.Endpoints { - if name == c.Name { - exists = true - break - } - } - return -} - -func endpointOrdinal(name string) (ordinal int) { - for i, c := range config.Endpoints { - if name == c.Name { - ordinal = i - break - } - } - return -} - +// GetEndpoint returns the endpoint with the given name from the configuration. func GetEndpoint(name string) (endpoint Endpoint) { for _, e := range config.Endpoints { if name == e.Name { @@ -101,6 +117,13 @@ func GetEndpoint(name string) (endpoint Endpoint) { return } +// OutputEndpoints outputs the list of endpoints in the configuration in a specified format. +// This function takes a formatter function and a flag indicating whether to +// output detailed information or just the names of the endpoints. +// If detailed information is requested, the formatter function is called with +// the list of endpoints in the configuration as the argument. +// Otherwise, the formatter function is called with a list of just the names of +// the endpoints in the configuration. func OutputEndpoints(formatter func(interface{}) []byte, detailed bool) { if detailed { formatter(config.Endpoints) @@ -114,3 +137,13 @@ func OutputEndpoints(formatter func(interface{}) []byte, detailed bool) { formatter(names) } } + +func endpointOrdinal(name string) (ordinal int) { + for i, c := range config.Endpoints { + if name == c.Name { + ordinal = i + break + } + } + return +} diff --git a/internal/config/endpoint_test.go b/internal/config/endpoint_test.go index f485a0f0..eb29222d 100644 --- a/internal/config/endpoint_test.go +++ b/internal/config/endpoint_test.go @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + package config import "testing" diff --git a/internal/config/initialize.go b/internal/config/initialize.go index 3796e236..2c943287 100644 --- a/internal/config/initialize.go +++ b/internal/config/initialize.go @@ -4,8 +4,8 @@ package config import ( + "fmt" "github.com/microsoft/go-sqlcmd/internal/net" - "github.com/microsoft/go-sqlcmd/internal/output" "github.com/microsoft/go-sqlcmd/internal/secret" ) @@ -13,21 +13,32 @@ var encryptCallback func(plainText string, encrypt bool) (cipherText string) var decryptCallback func(cipherText string, decrypt bool) (secret string) var isLocalPortAvailableCallback func(port int) (portAvailable bool) +// init sets up the package to work with a set of handlers to be used for the period +// before the command-line has been parsed func init() { errorHandler := func(err error) { if err != nil { panic(err) } } + traceHandler := func(format string, a ...any) { + fmt.Printf(format, a...) + } Initialize( errorHandler, - output.Tracef, + traceHandler, secret.Encode, secret.Decode, net.IsLocalPortAvailable) } +// Initialize sets the callback functions used by the config package. +// These callback functions are used for logging errors, tracing debug messages, +// encrypting and decrypting data, and checking if a local port is available. +// The callback functions are passed to the function as arguments. +// This function should be called at the start of the application to ensure that the +// config package has the necessary callback functions available. func Initialize( errorHandler func(err error), traceHandler func(format string, a ...any), diff --git a/internal/config/user.go b/internal/config/user.go index 4d75c8f6..2fb33a04 100644 --- a/internal/config/user.go +++ b/internal/config/user.go @@ -5,10 +5,14 @@ package config import ( "fmt" - . "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" + . "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "strconv" ) +// AddUser adds a new user to the configuration. +// The user's name is first modified to be unique by calling the FindUniqueUserName function. +// If the user's authentication type is "basic", the user's BasicAuth field must be non-nil and the username must be non-empty. +// The new user is then added to the list of users in the configuration object and the configuration is saved to the file. func AddUser(user User) { user.Name = FindUniqueUserName(user.Name) @@ -26,55 +30,21 @@ func AddUser(user User) { Save() } +// DeleteUser removes a user from the configuration by their name. +// If the user does not exist, the function does nothing. +// Otherwise, the user is removed from the list of users in the configuration object and the configuration is saved to the file. func DeleteUser(name string) { - if UserExists(name) { + if UserNameExists(name) { ordinal := userOrdinal(name) config.Users = append(config.Users[:ordinal], config.Users[ordinal+1:]...) Save() } } -func UserNameExists(name string) (exists bool) { - for _, v := range config.Users { - if v.Name == name { - exists = true - break - } - } - - return -} - -func UserExists(name string) (exists bool) { - for _, v := range config.Users { - if name == v.Name { - exists = true - break - } - } - return -} - -func userOrdinal(name string) (ordinal int) { - for i, c := range config.Users { - if name == c.Name { - ordinal = i - break - } - } - return -} - -func GetUser(name string) (user User) { - for _, v := range config.Users { - if name == v.Name { - user = v - break - } - } - return -} - +// FindUniqueUserName generates a unique user name based on the given name. +// If the given name is not already in use, it is returned as-is. +// Otherwise, a number is appended to the end of the given name to make it unique. +// This number starts at 2 and is incremented until a unique user name is found. func FindUniqueUserName(name string) (uniqueUserName string) { if !UserNameExists(name) { uniqueUserName = name @@ -98,6 +68,20 @@ func FindUniqueUserName(name string) (uniqueUserName string) { return } +// GetUser retrieves a user from the configuration by their name. +func GetUser(name string) (user User) { + for _, v := range config.Users { + if name == v.Name { + user = v + return + } + } + panic("User must exist") +} + +// OutputUsers outputs the list of users in the configuration. +// The output can be either detailed, which includes all information about each user, or a list of user names only. +// This is controlled by the detailed flag, which is passed to the function. func OutputUsers(formatter func(interface{}) []byte, detailed bool) { if detailed { formatter(config.Users) @@ -111,3 +95,31 @@ func OutputUsers(formatter func(interface{}) []byte, detailed bool) { formatter(names) } } + +// UserNameExists checks if a user with the given name exists in the configuration. +// It iterates over the list of users in the configuration object and returns true if a user with the given name is found. +// Otherwise, it returns false. +// This function can be useful for checking if a user with a given name already exists before adding a new user or updating an existing user. +func UserNameExists(name string) (exists bool) { + for _, v := range config.Users { + if v.Name == name { + exists = true + break + } + } + + return +} + +// userOrdinal returns the index of a user in the list of users in the configuration object. +// If the user does not exist, the function returns -1. +// This function iterates over the list of users and returns the index of the user with the given name. +func userOrdinal(name string) (ordinal int) { + for i, c := range config.Users { + if name == c.Name { + ordinal = i + break + } + } + return +} diff --git a/internal/config/user_test.go b/internal/config/user_test.go index 442d7bc0..24887e02 100644 --- a/internal/config/user_test.go +++ b/internal/config/user_test.go @@ -1,9 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + package config import ( + . "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/test" "testing" - - . "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" ) func TestAddUser(t *testing.T) { @@ -19,12 +22,8 @@ func TestAddUser(t *testing.T) { }) } -func TestAddUser2(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() +func TestNegAddUser(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() AddUser(User{ Name: "", AuthenticationType: "basic", @@ -35,3 +34,8 @@ func TestAddUser2(t *testing.T) { }, }) } + +func TestNegAddUser2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + GetUser("doesnotexist") +} diff --git a/internal/config/viper.go b/internal/config/viper.go index bec0ba8e..9409fd8d 100644 --- a/internal/config/viper.go +++ b/internal/config/viper.go @@ -5,20 +5,14 @@ package config import ( "bytes" + "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/spf13/viper" "gopkg.in/yaml.v2" ) -func configureViper(configFile string) { - if configFile == "" { - panic("Must provide configFile") - } - - viper.SetConfigType("yaml") - viper.SetEnvPrefix("SQLCMD") - viper.SetConfigFile(configFile) -} - +// Load loads the configuration from the file specified by the SetFileName() function. +// Any errors encountered while marshalling or saving the configuration are checked +// and handled by the injected errorHandler (via the checkErr function). func Load() { if filename == "" { panic("Must call config.SetFileName()") @@ -33,14 +27,22 @@ func Load() { err = viper.Unmarshal(&config) checkErr(err) - trace("Config loaded from file: %v", viper.ConfigFileUsed()) + trace("Config loaded from file: %v"+pal.LineBreak(), viper.ConfigFileUsed()) } +// Save marshals the current configuration object and saves it to the configuration +// file previously specified by the SetFileName variable. +// Any errors encountered while marshalling or saving the configuration are checked +// and handled by the injected errorHandler (via the checkErr function). func Save() { if filename == "" { panic("Must call config.SetFileName()") } + if config.Version == "" { + config.Version = "v1" + } + b, err := yaml.Marshal(&config) checkErr(err) err = viper.ReadConfig(bytes.NewReader(b)) @@ -49,6 +51,21 @@ func Save() { checkErr(err) } +// GetConfigFileUsed returns the path to the configuration file used by the Viper library. func GetConfigFileUsed() string { return viper.ConfigFileUsed() } + +// configureViper initializes the Viper library with the given configuration file. +// This function sets the configuration file type to "yaml" and sets the environment variable prefix to "SQLCMD". +// It also sets the configuration file to use to the one provided as an argument to the function. +// This function is intended to be called at the start of the application to configure Viper before any other code uses it. +func configureViper(configFile string) { + if configFile == "" { + panic("Must provide configFile") + } + + viper.SetConfigType("yaml") + viper.SetEnvPrefix("SQLCMD") + viper.SetConfigFile(configFile) +} diff --git a/internal/config/viper_test.go b/internal/config/viper_test.go index acee3475..86717662 100644 --- a/internal/config/viper_test.go +++ b/internal/config/viper_test.go @@ -1,12 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + package config -import "testing" +import ( + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) func Test_configureViper(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() + configureViper("") } + +func Test_Load(t *testing.T) { + SetFileNameForTest(t) + Clean() + Load() +} + +func TestNeg_Load(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + filename = "" + Load() +} + +func TestNeg_Save(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + filename = "" + Save() +} diff --git a/internal/container/controller.go b/internal/container/controller.go index b9a15d96..80af9ff6 100644 --- a/internal/container/controller.go +++ b/internal/container/controller.go @@ -22,6 +22,11 @@ type Controller struct { cli *client.Client } +// NewController creates a new Controller struct, which is used to interact +// with a container runtime engine (e.g. Docker or Podman etc.). It initializes +// engine client by calling client.NewClientWithOpts(client.FromEnv) and +// setting the cli field of the Controller struct to the result. +// The Controller struct is then returned. func NewController() (c *Controller) { var err error c = new(Controller) @@ -31,7 +36,13 @@ func NewController() (c *Controller) { return } -func (c *Controller) EnsureImage(image string) (err error) { +// EnsureImage creates a new instance of the Controller struct and initializes +// the container engine client by calling client.NewClientWithOpts() with +// the client.FromEnv option. It returns the Controller instance and an error +// if one occurred while creating the client. The Controller struct has a +// method EnsureImage() which pulls an image with the given name from +// a registry and logs the output to the console. +func (c Controller) EnsureImage(image string) (err error) { var reader io.ReadCloser trace("Running ImagePull for image %s", image) @@ -51,7 +62,10 @@ func (c *Controller) EnsureImage(image string) (err error) { return } -func (c *Controller) ContainerRun(image string, env []string, port int, command []string, unitTestFailure bool) string { +// ContainerRun creates a new container using the provided image and env values +// and binds it to the specified port number. It then starts the container and returns +// the ID of the container. +func (c Controller) ContainerRun(image string, env []string, port int, command []string, unitTestFailure bool) string { hostConfig := &container.HostConfig{ PortBindings: nat.PortMap{ nat.Port("1433/tcp"): []nat.PortBinding{ @@ -89,8 +103,16 @@ func (c *Controller) ContainerRun(image string, env []string, port int, command return resp.ID } -// ContainerWaitForLogEntry waits for text substring in containers logs -func (c *Controller) ContainerWaitForLogEntry(id string, text string) { +// ContainerWaitForLogEntry is used to wait for a specific string to be written +// to the logs of a container with the given ID. The function takes in the ID +// of the container and the string to look for in the logs. It creates a reader +// to stream the logs from the container, and scans the logs line by line until +// it finds the specified string. Once the string is found, the function breaks +// out of the loop and returns. +// +// This function is useful for waiting until a specific event has occurred in the +// container (e.g. a server has started up) before continuing with other operations. +func (c Controller) ContainerWaitForLogEntry(id string, text string) { options := types.ContainerLogsOptions{ ShowStdout: true, ShowStderr: false, @@ -119,7 +141,9 @@ func (c *Controller) ContainerWaitForLogEntry(id string, text string) { } } -func (c *Controller) ContainerStop(id string) (err error) { +// ContainerStop stops the container with the given ID. The function returns +// an error if there is an issue stopping the container. +func (c Controller) ContainerStop(id string) (err error) { if id == "" { panic("Must pass in non-empty id") } @@ -128,7 +152,12 @@ func (c *Controller) ContainerStop(id string) (err error) { return } -func (c *Controller) ContainerFiles(id string, filespec string) (files []string) { +// ContainerFiles returns a list of files matching a specified pattern within +// a given container. It takes an id argument, which specifies the ID of the +// container to search, and a filespec argument, which is a string pattern used +// to match files within the container. The function returns a []string slice +// containing the names of all files that match the specified pattern. +func (c Controller) ContainerFiles(id string, filespec string) (files []string) { if id == "" { panic("Must pass in non-empty id") } @@ -174,7 +203,11 @@ func (c *Controller) ContainerFiles(id string, filespec string) (files []string) return strings.Split(string(stdout), "\n") } -func (c *Controller) ContainerExists(id string) (exists bool) { +// ContainerExists checks if a container with the given ID exists in the system. +// It does this by using the container runtime API to list all containers and +// filtering by the given ID. If a container with the given ID is found, it +// returns true; otherwise, it returns false. +func (c Controller) ContainerExists(id string) (exists bool) { f := filters.NewArgs() f.Add( "id", id, @@ -195,7 +228,11 @@ func (c *Controller) ContainerExists(id string) (exists bool) { return } -func (c *Controller) ContainerRemove(id string) (err error) { +// ContainerRemove removes the container with the specified ID using the +// container runtime API. The function takes the ID of the container to be +// removed as an input argument, and returns an error if one occurs during +// the removal process. +func (c Controller) ContainerRemove(id string) (err error) { if id == "" { panic("Must pass in non-empty id") } diff --git a/internal/container/controller_test.go b/internal/container/controller_test.go index 79fe5d4f..1cd429ba 100644 --- a/internal/container/controller_test.go +++ b/internal/container/controller_test.go @@ -5,8 +5,7 @@ package container import ( "fmt" - "github.com/docker/docker/client" - "strings" + "github.com/microsoft/go-sqlcmd/internal/test" "testing" ) @@ -29,44 +28,17 @@ func TestController_EnsureImage(t *testing.T) { repo, tag) - type fields struct { - cli *client.Client - } - type args struct { - image string - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ - {"default", fields{nil}, args{imageName}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // If test name ends in 'Panic' expect a Panic - if strings.HasSuffix(tt.name, "Panic") { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - } - - c := NewController() - err := c.EnsureImage(tt.args.image) - checkErr(err) - id := c.ContainerRun(tt.args.image, []string{}, port, []string{"ash", "-c", "echo 'Hello World'; sleep 1"}, false) - c.ContainerWaitForLogEntry(id, "Hello World") - c.ContainerExists(id) - c.ContainerFiles(id, "*.mdf") - err = c.ContainerStop(id) - checkErr(err) - err = c.ContainerRemove(id) - checkErr(err) - }) - } + c := NewController() + err := c.EnsureImage(imageName) + checkErr(err) + id := c.ContainerRun(imageName, []string{}, port, []string{"ash", "-c", "echo 'Hello World'; sleep 1"}, false) + c.ContainerWaitForLogEntry(id, "Hello World") + c.ContainerExists(id) + c.ContainerFiles(id, "*.mdf") + err = c.ContainerStop(id) + checkErr(err) + err = c.ContainerRemove(id) + checkErr(err) } func TestController_ContainerRunFailure(t *testing.T) { @@ -80,11 +52,7 @@ func TestController_ContainerRunFailure(t *testing.T) { repo, tag) - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() c := NewController() c.ContainerRun( @@ -107,20 +75,20 @@ func TestController_ContainerRunFailureCleanup(t *testing.T) { repo, tag) - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() c := NewController() - c.ContainerRun( + id := c.ContainerRun( imageName, []string{}, 0, []string{"ash", "-c", "echo 'Hello World'; sleep 1"}, true, ) + err := c.ContainerStop(id) + checkErr(err) + err = c.ContainerRemove(id) + checkErr(err) } func TestController_ContainerStopNeg(t *testing.T) { @@ -134,22 +102,18 @@ func TestController_ContainerStopNeg(t *testing.T) { repo, tag) - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() c := NewController() - c.ContainerRun(imageName, []string{}, 0, []string{"ash", "-c", "echo 'Hello World'; sleep 1"}, false) + id := c.ContainerRun(imageName, []string{}, 0, []string{"ash", "-c", "echo 'Hello World'; sleep 1"}, false) + err := c.ContainerStop(id) + checkErr(err) + err = c.ContainerRemove(id) + checkErr(err) } func TestController_ContainerStopNeg2(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() c := NewController() err := c.ContainerStop("") @@ -157,11 +121,7 @@ func TestController_ContainerStopNeg2(t *testing.T) { } func TestController_ContainerRemoveNeg(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() c := NewController() err := c.ContainerRemove("") @@ -169,22 +129,14 @@ func TestController_ContainerRemoveNeg(t *testing.T) { } func TestController_ContainerFilesNeg(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() c := NewController() c.ContainerFiles("", "") } func TestController_ContainerFilesNeg2(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() c := NewController() c.ContainerFiles("id", "") diff --git a/internal/container/docker.go b/internal/container/docker.go index 83ebf2e8..5687944c 100644 --- a/internal/container/docker.go +++ b/internal/container/docker.go @@ -10,6 +10,10 @@ import ( "net/http" ) +// ListTags lists all tags for a container image located at a given +// path in the container registry. It takes the path to the image and the +// URL of the registry as input and returns a slice of strings containing +// the tags. func ListTags(path string, baseURL string) []string { ctx := context.Background() repo, err := reference.WithName(path) diff --git a/internal/intialize.go b/internal/intialize.go index 88647ce7..6303a300 100644 --- a/internal/intialize.go +++ b/internal/intialize.go @@ -6,43 +6,42 @@ package internal import ( "github.com/microsoft/go-sqlcmd/internal/config" "github.com/microsoft/go-sqlcmd/internal/container" - "github.com/microsoft/go-sqlcmd/internal/file" + "github.com/microsoft/go-sqlcmd/internal/io/file" "github.com/microsoft/go-sqlcmd/internal/mssql" "github.com/microsoft/go-sqlcmd/internal/net" - "github.com/microsoft/go-sqlcmd/internal/output" - "github.com/microsoft/go-sqlcmd/internal/output/verbosity" "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/microsoft/go-sqlcmd/internal/secret" - "os" ) type InitializeOptions struct { ErrorHandler func(error) + TraceHandler func(format string, a ...any) HintHandler func([]string) - OutputType string - LoggingLevel int + LineBreak string } +// Initialize initializes various dependencies for the application with the provided options. +// The dependencies that are initialized include file, mssql, config, container, +// secret, net, and pal. This function is typically called at the start of the application +// to ensure that all dependencies are properly initialized before any other code is executed. func Initialize(options InitializeOptions) { if options.ErrorHandler == nil { panic("ErrorHandler is nil") } + if options.TraceHandler == nil { + panic("TraceHandler is nil") + } if options.HintHandler == nil { panic("HintHandler is nil") } - if options.OutputType == "" { - panic("OutputType is empty") - } - if options.LoggingLevel <= 0 || options.LoggingLevel > 4 { - panic("LoggingLevel must be between 1 and 4 ") + if options.LineBreak == "" { + panic("LineBreak is empty") } - - file.Initialize(options.ErrorHandler, output.Tracef) - mssql.Initialize(options.ErrorHandler, output.Tracef, secret.Decode) - output.Initialize(options.ErrorHandler, output.Tracef, options.HintHandler, os.Stdout, options.OutputType, verbosity.Enum(options.LoggingLevel)) - config.Initialize(options.ErrorHandler, output.Tracef, secret.Encode, secret.Decode, net.IsLocalPortAvailable) - container.Initialize(options.ErrorHandler, output.Tracef) + file.Initialize(options.ErrorHandler, options.TraceHandler) + mssql.Initialize(options.ErrorHandler, options.TraceHandler, secret.Decode) + config.Initialize(options.ErrorHandler, options.TraceHandler, secret.Encode, secret.Decode, net.IsLocalPortAvailable) + container.Initialize(options.ErrorHandler, options.TraceHandler) secret.Initialize(options.ErrorHandler) - net.Initialize(options.ErrorHandler, output.Tracef) - pal.Initialize(options.ErrorHandler) + net.Initialize(options.ErrorHandler, options.TraceHandler) + pal.Initialize(options.ErrorHandler, options.LineBreak) } diff --git a/internal/intialize_test.go b/internal/intialize_test.go index 732f6ddc..daaf7563 100644 --- a/internal/intialize_test.go +++ b/internal/intialize_test.go @@ -4,40 +4,77 @@ package internal import ( + "github.com/microsoft/go-sqlcmd/internal/output" + "github.com/microsoft/go-sqlcmd/internal/test" "testing" ) func TestInitialize(t *testing.T) { - type args struct { - errorHandler func(error) - hintHandler func([]string) - outputType string - loggingLevel int + output := output.New(output.Options{HintHandler: func(hints []string) { + + }, ErrorHandler: func(err error) { + + }}) + options := InitializeOptions{ + ErrorHandler: func(err error) { + if err != nil { + panic(err) + } + }, + HintHandler: func(strings []string) {}, + TraceHandler: output.Tracef, + LineBreak: "\n", } - tests := []struct { - name string - args args - }{ - {"default", args{ - func(err error) { - if err != nil { - panic(err) - } - }, - func(strings []string) {}, - "yaml", - 2, - }}, + Initialize(options) +} + +func TestNegInitialize(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + options := InitializeOptions{ + ErrorHandler: nil, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - options := InitializeOptions{ - ErrorHandler: tt.args.errorHandler, - HintHandler: tt.args.hintHandler, - OutputType: tt.args.outputType, - LoggingLevel: tt.args.loggingLevel, - } - Initialize(options) - }) + Initialize(options) +} + +func TestNegInitialize2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + options := InitializeOptions{ + ErrorHandler: func(err error) {}, + } + Initialize(options) +} + +func TestNegInitialize3(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + options := InitializeOptions{ + ErrorHandler: func(err error) {}, + TraceHandler: func(format string, a ...any) {}, + } + Initialize(options) +} + +func TestNegInitialize4(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + options := InitializeOptions{ + ErrorHandler: func(err error) {}, + TraceHandler: func(format string, a ...any) {}, + HintHandler: func(strings []string) {}, + } + Initialize(options) +} + +func TestNegInitialize5(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + options := InitializeOptions{ + ErrorHandler: func(err error) {}, + TraceHandler: func(format string, a ...any) {}, + HintHandler: func(strings []string) {}, + LineBreak: "", } + Initialize(options) } diff --git a/internal/file/error.go b/internal/io/file/error.go similarity index 100% rename from internal/file/error.go rename to internal/io/file/error.go diff --git a/internal/file/error_test.go b/internal/io/file/error_test.go similarity index 100% rename from internal/file/error_test.go rename to internal/io/file/error_test.go diff --git a/internal/file/file.go b/internal/io/file/file.go similarity index 57% rename from internal/file/file.go rename to internal/io/file/file.go index 692e18e2..9a2c3312 100644 --- a/internal/file/file.go +++ b/internal/io/file/file.go @@ -4,11 +4,15 @@ package file import ( - "github.com/microsoft/go-sqlcmd/internal/folder" + "github.com/microsoft/go-sqlcmd/internal/io/folder" "os" "path/filepath" ) +// CreateEmptyIfNotExists creates an empty file with the given filename if it +// does not already exist. If the parent directory of the file does not exist, the +// function will create it. The function is useful for ensuring that a file is +// present before writing to it. func CreateEmptyIfNotExists(filename string) { if filename == "" { panic("filename must not be empty") @@ -30,6 +34,8 @@ func CreateEmptyIfNotExists(filename string) { } } +// Exists checks if a file with the given filename exists in the file system. It +// returns a boolean value indicating whether the file exists or not. func Exists(filename string) (exists bool) { if filename == "" { panic("filename must not be empty") @@ -42,6 +48,8 @@ func Exists(filename string) (exists bool) { return } +// Remove is used to remove a file with the specified filename. The function +// takes in the name of the file as an argument and deletes it from the file system. func Remove(filename string) { err := os.Remove(filename) checkErr(err) diff --git a/internal/file/file_test.go b/internal/io/file/file_test.go similarity index 95% rename from internal/file/file_test.go rename to internal/io/file/file_test.go index b927aafe..f17132db 100644 --- a/internal/file/file_test.go +++ b/internal/io/file/file_test.go @@ -4,8 +4,8 @@ package file_test import ( - "github.com/microsoft/go-sqlcmd/internal/file" - "github.com/microsoft/go-sqlcmd/internal/folder" + "github.com/microsoft/go-sqlcmd/internal/io/file" + "github.com/microsoft/go-sqlcmd/internal/io/folder" "os" "path/filepath" "strings" diff --git a/internal/file/initialize.go b/internal/io/file/initialize.go similarity index 91% rename from internal/file/initialize.go rename to internal/io/file/initialize.go index 1b4bfd39..0d1bdbb2 100644 --- a/internal/file/initialize.go +++ b/internal/io/file/initialize.go @@ -4,7 +4,7 @@ package file import ( - "github.com/microsoft/go-sqlcmd/internal/folder" + "github.com/microsoft/go-sqlcmd/internal/io/folder" ) func init() { diff --git a/internal/file/trace.go b/internal/io/file/trace.go similarity index 100% rename from internal/file/trace.go rename to internal/io/file/trace.go diff --git a/internal/folder/error.go b/internal/io/folder/error.go similarity index 100% rename from internal/folder/error.go rename to internal/io/folder/error.go diff --git a/internal/folder/error_test.go b/internal/io/folder/error_test.go similarity index 100% rename from internal/folder/error_test.go rename to internal/io/folder/error_test.go diff --git a/internal/folder/folder.go b/internal/io/folder/folder.go similarity index 74% rename from internal/folder/folder.go rename to internal/io/folder/folder.go index be1c54d2..ba0de8a5 100644 --- a/internal/folder/folder.go +++ b/internal/io/folder/folder.go @@ -7,6 +7,7 @@ import ( "os" ) +// MkdirAll creates a directory with the given name if it does not already exist. func MkdirAll(folder string) { if folder == "" { panic("folder must not be empty") @@ -18,6 +19,7 @@ func MkdirAll(folder string) { } } +// RemoveAll removes a folder and all of its contents at the given path. func RemoveAll(folder string) { err := os.RemoveAll(folder) checkErr(err) diff --git a/internal/folder/folder_test.go b/internal/io/folder/folder_test.go similarity index 100% rename from internal/folder/folder_test.go rename to internal/io/folder/folder_test.go diff --git a/internal/folder/initialize.go b/internal/io/folder/initialize.go similarity index 100% rename from internal/folder/initialize.go rename to internal/io/folder/initialize.go diff --git a/internal/folder/trace.go b/internal/io/folder/trace.go similarity index 100% rename from internal/folder/trace.go rename to internal/io/folder/trace.go diff --git a/internal/mssql/initialize.go b/internal/mssql/initialize.go index 2e2e4623..acdedeca 100644 --- a/internal/mssql/initialize.go +++ b/internal/mssql/initialize.go @@ -5,17 +5,6 @@ package mssql var decryptCallback func(cipherText string, decrypt bool) (secret string) -func init() { - Initialize( - func(err error) { - if err != nil { - panic(err) - } - }, - func(format string, a ...any) {}, - func(cipherText string, decrypt bool) (secret string) { return }) -} - func Initialize( errorHandler func(err error), traceHandler func(format string, a ...any), diff --git a/internal/mssql/mssql.go b/internal/mssql/mssql.go index 11972f54..2719607c 100644 --- a/internal/mssql/mssql.go +++ b/internal/mssql/mssql.go @@ -5,11 +5,15 @@ package mssql import ( "fmt" - "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" "os" ) +// Connect is used to connect to a SQL Server using the specified endpoint +// and user details. The console parameter is used to output messages during +// the connection process. The function returns a Sqlcmd instance that can +// be used to run SQL commands on the server. func Connect( endpoint sqlconfig.Endpoint, user *sqlconfig.User, @@ -46,6 +50,12 @@ func Connect( return s } +// Query is helper function that allows running a given SQL query on a +// provided sqlcmd.Sqlcmd object. It takes the sqlcmd.Sqlcmd object and the +// query text as inputs, and runs the query using the Run method of +// the sqlcmd.Sqlcmd object. It sets the standard output and standard error +// to be the same as the current process, and returns the error if any occurred +// during the execution of the query. func Query(s *sqlcmd.Sqlcmd, text string) { s.Query = text s.SetOutput(os.Stdout) diff --git a/internal/mssql/mssql_test.go b/internal/mssql/mssql_test.go index 5981d3d3..372c63de 100644 --- a/internal/mssql/mssql_test.go +++ b/internal/mssql/mssql_test.go @@ -4,7 +4,9 @@ package mssql import ( - . "github.com/microsoft/go-sqlcmd/cmd/sqlconfig" + "fmt" + . "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/secret" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" "runtime" "strings" @@ -12,7 +14,12 @@ import ( ) func TestConnect(t *testing.T) { - t.Skip() // BUG(stuartpa): Re-enable before merge + Initialize(func(err error) { + if err != nil { + panic(err) + } + }, func(format string, a ...any) { fmt.Printf(format, a...) }, secret.Decode) + endpoint := Endpoint{ EndpointDetails: EndpointDetails{ Address: "localhost", @@ -31,29 +38,29 @@ func TestConnect(t *testing.T) { }{ { name: "connectBasicPanic", args: args{ - endpoint: endpoint, - user: &User{ - Name: "basicUser", - AuthenticationType: "basic", - BasicAuth: &BasicAuthDetails{ - Username: "foo", - PasswordEncrypted: true, - Password: "bar", + endpoint: endpoint, + user: &User{ + Name: "basicUser", + AuthenticationType: "basic", + BasicAuth: &BasicAuthDetails{ + Username: "foo", + PasswordEncrypted: true, + Password: "bar", + }, }, + console: nil, }, - console: nil, - }, want: 0, }, { name: "invalidAuthTypePanic", args: args{ - endpoint: endpoint, - user: &User{ - Name: "basicUser", - AuthenticationType: "badbad", + endpoint: endpoint, + user: &User{ + Name: "basicUser", + AuthenticationType: "badbad", + }, + console: nil, }, - console: nil, - }, want: 0, }, } diff --git a/internal/net/net.go b/internal/net/net.go index 163006f0..765837dc 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -9,17 +9,21 @@ import ( "time" ) +// IsLocalPortAvailable takes a port number and returns a boolean indicating +// whether the port is available for use. func IsLocalPortAvailable(port int) (portAvailable bool) { timeout := time.Second + + hostPort := net.JoinHostPort("localhost", strconv.Itoa(port)) trace( "Checking if local port %d is available using DialTimeout(tcp, %v, timeout: %d)", port, - net.JoinHostPort("localhost", strconv.Itoa(port)), + hostPort, timeout, ) conn, err := net.DialTimeout( "tcp", - net.JoinHostPort("localhost", strconv.Itoa(port)), + hostPort, timeout, ) if err != nil { diff --git a/internal/net/net_test.go b/internal/net/net_test.go index aeefa73f..57123822 100644 --- a/internal/net/net_test.go +++ b/internal/net/net_test.go @@ -7,25 +7,27 @@ import ( "testing" ) +// TestIsLocalPortAvailable verified the function for both available and unavailable +// code (this function expects a local SQL Server instance listening on port 1433 func TestIsLocalPortAvailable(t *testing.T) { - t.Skip() // BUG(stuartpa): Re-enable before merge, fix to work on any machine - type args struct { - port int - } - tests := []struct { - name string - args args - wantPortAvailable bool - }{ - {name: "expectedToNotBeAvailable", args: args{port: 51027}, wantPortAvailable: false}, - {name: "expectedToBeAvailable", args: args{port: 9999}, wantPortAvailable: true}, - } + var testedPortAvailable bool + var testedNotPortAvailable bool - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if gotPortAvailable := IsLocalPortAvailable(tt.args.port); gotPortAvailable != tt.wantPortAvailable { - t.Errorf("IsLocalPortAvailable() = %v, want %v", gotPortAvailable, tt.wantPortAvailable) - } - }) + for i := 1432; i <= 1434; i++ { + isPortAvailable := IsLocalPortAvailable(i) + if isPortAvailable { + testedPortAvailable = true + t.Logf("Port %d is available", i) + } else { + testedNotPortAvailable = true + t.Logf("Port %d is not available", i) + } + if testedPortAvailable && testedNotPortAvailable { + return + } } + + t.Log("Didn't find both an available port and unavailable port") + t.Fail() + } diff --git a/internal/output/error.go b/internal/output/error.go deleted file mode 100644 index 70069477..00000000 --- a/internal/output/error.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package output - -var errorCallback func(err error) - -func checkErr(err error) { - errorCallback(err) -} diff --git a/internal/output/factory.go b/internal/output/factory.go new file mode 100644 index 00000000..8b75f74a --- /dev/null +++ b/internal/output/factory.go @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package output + +import ( + "fmt" + "github.com/microsoft/go-sqlcmd/internal/output/formatter" + "os" + "strings" +) + +// New initializes a new Output instance with the specified options. If options +// are not provided, default values are used. The function sets the error callback +// and the hint callback based on the value of the unitTesting field in the +// provided options. If unitTesting is true, the error callback is set to +// panic on error, otherwise it is set to use cobra.CheckErr to handle errors. +func New(options Options) *Output { + if options.LoggingLevel == 0 { + options.LoggingLevel = 2 + } + if options.StandardWriter == nil { + options.StandardWriter = os.Stdout + } + if options.ErrorHandler == nil { + if isRunningInTestExecutor(options) { + options.ErrorHandler = func(err error) { + if err != nil { + panic(err) + } + } + } else { + panic("Must provide Error Handler (the process (" + os.Args[0] + ") host is not a test executor)") + } + + } + if options.HintHandler == nil { + if isRunningInTestExecutor(options) { + options.HintHandler = func(hints []string) { + fmt.Println(hints) + } + } else { + panic("Must provide hint handler (the process " + os.Args[0] + " host is not a test executor)") + } + } + + f := formatter.New(formatter.Options{ + SerializationFormat: options.OutputType, + StandardOutput: options.StandardWriter, + ErrorHandler: options.ErrorHandler, + }) + + return &Output{ + formatter: f, + loggingLevel: options.LoggingLevel, + standardWriteCloser: options.StandardWriter, + errorCallback: options.ErrorHandler, + hintCallback: options.HintHandler, + } +} + +func isRunningInTestExecutor(options Options) bool { + if (strings.HasSuffix(os.Args[0], ".test") || // are we in go test on *nix? + strings.HasSuffix(os.Args[0], ".test.exe") || // are we in go test on windows? + (len(os.Args) > 1 && os.Args[1] == "-test.v")) && // are we in goland unittest? + !options.unitTesting { + return true + } else { + return false + } +} diff --git a/internal/output/factory_test.go b/internal/output/factory_test.go new file mode 100644 index 00000000..296cd04a --- /dev/null +++ b/internal/output/factory_test.go @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package output + +import ( + "github.com/microsoft/go-sqlcmd/internal/test" + "testing" +) + +func TestFactory(t *testing.T) { + o := New(Options{unitTesting: false, HintHandler: func(hints []string) { + + }, ErrorHandler: func(err error) { + + }}) + o.errorCallback(nil) +} + +func TestNegtactory(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + New(Options{unitTesting: true, + HintHandler: func(hints []string) {}, + ErrorHandler: nil}) +} + +func TestNegFactory2(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + New(Options{unitTesting: true, + HintHandler: nil, + ErrorHandler: func(err error) {}}) +} diff --git a/internal/output/formatter/base_test.go b/internal/output/formatter/base_test.go index 4853223c..15b22023 100644 --- a/internal/output/formatter/base_test.go +++ b/internal/output/formatter/base_test.go @@ -4,6 +4,7 @@ package formatter import ( + "github.com/microsoft/go-sqlcmd/internal/test" "strings" "testing" ) @@ -34,11 +35,7 @@ func TestBase_CheckErr(t *testing.T) { // If test name ends in 'Panic' expect a Panic if strings.HasSuffix(tt.name, "Panic") { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() + defer func() { test.CatchExpectedError(recover(), t) }() } f.CheckErr(tt.args.err) diff --git a/internal/output/formatter/factory.go b/internal/output/formatter/factory.go new file mode 100644 index 00000000..31625a34 --- /dev/null +++ b/internal/output/formatter/factory.go @@ -0,0 +1,43 @@ +package formatter + +import ( + "fmt" + "os" +) + +// New creates a new instance of the Formatter interface. It takes an Options +// struct as input and sets default values for some of the fields if they are +// not specified. The SerializationFormat field of the Options struct is used +// to determine which implementation of the Formatter interface to return. +// If the specified format is not supported, the function will panic. +func New(options Options, +) (f Formatter) { + if options.SerializationFormat == "" { + options.SerializationFormat = "yaml" + } + if options.ErrorHandler == nil { + options.ErrorHandler = func(err error) {} + } + if options.StandardOutput == nil { + options.StandardOutput = os.Stdout + } + + switch options.SerializationFormat { + case "json": + f = &Json{Base: Base{ + StandardOutput: options.StandardOutput, + ErrorHandlerCallback: options.ErrorHandler}} + case "yaml": + f = &Yaml{Base: Base{ + StandardOutput: options.StandardOutput, + ErrorHandlerCallback: options.ErrorHandler}} + case "xml": + f = &Xml{Base: Base{ + StandardOutput: options.StandardOutput, + ErrorHandlerCallback: options.ErrorHandler}} + default: + panic(fmt.Sprintf("Format '%v' not supported", options.SerializationFormat)) + } + + return +} diff --git a/internal/output/formatter/factory_test.go b/internal/output/formatter/factory_test.go new file mode 100644 index 00000000..fa624e32 --- /dev/null +++ b/internal/output/formatter/factory_test.go @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package formatter + +import ( + "github.com/microsoft/go-sqlcmd/internal/test" + "log" + "testing" +) + +func TestFormatter(t *testing.T) { + s := []string{"serialize this"} + + var f Formatter + f = New(Options{SerializationFormat: "yaml"}) + f.Serialize(s) + f = New(Options{SerializationFormat: "xml"}) + f.Serialize(s) + f = New(Options{SerializationFormat: "json"}) + f.Serialize(s) + + log.Println("This is here to ensure a newline is in test output") +} + +func TestNegFormatterBadFormat(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + s := "serialize this" + f := New(Options{SerializationFormat: "badbad"}) + f.Serialize(s) +} + +func TestFormatterEmptyFormat(t *testing.T) { + s := "serialize this" + f := New(Options{SerializationFormat: ""}) + f.Serialize(s) +} diff --git a/internal/output/formatter/formatter_test.go b/internal/output/formatter/formatter_test.go deleted file mode 100644 index e1338884..00000000 --- a/internal/output/formatter/formatter_test.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package formatter - -import ( - "os" - "testing" -) - -func TestFormatter(t *testing.T) { - - s := "Hello" - - b := Base{ - StandardOutput: os.Stdout, - ErrorHandlerCallback: func(err error) { - if err != nil { - panic(err) - } - }, - } - - j := Json{b} - j.Serialize(s) - - x := Xml{b} - x.Serialize(s) - - y := Yaml{b} - y.Serialize(s) - -} diff --git a/internal/output/formatter/interface.go b/internal/output/formatter/interface.go index 6ca43cda..2f3d55c7 100644 --- a/internal/output/formatter/interface.go +++ b/internal/output/formatter/interface.go @@ -3,6 +3,10 @@ package formatter +// Formatter defines a formatter for serializing an input object into a byte slice. +// The Serialize method serializes the input object and returns the resulting +// byte slice. The CheckErr method handles any error encountered during +// the serialization process. type Formatter interface { Serialize(in interface{}) (bytes []byte) CheckErr(err error) diff --git a/internal/output/formatter/json.go b/internal/output/formatter/json.go index e5a8bb03..8b057eba 100644 --- a/internal/output/formatter/json.go +++ b/internal/output/formatter/json.go @@ -14,7 +14,7 @@ type Json struct { func (f *Json) Serialize(in interface{}) (bytes []byte) { var err error - bytes, err = json.MarshalIndent(in, "", " ") + bytes, err = json.MarshalIndent(in, "", " ") f.Base.CheckErr(err) f.Base.Output(bytes) diff --git a/internal/output/formatter/options.go b/internal/output/formatter/options.go new file mode 100644 index 00000000..0673f207 --- /dev/null +++ b/internal/output/formatter/options.go @@ -0,0 +1,10 @@ +package formatter + +import "io" + +// Options defines the options for creating a new Formatter instance. +type Options struct { + SerializationFormat string + StandardOutput io.WriteCloser + ErrorHandler func(err error) +} diff --git a/internal/output/formatter/xml.go b/internal/output/formatter/xml.go index a3c5d670..07c8e498 100644 --- a/internal/output/formatter/xml.go +++ b/internal/output/formatter/xml.go @@ -14,7 +14,7 @@ type Xml struct { func (f *Xml) Serialize(in interface{}) (bytes []byte) { var err error - bytes, err = xml.MarshalIndent(in, "", " ") + bytes, err = xml.MarshalIndent(in, "", " ") f.Base.CheckErr(err) f.Base.Output(bytes) diff --git a/internal/output/hint.go b/internal/output/hint.go deleted file mode 100644 index c97ac343..00000000 --- a/internal/output/hint.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package output - -var hintCallback func(hints []string) - -func displayHints(hints []string) { - hintCallback(hints) -} diff --git a/internal/output/intialize.go b/internal/output/intialize.go deleted file mode 100644 index e8db7dee..00000000 --- a/internal/output/intialize.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package output - -import ( - "fmt" - . "github.com/microsoft/go-sqlcmd/internal/output/formatter" - "github.com/microsoft/go-sqlcmd/internal/output/verbosity" - "io" - "os" -) - -// init initializes the package for unit testing. For production, use -// the Initialize method to inject fully functional dependencies -func init() { - errorHandler := func(err error) { - if err != nil { - panic(err) - } - } - formatter = &Yaml{Base: Base{ - StandardOutput: standardWriteCloser, - ErrorHandlerCallback: errorHandler, - }} - - Initialize( - errorHandler, - func(format string, a ...any) {}, - func(hints []string) {}, - os.Stdout, - "yaml", - verbosity.Info, - ) -} - -func Initialize( - errorHandler func(err error), - traceHandler func(format string, a ...any), - hintHandler func(hints []string), - standardOutput io.WriteCloser, - serializationFormat string, - verbosity verbosity.Enum, -) { - errorCallback = errorHandler - traceCallback = traceHandler - hintCallback = hintHandler - standardWriteCloser = standardOutput - loggingLevel = verbosity - - trace("Initializing output as '%v'", serializationFormat) - - switch serializationFormat { - case "json": - formatter = &Json{Base: Base{ - StandardOutput: standardWriteCloser, - ErrorHandlerCallback: errorHandler}} - case "yaml": - formatter = &Yaml{Base: Base{ - StandardOutput: standardWriteCloser, - ErrorHandlerCallback: errorHandler}} - case "xml": - formatter = &Xml{Base: Base{ - StandardOutput: standardWriteCloser, - ErrorHandlerCallback: errorHandler}} - default: - panic(fmt.Sprintf("Format '%v' not supported", serializationFormat)) - } -} diff --git a/internal/output/intialize_test.go b/internal/output/intialize_test.go deleted file mode 100644 index 544ea8c5..00000000 --- a/internal/output/intialize_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package output - -import ( - "github.com/microsoft/go-sqlcmd/internal/output/verbosity" - "io" - "os" - "strings" - "testing" -) - -func TestInitialize(t *testing.T) { - type args struct { - errorHandler func(err error) - traceHandler func(format string, a ...any) - hintHandler func(hints []string) - standardOutput io.WriteCloser - errorOutput io.WriteCloser - format string - verbosity verbosity.Enum - } - tests := []struct { - name string - args args - }{ - { - name: "badFormatterPanic", - args: args{ - errorHandler: errorCallback, - traceHandler: traceCallback, - hintHandler: hintCallback, - standardOutput: os.Stdout, - errorOutput: os.Stderr, - format: "badbad", - verbosity: 0, - }, - }, - { - name: "initWithXml", - args: args{ - errorHandler: errorCallback, - traceHandler: traceCallback, - hintHandler: hintCallback, - standardOutput: os.Stdout, - errorOutput: os.Stderr, - format: "xml", - verbosity: 0, - }, - }, - { - name: "initWithJson", - args: args{ - errorHandler: errorCallback, - traceHandler: traceCallback, - hintHandler: hintCallback, - standardOutput: os.Stdout, - errorOutput: os.Stderr, - format: "json", - verbosity: 0, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - // If test name ends in 'Panic' expect a Panic - if strings.HasSuffix(tt.name, "Panic") { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - } - Initialize( - tt.args.errorHandler, - tt.args.traceHandler, - tt.args.hintHandler, - tt.args.standardOutput, - tt.args.format, - tt.args.verbosity, - ) - }) - } -} diff --git a/internal/output/options.go b/internal/output/options.go new file mode 100644 index 00000000..9d1d7195 --- /dev/null +++ b/internal/output/options.go @@ -0,0 +1,17 @@ +package output + +import ( + "github.com/microsoft/go-sqlcmd/internal/output/verbosity" + "io" +) + +type Options struct { + OutputType string + LoggingLevel verbosity.Level + StandardWriter io.WriteCloser + + ErrorHandler func(err error) + HintHandler func(hints []string) + + unitTesting bool +} diff --git a/internal/output/output.go b/internal/output/output.go index fd2eed14..c18981d4 100644 --- a/internal/output/output.go +++ b/internal/output/output.go @@ -1,138 +1,140 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -// Package output manages outputting text to the user. +// Package Output provides a number of methods for logging and handling +// errors, including Debugf, Errorf, Fatalf, FatalErr, Infof, Panic, Panicf, +// Struct, Tracef, and Warnf. These methods allow the caller to specify the +// desired verbosity level, add newlines to the end of the log message if +// necessary, and handle errors and hints in a variety of ways. // // Trace("Something very low level.") - not localized // Debug("Useful debugging information.") - not localized // Info("Something noteworthy happened!") - localized // Warn("You should probably take a look at this.") - localized // Error("Something failed but I'm not quitting.") - localized -// Fatal("Bye.") - localized -// -// calls os.Exit(1) after logging -// -// Panic("I'm bailing.") - not localized -// -// calls panic() after logging +// Fatal("Bye.") - localized, calls os.Exit(1) after logging +// Panic("I'm bailing.") - not localized, calls panic() after logging package output import ( "fmt" - . "github.com/microsoft/go-sqlcmd/internal/output/formatter" "github.com/microsoft/go-sqlcmd/internal/output/verbosity" - "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" + "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/pkg/errors" - "io" "regexp" "strings" ) -var formatter Formatter -var loggingLevel verbosity.Enum -var runningUnitTests bool - -var standardWriteCloser io.WriteCloser - -func Debugf(format string, a ...any) { - if loggingLevel >= verbosity.Debug { - format = ensureEol(format) - printf("DEBUG: "+format, a...) +func (o Output) Debugf(format string, a ...any) { + if o.loggingLevel >= verbosity.Debug { + format = o.ensureEol(format) + o.printf("DEBUG: "+format, a...) } } -func Errorf(format string, a ...any) { - if loggingLevel >= verbosity.Error { - format = ensureEol(format) - if loggingLevel >= verbosity.Debug { +func (o Output) Errorf(format string, a ...any) { + if o.loggingLevel >= verbosity.Error { + format = o.ensureEol(format) + if o.loggingLevel >= verbosity.Debug { format = "ERROR: " + format } - printf(format, a...) + o.printf(format, a...) } } -func Fatal(a ...any) { - fatal([]string{}, a...) +func (o Output) Fatal(a ...any) { + o.fatal([]string{}, a...) } - -func FatalErr(err error) { - checkErr(err) +func (o Output) FatalErr(err error) { + o.errorCallback(err) } -func Fatalf(format string, a ...any) { - fatalf([]string{}, format, a...) +func (o Output) Fatalf(format string, a ...any) { + o.fatalf([]string{}, format, a...) } -func FatalfErrorWithHints(err error, hints []string, format string, a ...any) { - fatalf(hints, format, a...) - checkErr(err) +func (o Output) FatalfErrorWithHints(err error, hints []string, format string, a ...any) { + o.fatalf(hints, format, a...) + o.errorCallback(err) } -func FatalfWithHints(hints []string, format string, a ...any) { - fatalf(hints, format, a...) +func (o Output) FatalfWithHints(hints []string, format string, a ...any) { + o.fatalf(hints, format, a...) } -func FatalfWithHintExamples(hintExamples [][]string, format string, a ...any) { +func (o Output) FatalfWithHintExamples(hintExamples [][]string, format string, a ...any) { err := errors.New(fmt.Sprintf(format, a...)) - displayHintExamples(hintExamples) - checkErr(err) + o.displayHintExamples(hintExamples) + o.errorCallback(err) } -func FatalWithHints(hints []string, a ...any) { - fatal(hints, a...) +func (o Output) FatalWithHints(hints []string, a ...any) { + o.fatal(hints, a...) } -func Infof(format string, a ...any) { - infofWithHints([]string{}, format, a...) +func (o Output) Infof(format string, a ...any) { + o.infofWithHints([]string{}, format, a...) } -func InfofWithHints(hints []string, format string, a ...any) { - infofWithHints(hints, format, a...) +func (o Output) InfofWithHints(hints []string, format string, a ...any) { + o.infofWithHints(hints, format, a...) } -func InfofWithHintExamples(hintExamples [][]string, format string, a ...any) { - if loggingLevel >= verbosity.Info || runningUnitTests { - format = ensureEol(format) - if loggingLevel >= verbosity.Debug { +// InfofWithHintExamples logs an info-level message with a given format and +// arguments a. It also displays additional hints with example usage in the +// output, using the displayHintExamples helper function. The message is +// formatted using the ensureEol helper function to ensure that it ends with +// a newline character. If the logging level is set to Debug, the message is prefixed +// with "INFO: ". The displayHintExamples helper function formats the hints +// for display and passes them to the hintCallback function for output. +func (o Output) InfofWithHintExamples(hintExamples [][]string, format string, a ...any) { + if o.loggingLevel >= verbosity.Info { + format = o.ensureEol(format) + if o.loggingLevel >= verbosity.Debug { format = "INFO: " + format } - printf(format, a...) - displayHintExamples(hintExamples) + o.printf(format, a...) + o.displayHintExamples(hintExamples) } } -func Panic(a ...any) { +func (o Output) Panic(a ...any) { panic(a) } -func Panicf(format string, a ...any) { +func (o Output) Panicf(format string, a ...any) { panic(fmt.Sprintf(format, a...)) } -func Struct(in interface{}) (bytes []byte) { - bytes = formatter.Serialize(in) +func (o Output) Struct(in interface{}) (bytes []byte) { + bytes = o.formatter.Serialize(in) return } -func Tracef(format string, a ...any) { - if loggingLevel >= verbosity.Trace { - format = ensureEol(format) - printf("TRACE: "+format, a...) +func (o Output) Tracef(format string, a ...any) { + if o.loggingLevel >= verbosity.Trace { + format = o.ensureEol(format) + o.printf("TRACE: "+format, a...) } } -func Warnf(format string, a ...any) { - if loggingLevel >= verbosity.Warn { - format = ensureEol(format) - if loggingLevel >= verbosity.Debug { +func (o Output) Warnf(format string, a ...any) { + if o.loggingLevel >= verbosity.Warn { + format = o.ensureEol(format) + if o.loggingLevel >= verbosity.Debug { format = "WARN: " + format } - printf(format, a...) + o.printf(format, a...) } } -func displayHintExamples(hintExamples [][]string) { +// displayHintExamples takes an array of hint examples and displays them in +// a formatted way. It first calculates the maximum length of the description +// in the hint examples, and then creates a string for each hint example with +// the description padded to the maximum length, followed by the example. +// Finally, it calls the hint callback function with the array of formatted hints. +func (o Output) displayHintExamples(hintExamples [][]string) { var hints []string maxLengthHintText := 0 @@ -155,44 +157,62 @@ func displayHintExamples(hintExamples [][]string) { hintExample[1], )) } - displayHints(hints) -} - -func ensureEol(format string) string { - if len(format) >= len(sqlcmd.SqlcmdEol) { - if !strings.HasSuffix(format, sqlcmd.SqlcmdEol) { - format = format + sqlcmd.SqlcmdEol + o.hintCallback(hints) +} + +// ensureEol ensures that the provided format string ends with a line break character. +// It does this by checking if the format string already ends with a line break character, +// and if not, it appends a line break character to the format string. If the format +// string is shorter than the length of the line break character, it returns the line +// break character on its own. This function is useful for ensuring that output to +// the console will always be properly formatted and easy to read. +func (o Output) ensureEol(format string) string { + if len(format) >= len(pal.LineBreak()) { + if !strings.HasSuffix(format, pal.LineBreak()) { + format = format + pal.LineBreak() } } else { - format = sqlcmd.SqlcmdEol + format = pal.LineBreak() } return format } -func fatal(hints []string, a ...any) { +func (o Output) fatal(hints []string, a ...any) { err := errors.New(fmt.Sprintf("%v", a...)) - displayHints(hints) - checkErr(err) + o.hintCallback(hints) + o.errorCallback(err) } -func fatalf(hints []string, format string, a ...any) { +func (o Output) fatalf(hints []string, format string, a ...any) { err := errors.New(fmt.Sprintf(format, a...)) - displayHints(hints) - checkErr(err) -} - -func infofWithHints(hints []string, format string, a ...any) { - if loggingLevel >= verbosity.Info { - format = ensureEol(format) - if loggingLevel >= verbosity.Debug { + o.hintCallback(hints) + o.errorCallback(err) +} + +// infofWithHints is used to print out an "INFO" message with additional hints. +// The format argument specifies the text to be printed, which can include placeholders +// for dynamic values. The a argument is a variadic parameter containing the +// values to be used to replace the placeholders in the format string. The hints +// argument is a slice of strings representing additional hints to be printed along +// with the message. The function checks if the logging level is set to at least +// "Info" before printing the message. If the logging level is set to "Debug" +// or higher, the string "INFO: " is prepended to the message before it is printed. +// The function also calls the hintCallback function to print the hints, if any are provided. +func (o Output) infofWithHints(hints []string, format string, a ...any) { + if o.loggingLevel >= verbosity.Info { + format = o.ensureEol(format) + if o.loggingLevel >= verbosity.Debug { format = "INFO: " + format } - printf(format, a...) - displayHints(hints) + o.printf(format, a...) + o.hintCallback(hints) } } -func maskSecrets(text string) string { +// maskSecrets takes a string as input and masks any password found in the +// string using the PASSWORD.*\s?=.*\s?N?') regular expression. It +// returns the resulting masked string. +func (o Output) maskSecrets(text string) string { // Mask password from T/SQL e.g. ALTER LOGIN [sa] WITH PASSWORD = N'foo'; r := regexp.MustCompile(`(PASSWORD.*\s?=.*\s?N?')(.*)(')`) @@ -200,9 +220,9 @@ func maskSecrets(text string) string { return text } -func printf(format string, a ...any) { +func (o Output) printf(format string, a ...any) { text := fmt.Sprintf(format, a...) - text = maskSecrets(text) - _, err := standardWriteCloser.Write([]byte(text)) - checkErr(err) + text = o.maskSecrets(text) + _, err := o.standardWriteCloser.Write([]byte(text)) + o.errorCallback(err) } diff --git a/internal/output/output_test.go b/internal/output/output_test.go index 8b0ce7c5..d8465cd8 100644 --- a/internal/output/output_test.go +++ b/internal/output/output_test.go @@ -6,273 +6,101 @@ package output import ( "errors" "github.com/microsoft/go-sqlcmd/internal/output/verbosity" + "github.com/microsoft/go-sqlcmd/internal/test" "testing" ) func TestTracef(t *testing.T) { - type args struct { - loggingLevel verbosity.Enum - format string - a []any - } - tests := []struct { - name string - args args - }{ - {"default", args{ - loggingLevel: verbosity.Trace, - format: "%v", - a: []any{"sample trace"}, - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - loggingLevel = tt.args.loggingLevel - Tracef(tt.args.format, tt.args.a...) - Debugf(tt.args.format, tt.args.a...) - Infof(tt.args.format, tt.args.a...) - Warnf(tt.args.format, tt.args.a...) - Errorf(tt.args.format, tt.args.a...) - Struct(tt.args.a) + format := "%v" + args := []string{"sample text"} - InfofWithHints([]string{}, tt.args.format, tt.args.a...) - InfofWithHintExamples([][]string{}, tt.args.format, tt.args.a...) - }) - } + loggingLevel := verbosity.Trace + o := New(Options{LoggingLevel: loggingLevel, HintHandler: func(hints []string) { + + }, ErrorHandler: func(err error) { + + }}) + o.Tracef(format, args) + o.Debugf(format, args) + o.Infof(format, args) + o.Warnf(format, args) + o.Errorf(format, args) + o.Struct(args) + + o.InfofWithHints([]string{}, format, args) + o.InfofWithHintExamples([][]string{}, format, args) } func TestFatal(t *testing.T) { - type args struct { - a []any - } - tests := []struct { - name string - args args - }{ - {"default", args{ - a: []any{"sample trace"}, - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - Fatal(tt.args.a...) - }) - } + defer func() { test.CatchExpectedError(recover(), t) }() + o := New(Options{LoggingLevel: 4}) + o.Fatal("sample trace") } func TestFatalWithHints(t *testing.T) { - type args struct { - hints []string - a []any - } - tests := []struct { - name string - args args - }{ - {"default", args{ - hints: []string{"This is a hint"}, - a: []any{"sample trace"}, - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - FatalWithHints(tt.args.hints, tt.args.a...) - }) - } + defer func() { test.CatchExpectedError(recover(), t) }() + o := New(Options{LoggingLevel: 4}) + o.FatalWithHints([]string{"This is a hint"}, "expected error") } func TestFatalfWithHintExamples(t *testing.T) { - type args struct { - hintExamples [][]string - format string - a []any - } - tests := []struct { - name string - args args - }{ - {"default", args{ - hintExamples: [][]string{{"This is a hint", "With a sample"}}, - a: []any{"sample trace"}, - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - FatalfWithHintExamples(tt.args.hintExamples, tt.args.format, tt.args.a...) - }) - } + defer func() { test.CatchExpectedError(recover(), t) }() + + hintExamples := [][]string{{"This is a hint", ""}} + o := New(Options{LoggingLevel: verbosity.Trace}) + o.FatalfWithHintExamples( + hintExamples, + "%v", + "this is an error", + ) } func TestFatalfErrorWithHints(t *testing.T) { - type args struct { - err error - hints []string - format string - a []any - } - tests := []struct { - name string - args args - }{ - {"default", args{ - hints: []string{"This is a hint"}, - a: []any{"sample trace"}, - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - FatalfErrorWithHints(tt.args.err, tt.args.hints, tt.args.format, tt.args.a...) - }) - } + defer func() { test.CatchExpectedError(recover(), t) }() + o := New(Options{LoggingLevel: 4}) + o.FatalfErrorWithHints( + errors.New("error to check"), + []string{"This is a hint to avoid the error"}, + "%v", + "This the error message", + ) } func TestFatalfWithHints(t *testing.T) { - type args struct { - hints []string - format string - a []any - } - tests := []struct { - name string - args args - }{ - {"default", args{ - hints: []string{"This is a hint"}, - a: []any{"sample trace"}, - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - FatalfWithHints(tt.args.hints, tt.args.format, tt.args.a...) - }) - } + defer func() { test.CatchExpectedError(recover(), t) }() + o := New(Options{LoggingLevel: 4}) + o.FatalfWithHints( + []string{"This is a hint to the user to avoid the error"}, + "%v", + "this is the reason for the fatal error", + ) } func TestFatalf(t *testing.T) { - type args struct { - format string - a []any - } - tests := []struct { - name string - args args - }{ - {"default", args{ - a: []any{"sample trace"}, - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - Fatalf(tt.args.format, tt.args.a...) - }) - } + defer func() { test.CatchExpectedError(recover(), t) }() + o := New(Options{LoggingLevel: 4}) + o.Fatalf("%v", "message to give user on exit") } func TestFatalErr(t *testing.T) { - type args struct { - err error - } - tests := []struct { - name string - args args - }{ - {"default", args{ - errors.New("an error"), - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - FatalErr(tt.args.err) - }) - } + defer func() { test.CatchExpectedError(recover(), t) }() + o := New(Options{LoggingLevel: 4}) + o.FatalErr(errors.New("will exist if error is not nil")) } func TestPanicf(t *testing.T) { - type args struct { - format string - a []any - } - tests := []struct { - name string - args args - }{ - {"default", args{ - a: []any{"sample trace"}, - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - Panicf(tt.args.format, tt.args.a...) - }) - } + defer func() { test.CatchExpectedError(recover(), t) }() + o := New(Options{LoggingLevel: 4}) + o.Panicf("%v", "this is the reason for the panic") } func TestPanic(t *testing.T) { - type args struct { - a []any - } - tests := []struct { - name string - args args - }{ - {"default", args{ - a: []any{"sample trace"}, - }}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - Panic(tt.args.a...) - }) - } + defer func() { test.CatchExpectedError(recover(), t) }() + o := New(Options{LoggingLevel: 4}) + o.Panic("reason for the panic") } func TestInfofWithHintExamples(t *testing.T) { - t.Skip() // BUG(stuartpa): CrossPlatScripts build is failing on this test!? (presume this is an issue with static state, move to an object) type args struct { hintExamples [][]string format string @@ -303,21 +131,9 @@ func TestInfofWithHintExamples(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - - //BUG(stuartpa): Not thread safe - runningUnitTests = true - InfofWithHintExamples(tt.args.hintExamples, tt.args.format, tt.args.a...) - runningUnitTests = false + defer func() { test.CatchExpectedError(recover(), t) }() + o := New(Options{LoggingLevel: 4}) + o.InfofWithHintExamples(tt.args.hintExamples, tt.args.format, tt.args.a...) }) } } - -func Test_ensureEol(t *testing.T) { - format := ensureEol("%s") - Infof(format, "hello-world") -} diff --git a/internal/output/trace.go b/internal/output/trace.go deleted file mode 100644 index 6b0c96af..00000000 --- a/internal/output/trace.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package output - -var traceCallback func(format string, a ...any) - -func trace(format string, a ...any) { - traceCallback(format, a...) -} diff --git a/internal/output/type.go b/internal/output/type.go new file mode 100644 index 00000000..0351d66e --- /dev/null +++ b/internal/output/type.go @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package output + +import ( + "github.com/microsoft/go-sqlcmd/internal/output/formatter" + "github.com/microsoft/go-sqlcmd/internal/output/verbosity" + "io" +) + +type Output struct { + errorCallback func(err error) + hintCallback func(hints []string) + + formatter formatter.Formatter + loggingLevel verbosity.Level + standardWriteCloser io.WriteCloser +} diff --git a/internal/output/verbosity/level.go b/internal/output/verbosity/level.go new file mode 100644 index 00000000..b4221430 --- /dev/null +++ b/internal/output/verbosity/level.go @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package verbosity + +// Level is an enumeration representing different verbosity levels for logging, +// ranging from Error to Trace. The values of the enumeration are +// Error, Warn, Info, Debug, and Trace, in increasing order of verbosity. +type Level int + +const ( + Error Level = iota + Warn + Info + Debug + Trace +) diff --git a/internal/pal/intialize.go b/internal/pal/intialize.go index 8d124203..466c1eb3 100644 --- a/internal/pal/intialize.go +++ b/internal/pal/intialize.go @@ -8,9 +8,10 @@ func init() { if err != nil { panic(err) } - }) + }, "\n") } -func Initialize(handler func(err error)) { +func Initialize(handler func(err error), endOfLine string) { errorCallback = handler + lineBreak = endOfLine } diff --git a/internal/pal/pal.go b/internal/pal/pal.go index 4df15776..f0683df1 100644 --- a/internal/pal/pal.go +++ b/internal/pal/pal.go @@ -11,6 +11,8 @@ import ( "strings" ) +var lineBreak string + // FilenameInUserHomeDotDirectory returns the full path and filename // to the filename in the dotDirectory (e.g. .sqlcmd) in the user's home directory // e.g. c:\users\username @@ -40,3 +42,11 @@ func CmdLineWithEnvVars(vars []string, cmd string) string { return sb.String() } + +func LineBreak() string { + if lineBreak == "" { + panic("Initialize has not been called") + } + + return lineBreak +} diff --git a/internal/pal/pal_test.go b/internal/pal/pal_test.go index 67c2ee11..a8337759 100644 --- a/internal/pal/pal_test.go +++ b/internal/pal/pal_test.go @@ -6,6 +6,7 @@ package pal import ( "errors" "fmt" + "github.com/microsoft/go-sqlcmd/internal/test" "testing" ) @@ -13,6 +14,17 @@ func TestFilenameInUserHomeDotDirectory(t *testing.T) { FilenameInUserHomeDotDirectory(".foo", "bar") } +func TestLineBreak(t *testing.T) { + LineBreak() +} + +func TestNegLineBreak(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() + + lineBreak = "" + LineBreak() +} + func TestCheckErr(t *testing.T) { defer func() { if r := recover(); r == nil { diff --git a/internal/secret/generate.go b/internal/secret/generate.go index 78ceedd8..62ea28f5 100644 --- a/internal/secret/generate.go +++ b/internal/secret/generate.go @@ -16,6 +16,13 @@ const ( numberSet = "0123456789" ) +// Generate generates a random password of a specified length. The password +// will contain at least the specified number of special characters, +// numeric digits, and upper-case letters. The remaining characters in the +// password will be selected from a combination of lower-case letters, special +// characters, and numeric digits. The special characters are chosen from +// the provided special character set. The generated password is returned +// as a string. func Generate(passwordLength, minSpecialChar, minNum, minUpperCase int, specialCharSet string) string { var password strings.Builder allCharSet := lowerCharSet + upperCharSet + specialCharSet + numberSet diff --git a/internal/secret/secret.go b/internal/secret/secret.go index 4b08eece..22019eb8 100644 --- a/internal/secret/secret.go +++ b/internal/secret/secret.go @@ -9,7 +9,9 @@ import ( "encoding/base64" ) -// Encode optionally encrypts the plainText and always base64 encodes it +// Encode takes a plain text string and a boolean indicating whether or not to +// encrypt the plain text using a password, and returns the resulting cipher text. +// If the plain text is an empty string, this function will panic. func Encode(plainText string, encryptPassword bool) (cipherText string) { if plainText == "" { panic("Cannot encode/encrypt an empty string") @@ -26,7 +28,9 @@ func Encode(plainText string, encryptPassword bool) (cipherText string) { return } -// Decode always base64 decodes the cipherText and optionally decrypts it +// Decode takes a cipher text and a boolean indicating whether or not to decrypt +// the cipher text using a password, and returns the resulting plain text. +// If the cipher text is an empty string, this function will panic. func Decode(cipherText string, decryptPassword bool) (plainText string) { if cipherText == "" { panic("Cannot decode/decrypt an empty string") diff --git a/internal/secret/secret_test.go b/internal/secret/secret_test.go index 3c463abf..7a0a3486 100644 --- a/internal/secret/secret_test.go +++ b/internal/secret/secret_test.go @@ -4,67 +4,25 @@ package secret import ( - "github.com/microsoft/go-sqlcmd/internal/output" - "strings" + "github.com/microsoft/go-sqlcmd/internal/test" "testing" ) -func TestEncryptAndDecrypt(t *testing.T) { - type args struct { - plainText string - encrypt bool - } - tests := []struct { - name string - args args - wantPlainText string - }{ - { - name: "noEncrypt", - args: args{"plainText", false}, - wantPlainText: "plainText", - }, - { - name: "encrypt", - args: args{"plainText", true}, - wantPlainText: "plainText", - }, - { - name: "emptyStringForEncryptPanic", - args: args{"", true}, - wantPlainText: "", - }, - { - name: "emptyStringForDecryptPanic", - args: args{"", true}, - wantPlainText: "", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { +func TestEncodeAndDecode(t *testing.T) { + notEncrypted := Encode("plainText", false) + encrypted := Encode("plainText", true) + Decode(notEncrypted, false) + Decode(encrypted, true) +} + +func TestNegEncode(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() - // If test name ends in 'Panic' expect a Panic - if strings.HasSuffix(tt.name, "Panic") { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - } + Encode("", true) +} - var gotPlainText string - if tt.name != "emptyStringForDecryptPanic" { - cipherText := Encode(tt.args.plainText, tt.args.encrypt) - gotPlainText = Decode(cipherText, tt.args.encrypt) - output.Infof(gotPlainText) - } else { - gotPlainText = Decode(tt.args.plainText, tt.args.encrypt) - output.Infof(gotPlainText) - } +func TestNegDecode(t *testing.T) { + defer func() { test.CatchExpectedError(recover(), t) }() - if gotPlainText = tt.args.plainText; gotPlainText != tt.wantPlainText { - t.Errorf("Encode/Decode() = %v, want %v", gotPlainText, tt.wantPlainText) - } - }) - } + Decode("", true) } diff --git a/internal/test/memory-buffer.go b/internal/test/memory-buffer.go new file mode 100644 index 00000000..c8fe3e81 --- /dev/null +++ b/internal/test/memory-buffer.go @@ -0,0 +1,25 @@ +package test + +import "bytes" + +// MemoryBuffer has both Write and Close methods for use as io.WriteCloser +// when testing (instead of os.Stdout), so tests can assert.Equal results etc. +type MemoryBuffer struct { + buf *bytes.Buffer +} + +func (b *MemoryBuffer) Write(p []byte) (n int, err error) { + return b.buf.Write(p) +} + +func (b *MemoryBuffer) Close() error { + return nil +} + +func (b *MemoryBuffer) String() string { + return b.buf.String() +} + +func NewMemoryBuffer() *MemoryBuffer { + return &MemoryBuffer{buf: new(bytes.Buffer)} +} diff --git a/internal/test/test.go b/internal/test/test.go new file mode 100644 index 00000000..ba6817f7 --- /dev/null +++ b/internal/test/test.go @@ -0,0 +1,20 @@ +package test + +import ( + "testing" +) + +// CatchExpectedError function is a helper function for use in unit tests. It +// expects an error value as its first argument, and a pointer to a testing.T struct +// as its second argument. If the error is not nil, the function logs the error +// and continues execution of the unit test. If the error is nil, the function +// panics with a message indicating that the code did not panic as expected. +// This function is used to verify that a particular code path produces an error +// in a unit test. +func CatchExpectedError(err any, t *testing.T) { + if err != nil { + t.Log("The expected error was:", err) + } else { + panic("The code did not panic as expected") + } +} diff --git a/internal/test/test_test.go b/internal/test/test_test.go new file mode 100644 index 00000000..79cf4014 --- /dev/null +++ b/internal/test/test_test.go @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package test + +import ( + "errors" + "testing" +) + +func TestCatchExpectedError(t *testing.T) { + CatchExpectedError(errors.New("test"), t) +} + +func TestCatchExpectedError2(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + CatchExpectedError(nil, t) +} diff --git a/main.go b/main.go deleted file mode 100644 index 48830234..00000000 --- a/main.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -// Package main is the entrypoint for the sqlcmd CLI application. -package main - -import ( - "github.com/microsoft/go-sqlcmd/cmd" - legacyCmd "github.com/microsoft/go-sqlcmd/cmd/sqlcmd" - "os" -) - -// main is the entrypoint function for sqlcmd. -// -// TEMPORARY: While we have both the new cobra and old kong CLI -// implementations, main decides which CLI framework to use -func main() { - cmd.Initialize() - - if isModernCliEnabled() && isFirstArgModernCliSubCommand() { - cmd.Execute() - } else { - legacyCmd.Execute() - } -} - -// isModernCliEnabled is TEMPORARY code, to be removed when we enable -// the new cobra based CLI by default -func isModernCliEnabled() (modernCliEnabled bool) { - if os.Getenv("SQLCMD_MODERN") != "" { - modernCliEnabled = true - } - return -} - -// isFirstArgModernCliSubCommand is TEMPORARY code, to be removed when -// we enable the new cobra based CLI by default -func isFirstArgModernCliSubCommand() (isNewCliCommand bool) { - if len(os.Args) > 0 { - if cmd.IsValidSubCommand(os.Args[1]) { - isNewCliCommand = true - } - } - return -}