From 0ca6f013e474d274577ee4cbee881095b79c37d4 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Fri, 2 Feb 2024 00:30:53 +0000 Subject: [PATCH 01/20] gopls: add fill switch cases code action --- .../analysis/fillswitch/fillswitch.go | 359 ++++++++++++++++++ .../analysis/fillswitch/fillswitch_test.go | 38 ++ .../analysis/fillswitch/testdata/src/a/a.go | 74 ++++ .../analysis/fillswitch/testdata/src/b/b.go | 13 + gopls/internal/golang/codeaction.go | 20 + gopls/internal/golang/fix.go | 2 + .../testdata/codeaction/fill_switch.txt | 76 ++++ .../codeaction/fill_switch_resolve.txt | 87 +++++ 8 files changed, 669 insertions(+) create mode 100644 gopls/internal/analysis/fillswitch/fillswitch.go create mode 100644 gopls/internal/analysis/fillswitch/fillswitch_test.go create mode 100644 gopls/internal/analysis/fillswitch/testdata/src/a/a.go create mode 100644 gopls/internal/analysis/fillswitch/testdata/src/b/b.go create mode 100644 gopls/internal/test/marker/testdata/codeaction/fill_switch.txt create mode 100644 gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go new file mode 100644 index 00000000000..666c469a870 --- /dev/null +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -0,0 +1,359 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fillswitch defines an Analyzer that automatically +// fills the missing cases in type switches or switches over named types. +// +// The analyzer's diagnostic is merely a prompt. +// The actual fix is created by a separate direct call from gopls to +// the SuggestedFixes function. +// Tests of Analyzer.Run can be found in ./testdata/src. +// Tests of the SuggestedFixes logic live in ../../testdata/fillswitch. +package fillswitch + +import ( + "bytes" + "context" + "errors" + "fmt" + "go/ast" + "go/token" + "go/types" + "slices" + "strings" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/gopls/internal/cache" + "golang.org/x/tools/gopls/internal/cache/parsego" +) + +const FixCategory = "fillswitch" // recognized by gopls ApplyFix + +// errNoSuggestedFix is returned when no suggested fix is available. This could +// be because all cases are already covered, or (in the case of a type switch) +// because the remaining cases are for types not accessible by the current +// package. +var errNoSuggestedFix = errors.New("no suggested fix") + +// Diagnose computes diagnostics for switch statements with missing cases +// overlapping with the provided start and end position. +// +// The diagnostic contains a lazy fix; the actual patch is computed +// (via the ApplyFix command) by a call to [SuggestedFix]. +// +// If either start or end is invalid, the entire package is inspected. +func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic { + var diags []analysis.Diagnostic + nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)} + inspect.Preorder(nodeFilter, func(n ast.Node) { + if expr, ok := n.(*ast.SwitchStmt); ok { + if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) { + return // non-overlapping + } + + if defaultHandled(expr.Body) { + return + } + + namedType, err := namedTypeFromSwitch(expr, info) + if err != nil { + return + } + + if _, err := suggestedFixSwitch(expr, pkg, info); err != nil { + return + } + + diags = append(diags, analysis.Diagnostic{ + Message: "Switch has missing cases", + Pos: expr.Pos(), + End: expr.End(), + Category: FixCategory, + SuggestedFixes: []analysis.SuggestedFix{{ + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), + // No TextEdits => computed later by gopls. + }}, + }) + } + + if expr, ok := n.(*ast.TypeSwitchStmt); ok { + if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) { + return // non-overlapping + } + + if defaultHandled(expr.Body) { + return + } + + namedType, err := namedTypeFromTypeSwitch(expr, info) + if err != nil { + return + } + + if _, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil { + return + } + + diags = append(diags, analysis.Diagnostic{ + Message: "Switch has missing cases", + Pos: expr.Pos(), + End: expr.End(), + Category: FixCategory, + SuggestedFixes: []analysis.SuggestedFix{{ + Message: fmt.Sprintf("Add cases for %v", namedType.Obj().Name()), + // No TextEdits => computed later by gopls. + }}, + }) + } + }) + + return diags +} + +func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { + namedType, err := namedTypeFromTypeSwitch(stmt, info) + if err != nil { + return nil, err + } + + scope := namedType.Obj().Pkg().Scope() + variants := make([]string, 0) + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if _, ok := obj.(*types.TypeName); !ok { + continue + } + + if types.Identical(obj.Type(), namedType.Obj().Type()) { + continue + } + + if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { + if obj.Pkg().Name() != pkg.Name() { + if !obj.Exported() { + continue + } + + variants = append(variants, obj.Pkg().Name()+"."+obj.Name()) + } else { + variants = append(variants, obj.Name()) + } + } else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) { + if obj.Pkg().Name() != pkg.Name() { + if !obj.Exported() { + continue + } + + variants = append(variants, "*"+obj.Pkg().Name()+"."+obj.Name()) + } else { + variants = append(variants, "*"+obj.Name()) + } + } + } + + handledVariants := getHandledVariants(stmt.Body) + if len(variants) == 0 || len(variants) == len(handledVariants) { + return nil, errNoSuggestedFix + } + + newText := buildNewText(variants, handledVariants) + return &analysis.SuggestedFix{ + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), + TextEdits: []analysis.TextEdit{{ + Pos: stmt.End() - 1, + End: stmt.End() - 1, + NewText: indent([]byte(newText), []byte{'\t'}), + }}, + }, nil +} + +func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { + namedType, err := namedTypeFromSwitch(stmt, info) + if err != nil { + return nil, err + } + + scope := namedType.Obj().Pkg().Scope() + variants := make([]string, 0) + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if obj.Id() == namedType.Obj().Id() { + continue + } + + if types.Identical(obj.Type(), namedType.Obj().Type()) { + // TODO: comparing the package name like this feels wrong, is it? + if obj.Pkg().Name() != pkg.Name() { + if !obj.Exported() { + continue + } + + variants = append(variants, obj.Pkg().Name()+"."+obj.Name()) + } else { + variants = append(variants, obj.Name()) + } + } + } + + handledVariants := getHandledVariants(stmt.Body) + if len(variants) == 0 || len(variants) == len(handledVariants) { + return nil, errNoSuggestedFix + } + + newText := buildNewText(variants, handledVariants) + return &analysis.SuggestedFix{ + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), + TextEdits: []analysis.TextEdit{{ + Pos: stmt.End() - 1, + End: stmt.End() - 1, + NewText: indent([]byte(newText), []byte{'\t'}), + }}, + }, nil +} + +func namedTypeFromSwitch(stmt *ast.SwitchStmt, info *types.Info) (*types.Named, error) { + typ := info.TypeOf(stmt.Tag) + if typ == nil { + return nil, errors.New("expected switch statement to have a tag") + } + + namedType, ok := typ.(*types.Named) + if !ok { + return nil, errors.New("switch statement is not on a named type") + } + + return namedType, nil +} + +func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) (*types.Named, error) { + switch s := stmt.Assign.(type) { + case *ast.ExprStmt: + typ, ok := s.X.(*ast.TypeAssertExpr) + if !ok { + return nil, errors.New("type switch expression is not a type assert expression") + } + + namedType, ok := info.TypeOf(typ.X).(*types.Named) + if !ok { + return nil, errors.New("type switch expression is not on a named type") + } + + return namedType, nil + case *ast.AssignStmt: + for _, expr := range s.Rhs { + typ, ok := expr.(*ast.TypeAssertExpr) + if !ok { + continue + } + + namedType, ok := info.TypeOf(typ.X).(*types.Named) + if !ok { + continue + } + + return namedType, nil + } + + return nil, errors.New("expected type switch expression to have a named type") + default: + return nil, errors.New("node is not a type switch statement") + } +} + +func defaultHandled(body *ast.BlockStmt) bool { + for _, bl := range body.List { + if len(bl.(*ast.CaseClause).List) == 0 { + return true + } + } + + return false +} + +func buildNewText(variants []string, handledVariants []string) string { + var textBuilder strings.Builder + for _, c := range variants { + if slices.Contains(handledVariants, c) { + continue + } + + textBuilder.WriteString("case ") + textBuilder.WriteString(c) + textBuilder.WriteString(":\n") + } + + return textBuilder.String() +} + +func getHandledVariants(body *ast.BlockStmt) []string { + out := make([]string, 0) + for _, bl := range body.List { + for _, c := range bl.(*ast.CaseClause).List { + switch v := c.(type) { + case *ast.Ident: + out = append(out, v.Name) + case *ast.SelectorExpr: + out = append(out, v.X.(*ast.Ident).Name+"."+v.Sel.Name) + case *ast.StarExpr: + switch v := v.X.(type) { + case *ast.Ident: + out = append(out, "*"+v.Name) + case *ast.SelectorExpr: + out = append(out, "*"+v.X.(*ast.Ident).Name+"."+v.Sel.Name) + } + } + } + } + + return out +} + +// SuggestedFix computes the suggested fix for the kinds of +// diagnostics produced by the Analyzer above. +func SuggestedFix(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) { + pos := start // don't use the end + path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos) + if len(path) < 2 { + return nil, nil, fmt.Errorf("no expression found") + } + + switch stmt := path[0].(type) { + case *ast.SwitchStmt: + fix, err := suggestedFixSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo()) + if err != nil { + return nil, nil, err + } + + return pkg.FileSet(), fix, nil + case *ast.TypeSwitchStmt: + fix, err := suggestedFixTypeSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo()) + if err != nil { + return nil, nil, err + } + + return pkg.FileSet(), fix, nil + default: + return nil, nil, fmt.Errorf("no switch statement found") + } +} + +// indent works line by line through str, prefixing each line with +// prefix. +func indent(str, prefix []byte) []byte { + split := bytes.Split(str, []byte("\n")) + newText := bytes.NewBuffer(nil) + for i, s := range split { + if i != 0 { + newText.Write(prefix) + } + + newText.Write(s) + if i < len(split)-1 { + newText.WriteByte('\n') + } + } + return newText.Bytes() +} diff --git a/gopls/internal/analysis/fillswitch/fillswitch_test.go b/gopls/internal/analysis/fillswitch/fillswitch_test.go new file mode 100644 index 00000000000..e0cc27ab520 --- /dev/null +++ b/gopls/internal/analysis/fillswitch/fillswitch_test.go @@ -0,0 +1,38 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillswitch_test + +import ( + "go/token" + "testing" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/analysistest" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/gopls/internal/analysis/fillswitch" +) + +// analyzer allows us to test the fillswitch code action using the analysistest +// harness. (fillswitch used to be a gopls analyzer.) +var analyzer = &analysis.Analyzer{ + Name: "fillswitch", + Doc: "test only", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: func(pass *analysis.Pass) (any, error) { + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + for _, d := range fillswitch.Diagnose(inspect, token.NoPos, token.NoPos, pass.Pkg, pass.TypesInfo) { + pass.Report(d) + } + return nil, nil + }, + URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/fillswitch", + RunDespiteErrors: true, +} + +func Test(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, analyzer, "a") +} diff --git a/gopls/internal/analysis/fillswitch/testdata/src/a/a.go b/gopls/internal/analysis/fillswitch/testdata/src/a/a.go new file mode 100644 index 00000000000..05d1a9644eb --- /dev/null +++ b/gopls/internal/analysis/fillswitch/testdata/src/a/a.go @@ -0,0 +1,74 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillswitch + +import ( + data "b" +) + +type typeA int + +const ( + typeAOne typeA = iota + typeATwo + typeAThree +) + +func doSwitch() { + var a typeA + switch a { // want `Switch has missing cases` + } + + switch a { // want `Switch has missing cases` + case typeAOne: + } + + switch a { + case typeAOne: + default: + } + + switch a { + case typeAOne: + case typeATwo: + case typeAThree: + } + + var b data.TypeB + switch b { // want `Switch has missing cases` + case data.TypeBOne: + } +} + +type notification interface { + isNotification() +} + +type notificationOne struct{} + +func (notificationOne) isNotification() {} + +type notificationTwo struct{} + +func (notificationTwo) isNotification() {} + +func doTypeSwitch() { + var not notification + switch not.(type) { // want `Switch has missing cases` + } + + switch not.(type) { // want `Switch has missing cases` + case notificationOne: + } + + switch not.(type) { + case notificationOne: + case notificationTwo: + } + + switch not.(type) { + default: + } +} diff --git a/gopls/internal/analysis/fillswitch/testdata/src/b/b.go b/gopls/internal/analysis/fillswitch/testdata/src/b/b.go new file mode 100644 index 00000000000..af91198410e --- /dev/null +++ b/gopls/internal/analysis/fillswitch/testdata/src/b/b.go @@ -0,0 +1,13 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillswitch + +type TypeB int + +const ( + TypeBOne TypeB = iota + TypeBTwo + TypeBThree +) diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index df4ca513ce2..2a3ab0050ba 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -13,6 +13,7 @@ import ( "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/analysis/fillstruct" + "golang.org/x/tools/gopls/internal/analysis/fillswitch" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/file" @@ -328,6 +329,25 @@ func getRewriteCodeActions(pkg *cache.Package, pgf *parsego.File, fh file.Handle } } + for _, diag := range fillswitch.Diagnose(inspect, start, end, pkg.GetTypes(), pkg.GetTypesInfo()) { + rng, err := pgf.Mapper.PosRange(pgf.Tok, diag.Pos, diag.End) + if err != nil { + return nil, err + } + for _, fix := range diag.SuggestedFixes { + cmd, err := command.NewApplyFixCommand(fix.Message, command.ApplyFixArgs{ + Fix: diag.Category, + URI: pgf.URI, + Range: rng, + ResolveEdits: supportsResolveEdits(options), + }) + if err != nil { + return nil, err + } + commands = append(commands, cmd) + } + } + for i := range commands { actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorRewrite, &commands[i], nil, options)) } diff --git a/gopls/internal/golang/fix.go b/gopls/internal/golang/fix.go index 6f07cb869c5..5e4cdfd408e 100644 --- a/gopls/internal/golang/fix.go +++ b/gopls/internal/golang/fix.go @@ -14,6 +14,7 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/gopls/internal/analysis/embeddirective" "golang.org/x/tools/gopls/internal/analysis/fillstruct" + "golang.org/x/tools/gopls/internal/analysis/fillswitch" "golang.org/x/tools/gopls/internal/analysis/stubmethods" "golang.org/x/tools/gopls/internal/analysis/undeclaredname" "golang.org/x/tools/gopls/internal/analysis/unusedparams" @@ -107,6 +108,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file fillstruct.FixCategory: singleFile(fillstruct.SuggestedFix), stubmethods.FixCategory: stubMethodsFixer, undeclaredname.FixCategory: singleFile(undeclaredname.SuggestedFix), + fillswitch.FixCategory: fillswitch.SuggestedFix, // Ad-hoc fixers: these are used when the command is // constructed directly by logic in server/code_action. diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt b/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt new file mode 100644 index 00000000000..84e347138cd --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt @@ -0,0 +1,76 @@ +This test checks the behavior of the 'fill switch' code action. +See fill_switch_resolve.txt for same test with resolve support. + +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module golang.org/lsptests/fillswitch + +go 1.18 + +-- data/data.go -- +package data + +type TypeB int + +const ( + TypeBOne TypeB = iota + TypeBTwo + TypeBThree +) + +-- a.go -- +package fillswitch + +import ( + "golang.org/lsptests/fillswitch/data" +) + +type typeA int + +const ( + typeAOne typeA = iota + typeATwo + typeAThree +) + +type notification interface { + isNotification() +} + +type notificationOne struct{} + +func (notificationOne) isNotification() {} + +type notificationTwo struct{} + +func (notificationTwo) isNotification() {} + +func doSwitch() { + var b data.TypeB + switch b { + case data.TypeBOne: //@codeactionedit(":", "refactor.rewrite", a1) + } + + var a typeA + switch a { + case typeAThree: //@codeactionedit(":", "refactor.rewrite", a2) + } + + var n notification + switch n.(type) { //@codeactionedit("{", "refactor.rewrite", a3) + } +} +-- @a1/a.go -- +@@ -31 +31,2 @@ ++ case data.TypeBThree: ++ case data.TypeBTwo: +-- @a2/a.go -- +@@ -36 +36,2 @@ ++ case typeAOne: ++ case typeATwo: +-- @a3/a.go -- +@@ -40 +40,2 @@ ++ case notificationOne: ++ case notificationTwo: diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt new file mode 100644 index 00000000000..1b4c1cdbc93 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt @@ -0,0 +1,87 @@ +This test checks the behavior of the 'fill switch' code action, with resolve support. +See fill_switch.txt for same test without resolve support. + +-- capabilities.json -- +{ + "textDocument": { + "codeAction": { + "dataSupport": true, + "resolveSupport": { + "properties": ["edit"] + } + } + } +} +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module golang.org/lsptests/fillswitch + +go 1.18 + +-- data/data.go -- +package data + +type TypeB int + +const ( + TypeBOne TypeB = iota + TypeBTwo + TypeBThree +) + +-- a.go -- +package fillswitch + +import ( + "golang.org/lsptests/fillswitch/data" +) + +type typeA int + +const ( + typeAOne typeA = iota + typeATwo + typeAThree +) + +type notification interface { + isNotification() +} + +type notificationOne struct{} + +func (notificationOne) isNotification() {} + +type notificationTwo struct{} + +func (notificationTwo) isNotification() {} + +func doSwitch() { + var b data.TypeB + switch b { + case data.TypeBOne: //@codeactionedit(":", "refactor.rewrite", a1) + } + + var a typeA + switch a { + case typeAThree: //@codeactionedit(":", "refactor.rewrite", a2) + } + + var n notification + switch n.(type) { //@codeactionedit("{", "refactor.rewrite", a3) + } +} +-- @a1/a.go -- +@@ -31 +31,2 @@ ++ case data.TypeBThree: ++ case data.TypeBTwo: +-- @a2/a.go -- +@@ -36 +36,2 @@ ++ case typeAOne: ++ case typeATwo: +-- @a3/a.go -- +@@ -40 +40,2 @@ ++ case notificationOne: ++ case notificationTwo: From 66395e156d799e450380d03990780f349f7feb17 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Wed, 7 Feb 2024 08:13:29 +0000 Subject: [PATCH 02/20] add test case for unexported types that implement exported interfaces --- gopls/internal/analysis/fillswitch/testdata/src/a/a.go | 4 ++++ gopls/internal/analysis/fillswitch/testdata/src/b/b.go | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/gopls/internal/analysis/fillswitch/testdata/src/a/a.go b/gopls/internal/analysis/fillswitch/testdata/src/a/a.go index 05d1a9644eb..b70c7263ccb 100644 --- a/gopls/internal/analysis/fillswitch/testdata/src/a/a.go +++ b/gopls/internal/analysis/fillswitch/testdata/src/a/a.go @@ -71,4 +71,8 @@ func doTypeSwitch() { switch not.(type) { default: } + + var t data.ExportedInterface + switch t { + } } diff --git a/gopls/internal/analysis/fillswitch/testdata/src/b/b.go b/gopls/internal/analysis/fillswitch/testdata/src/b/b.go index af91198410e..f65f3a7e6f2 100644 --- a/gopls/internal/analysis/fillswitch/testdata/src/b/b.go +++ b/gopls/internal/analysis/fillswitch/testdata/src/b/b.go @@ -11,3 +11,11 @@ const ( TypeBTwo TypeBThree ) + +type ExportedInterface interface { + isExportedInterface() +} + +type notExportedType struct{} + +func (notExportedType) isExportedInterface() {} From b97cb4f91d61dd3c6fd927bbb19b350a2c98f404 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Wed, 7 Feb 2024 08:14:18 +0000 Subject: [PATCH 03/20] address review comments --- .../analysis/fillswitch/fillswitch.go | 126 ++++++++---------- .../analysis/fillswitch/fillswitch_test.go | 2 +- 2 files changed, 59 insertions(+), 69 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 666c469a870..bf5e5fd2947 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -1,9 +1,9 @@ -// Copyright 2020 The Go Authors. All rights reserved. +// Copyright 2024 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package fillswitch defines an Analyzer that automatically -// fills the missing cases in type switches or switches over named types. +// Package fillswitch provides diagnostics and fixes to fills the missing cases +// in type switches or switches over named types. // // The analyzer's diagnostic is merely a prompt. // The actual fix is created by a separate direct call from gopls to @@ -32,12 +32,6 @@ import ( const FixCategory = "fillswitch" // recognized by gopls ApplyFix -// errNoSuggestedFix is returned when no suggested fix is available. This could -// be because all cases are already covered, or (in the case of a type switch) -// because the remaining cases are for types not accessible by the current -// package. -var errNoSuggestedFix = errors.New("no suggested fix") - // Diagnose computes diagnostics for switch statements with missing cases // overlapping with the provided start and end position. // @@ -49,8 +43,10 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac var diags []analysis.Diagnostic nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)} inspect.Preorder(nodeFilter, func(n ast.Node) { - if expr, ok := n.(*ast.SwitchStmt); ok { - if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) { + switch expr := n.(type) { + case *ast.SwitchStmt: + if start.IsValid() && expr.End() < start || + end.IsValid() && expr.Pos() > end { return // non-overlapping } @@ -63,7 +59,7 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac return } - if _, err := suggestedFixSwitch(expr, pkg, info); err != nil { + if fix, err := suggestedFixSwitch(expr, pkg, info); err != nil || fix == nil { return } @@ -77,10 +73,9 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac // No TextEdits => computed later by gopls. }}, }) - } - - if expr, ok := n.(*ast.TypeSwitchStmt); ok { - if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) { + case *ast.TypeSwitchStmt: + if start.IsValid() && expr.End() < start || + end.IsValid() && expr.Pos() > end { return // non-overlapping } @@ -93,7 +88,7 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac return } - if _, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil { + if fix, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil || fix == nil { return } @@ -120,43 +115,40 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * } scope := namedType.Obj().Pkg().Scope() - variants := make([]string, 0) + var variants []string for _, name := range scope.Names() { obj := scope.Lookup(name) if _, ok := obj.(*types.TypeName); !ok { - continue + continue // not a type } if types.Identical(obj.Type(), namedType.Obj().Type()) { continue } - if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { - if obj.Pkg().Name() != pkg.Name() { - if !obj.Exported() { - continue - } + if types.IsInterface(obj.Type()) { + continue + } - variants = append(variants, obj.Pkg().Name()+"."+obj.Name()) - } else { - variants = append(variants, obj.Name()) + name := obj.Name() + samePkg := obj.Pkg() == pkg + if !samePkg { + if !obj.Exported() { + continue // inaccessible } - } else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) { - if obj.Pkg().Name() != pkg.Name() { - if !obj.Exported() { - continue - } + name = obj.Pkg().Name() + name + } - variants = append(variants, "*"+obj.Pkg().Name()+"."+obj.Name()) - } else { - variants = append(variants, "*"+obj.Name()) - } + if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { + variants = append(variants, name) + } else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) { + variants = append(variants, "*"+name) } } - handledVariants := getHandledVariants(stmt.Body) + handledVariants := caseTypes(stmt.Body, info) if len(variants) == 0 || len(variants) == len(handledVariants) { - return nil, errNoSuggestedFix + return nil, nil } newText := buildNewText(variants, handledVariants) @@ -165,7 +157,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - 1, End: stmt.End() - 1, - NewText: indent([]byte(newText), []byte{'\t'}), + NewText: bytes.ReplaceAll([]byte(newText), []byte("\n"), []byte("\n\t")), }}, }, nil } @@ -198,9 +190,9 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } } - handledVariants := getHandledVariants(stmt.Body) + handledVariants := caseTypes(stmt.Body, info) if len(variants) == 0 || len(variants) == len(handledVariants) { - return nil, errNoSuggestedFix + return nil, nil } newText := buildNewText(variants, handledVariants) @@ -209,7 +201,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - 1, End: stmt.End() - 1, - NewText: indent([]byte(newText), []byte{'\t'}), + NewText: bytes.ReplaceAll([]byte(newText), []byte("\n"), []byte("\n\t")), }}, }, nil } @@ -288,20 +280,36 @@ func buildNewText(variants []string, handledVariants []string) string { return textBuilder.String() } -func getHandledVariants(body *ast.BlockStmt) []string { - out := make([]string, 0) - for _, bl := range body.List { - for _, c := range bl.(*ast.CaseClause).List { - switch v := c.(type) { +func caseTypes(body *ast.BlockStmt, info *types.Info) []string { + var out []string + for _, stmt := range body.List { + for _, e := range stmt.(*ast.CaseClause).List { + switch e := e.(type) { case *ast.Ident: - out = append(out, v.Name) + out = append(out, e.Name) case *ast.SelectorExpr: - out = append(out, v.X.(*ast.Ident).Name+"."+v.Sel.Name) + if _, ok := e.X.(*ast.Ident); !ok { + continue + } + + out = append(out, e.X.(*ast.Ident).Name+"."+e.Sel.Name) case *ast.StarExpr: - switch v := v.X.(type) { + switch v := e.X.(type) { case *ast.Ident: + if !info.Types[v].IsType() { + continue + } + out = append(out, "*"+v.Name) case *ast.SelectorExpr: + if !info.Types[v].IsType() { + continue + } + + if _, ok := e.X.(*ast.Ident); !ok { + continue + } + out = append(out, "*"+v.X.(*ast.Ident).Name+"."+v.Sel.Name) } } @@ -339,21 +347,3 @@ func SuggestedFix(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Pack return nil, nil, fmt.Errorf("no switch statement found") } } - -// indent works line by line through str, prefixing each line with -// prefix. -func indent(str, prefix []byte) []byte { - split := bytes.Split(str, []byte("\n")) - newText := bytes.NewBuffer(nil) - for i, s := range split { - if i != 0 { - newText.Write(prefix) - } - - newText.Write(s) - if i < len(split)-1 { - newText.WriteByte('\n') - } - } - return newText.Bytes() -} diff --git a/gopls/internal/analysis/fillswitch/fillswitch_test.go b/gopls/internal/analysis/fillswitch/fillswitch_test.go index e0cc27ab520..15d3ef1dd70 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch_test.go +++ b/gopls/internal/analysis/fillswitch/fillswitch_test.go @@ -16,7 +16,7 @@ import ( ) // analyzer allows us to test the fillswitch code action using the analysistest -// harness. (fillswitch used to be a gopls analyzer.) +// harness. var analyzer = &analysis.Analyzer{ Name: "fillswitch", Doc: "test only", From 6967f913281bb175ab459bd231c3233120e0d2c4 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Wed, 7 Feb 2024 08:37:36 +0000 Subject: [PATCH 04/20] fix typo --- gopls/internal/analysis/fillswitch/fillswitch.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index bf5e5fd2947..e8b66901bd0 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package fillswitch provides diagnostics and fixes to fills the missing cases +// Package fillswitch provides diagnostics and fixes to fill the missing cases // in type switches or switches over named types. // // The analyzer's diagnostic is merely a prompt. From 6292ec9eb2c25935b0e6c7c2ff5109a848b5aed6 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Wed, 7 Feb 2024 08:39:41 +0000 Subject: [PATCH 05/20] move check for default case inside suggested fix functions --- gopls/internal/analysis/fillswitch/fillswitch.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index e8b66901bd0..4f208c9caf1 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -50,10 +50,6 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac return // non-overlapping } - if defaultHandled(expr.Body) { - return - } - namedType, err := namedTypeFromSwitch(expr, info) if err != nil { return @@ -79,10 +75,6 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac return // non-overlapping } - if defaultHandled(expr.Body) { - return - } - namedType, err := namedTypeFromTypeSwitch(expr, info) if err != nil { return @@ -109,6 +101,10 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac } func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { + if defaultHandled(stmt.Body) { + return nil, nil + } + namedType, err := namedTypeFromTypeSwitch(stmt, info) if err != nil { return nil, err @@ -163,6 +159,10 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * } func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { + if defaultHandled(stmt.Body) { + return nil, nil + } + namedType, err := namedTypeFromSwitch(stmt, info) if err != nil { return nil, err From 901c2b01b8d1a70c5ef872aa220af78f0ab1a3db Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Thu, 8 Feb 2024 21:55:28 +0000 Subject: [PATCH 06/20] multiple fixes --- .../internal/analysis/fillswitch/fillswitch.go | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 4f208c9caf1..2d430e14b5a 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -101,7 +101,7 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac } func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { - if defaultHandled(stmt.Body) { + if hasDefaultCase(stmt.Body) { return nil, nil } @@ -118,10 +118,6 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * continue // not a type } - if types.Identical(obj.Type(), namedType.Obj().Type()) { - continue - } - if types.IsInterface(obj.Type()) { continue } @@ -159,7 +155,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * } func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { - if defaultHandled(stmt.Body) { + if hasDefaultCase(stmt.Body) { return nil, nil } @@ -169,16 +165,16 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } scope := namedType.Obj().Pkg().Scope() - variants := make([]string, 0) + var variants []string for _, name := range scope.Names() { obj := scope.Lookup(name) - if obj.Id() == namedType.Obj().Id() { + _, ok := obj.(*types.Const) + if !ok { continue } if types.Identical(obj.Type(), namedType.Obj().Type()) { - // TODO: comparing the package name like this feels wrong, is it? - if obj.Pkg().Name() != pkg.Name() { + if obj.Pkg() != pkg { if !obj.Exported() { continue } @@ -255,7 +251,7 @@ func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) (*types } } -func defaultHandled(body *ast.BlockStmt) bool { +func hasDefaultCase(body *ast.BlockStmt) bool { for _, bl := range body.List { if len(bl.(*ast.CaseClause).List) == 0 { return true From c43c2d94717e142c94e5e5bc4781037fc2254f62 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Fri, 9 Feb 2024 09:26:37 +0000 Subject: [PATCH 07/20] collect types/consts instead of strings --- .../analysis/fillswitch/fillswitch.go | 185 ++++++++++++++---- 1 file changed, 146 insertions(+), 39 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 2d430e14b5a..310627fccec 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -111,7 +111,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * } scope := namedType.Obj().Pkg().Scope() - var variants []string + var variants []types.Type for _, name := range scope.Names() { obj := scope.Lookup(name) if _, ok := obj.(*types.TypeName); !ok { @@ -122,34 +122,29 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * continue } - name := obj.Name() samePkg := obj.Pkg() == pkg - if !samePkg { - if !obj.Exported() { - continue // inaccessible - } - name = obj.Pkg().Name() + name + if !samePkg && !obj.Exported() { + continue // inaccessible } if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { - variants = append(variants, name) + variants = append(variants, obj.Type()) } else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) { - variants = append(variants, "*"+name) + variants = append(variants, types.NewPointer(obj.Type())) } } - handledVariants := caseTypes(stmt.Body, info) + handledVariants := typeSwitchCases(stmt.Body, info) if len(variants) == 0 || len(variants) == len(handledVariants) { return nil, nil } - newText := buildNewText(variants, handledVariants) return &analysis.SuggestedFix{ Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - 1, End: stmt.End() - 1, - NewText: bytes.ReplaceAll([]byte(newText), []byte("\n"), []byte("\n\t")), + NewText: buildNewTypesText(variants, handledVariants, pkg), }}, }, nil } @@ -165,39 +160,35 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } scope := namedType.Obj().Pkg().Scope() - var variants []string + var variants []*types.Const for _, name := range scope.Names() { obj := scope.Lookup(name) - _, ok := obj.(*types.Const) + c, ok := obj.(*types.Const) if !ok { continue } - if types.Identical(obj.Type(), namedType.Obj().Type()) { - if obj.Pkg() != pkg { - if !obj.Exported() { - continue - } + samePkg := obj.Pkg() != pkg + if samePkg && !obj.Exported() { + continue + } - variants = append(variants, obj.Pkg().Name()+"."+obj.Name()) - } else { - variants = append(variants, obj.Name()) - } + if types.Identical(obj.Type(), namedType.Obj().Type()) { + variants = append(variants, c) } } - handledVariants := caseTypes(stmt.Body, info) + handledVariants := caseConsts(stmt.Body, info) if len(variants) == 0 || len(variants) == len(handledVariants) { return nil, nil } - newText := buildNewText(variants, handledVariants) return &analysis.SuggestedFix{ Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - 1, End: stmt.End() - 1, - NewText: bytes.ReplaceAll([]byte(newText), []byte("\n"), []byte("\n\t")), + NewText: buildNewConstsText(variants, handledVariants, pkg), }}, }, nil } @@ -261,7 +252,7 @@ func hasDefaultCase(body *ast.BlockStmt) bool { return false } -func buildNewText(variants []string, handledVariants []string) string { +func buildNewConstsText(variants []*types.Const, handledVariants []*types.Const, currentPkg *types.Package) []byte { var textBuilder strings.Builder for _, c := range variants { if slices.Contains(handledVariants, c) { @@ -269,44 +260,160 @@ func buildNewText(variants []string, handledVariants []string) string { } textBuilder.WriteString("case ") - textBuilder.WriteString(c) + if c.Pkg() != currentPkg { + textBuilder.WriteString(c.Pkg().Name() + "." + c.Name()) + } else { + textBuilder.WriteString(c.Name()) + } + textBuilder.WriteString(":\n") + } + + return bytes.ReplaceAll([]byte(textBuilder.String()), []byte("\n"), []byte("\n\t")) +} + +func isSameType(c, t types.Type) bool { + if types.Identical(c, t) { + return true + } + + if p, ok := c.(*types.Pointer); ok && types.Identical(p.Elem(), t) { + return true + } + + if p, ok := t.(*types.Pointer); ok && types.Identical(p.Elem(), c) { + return true + } + + return false +} + +func buildNewTypesText(variants []types.Type, handledVariants []types.Type, currentPkg *types.Package) []byte { + var textBuilder strings.Builder + for _, c := range variants { + if slices.ContainsFunc(handledVariants, func(t types.Type) bool { return isSameType(c, t) }) { + continue + } + + textBuilder.WriteString("case ") + switch t := c.(type) { + case *types.Named: + if t.Obj().Pkg() != currentPkg { + textBuilder.WriteString(t.Obj().Pkg().Name() + "." + t.Obj().Name()) + } else { + textBuilder.WriteString(t.Obj().Name()) + } + case *types.Pointer: + e, ok := t.Elem().(*types.Named) + if !ok { + continue + } + + if e.Obj().Pkg() != currentPkg { + textBuilder.WriteString("*" + e.Obj().Pkg().Name() + "." + e.Obj().Name()) + } else { + textBuilder.WriteString("*" + e.Obj().Name()) + } + } + textBuilder.WriteString(":\n") } - return textBuilder.String() + return bytes.ReplaceAll([]byte(textBuilder.String()), []byte("\n"), []byte("\n\t")) } -func caseTypes(body *ast.BlockStmt, info *types.Info) []string { - var out []string +func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const { + var out []*types.Const for _, stmt := range body.List { for _, e := range stmt.(*ast.CaseClause).List { + if !info.Types[e].IsValue() { + continue + } + switch e := e.(type) { case *ast.Ident: - out = append(out, e.Name) + obj, ok := info.Uses[e] + if !ok { + continue + } + c, ok := obj.(*types.Const) + if !ok { + continue + } + + out = append(out, c) case *ast.SelectorExpr: - if _, ok := e.X.(*ast.Ident); !ok { + _, ok := e.X.(*ast.Ident) + if !ok { + continue + } + + obj, ok := info.Uses[e.Sel] + if !ok { + continue + } + + c, ok := obj.(*types.Const) + if !ok { + continue + } + + out = append(out, c) + } + } + } + + return out +} + +func typeSwitchCases(body *ast.BlockStmt, info *types.Info) []types.Type { + var out []types.Type + for _, stmt := range body.List { + for _, e := range stmt.(*ast.CaseClause).List { + if !info.Types[e].IsType() { + continue + } + + switch e := e.(type) { + case *ast.Ident: + obj, ok := info.Uses[e] + if !ok { + continue + } + + out = append(out, obj.Type()) + case *ast.SelectorExpr: + i, ok := e.X.(*ast.Ident) + if !ok { + continue + } + + obj, ok := info.Uses[i] + if !ok { continue } - out = append(out, e.X.(*ast.Ident).Name+"."+e.Sel.Name) + out = append(out, obj.Type()) case *ast.StarExpr: switch v := e.X.(type) { case *ast.Ident: - if !info.Types[v].IsType() { + obj, ok := info.Uses[v] + if !ok { continue } - out = append(out, "*"+v.Name) + out = append(out, obj.Type()) case *ast.SelectorExpr: - if !info.Types[v].IsType() { + i, ok := e.X.(*ast.Ident) + if !ok { continue } - if _, ok := e.X.(*ast.Ident); !ok { + obj, ok := info.Uses[i] + if !ok { continue } - out = append(out, "*"+v.X.(*ast.Ident).Name+"."+v.Sel.Name) + out = append(out, obj.Type()) } } } From c84559bd73202a61cf4613394fe722a19ef521cb Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Fri, 9 Feb 2024 14:29:35 +0000 Subject: [PATCH 08/20] return edits on diagnostic --- gopls/internal/analysis/fillswitch/doc.go | 65 +++++++++++++ .../analysis/fillswitch/fillswitch.go | 96 +++++-------------- .../analysis/fillswitch/testdata/src/a/a.go | 10 +- gopls/internal/golang/codeaction.go | 32 ++++--- gopls/internal/golang/fix.go | 2 - 5 files changed, 111 insertions(+), 94 deletions(-) create mode 100644 gopls/internal/analysis/fillswitch/doc.go diff --git a/gopls/internal/analysis/fillswitch/doc.go b/gopls/internal/analysis/fillswitch/doc.go new file mode 100644 index 00000000000..1a0fe4e9a1b --- /dev/null +++ b/gopls/internal/analysis/fillswitch/doc.go @@ -0,0 +1,65 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fillswitch identifies switches with missing cases. +// +// It will provide diagnostics for type switches or switches over named types +// that are missing cases and provides a code action to fill those in. +// +// If the switch statement is over a named type, it will suggest cases for all +// const values that are assignable to the named type. +// +// type T int +// const ( +// A T = iota +// B +// C +// ) +// +// var t T +// switch t { +// case A: +// } +// +// It will provide a diagnostic with a suggested edit to fill in the remaining +// cases: +// +// var t T +// switch t { +// case A: +// case B: +// case C: +// } +// +// If the switch statement is over type of an interface, it will suggest cases for all types +// that implement the interface. +// +// type I interface { +// M() +// } +// +// type T struct{} +// func (t *T) M() {} +// +// type E struct{} +// func (e *E) M() {} +// +// var i I +// switch i.(type) { +// case *T: +// } +// +// It will provide a diagnostic with a suggested edit to fill in the remaining +// cases: +// +// var i I +// switch i.(type) { +// case *T: +// case *E: +// } +// +// The provided diagnostics will only suggest cases for types that are defined +// on the same package as the switch statement, or for types that are exported; +// and it will not suggest any case if the switch handles the default case. +package fillswitch diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 310627fccec..6bbee377458 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -14,7 +14,6 @@ package fillswitch import ( "bytes" - "context" "errors" "fmt" "go/ast" @@ -24,20 +23,14 @@ import ( "strings" "golang.org/x/tools/go/analysis" - "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/ast/inspector" - "golang.org/x/tools/gopls/internal/cache" - "golang.org/x/tools/gopls/internal/cache/parsego" ) -const FixCategory = "fillswitch" // recognized by gopls ApplyFix +const FixCategory = "fillswitch" // Diagnose computes diagnostics for switch statements with missing cases // overlapping with the provided start and end position. // -// The diagnostic contains a lazy fix; the actual patch is computed -// (via the ApplyFix command) by a call to [SuggestedFix]. -// // If either start or end is invalid, the entire package is inspected. func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic { var diags []analysis.Diagnostic @@ -50,24 +43,17 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac return // non-overlapping } - namedType, err := namedTypeFromSwitch(expr, info) - if err != nil { - return - } - - if fix, err := suggestedFixSwitch(expr, pkg, info); err != nil || fix == nil { + fix, err := suggestedFixSwitch(expr, pkg, info) + if err != nil || fix == nil { return } diags = append(diags, analysis.Diagnostic{ - Message: "Switch has missing cases", - Pos: expr.Pos(), - End: expr.End(), - Category: FixCategory, - SuggestedFixes: []analysis.SuggestedFix{{ - Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), - // No TextEdits => computed later by gopls. - }}, + Message: fix.Message, + Pos: expr.Pos(), + End: expr.End(), + Category: FixCategory, + SuggestedFixes: []analysis.SuggestedFix{*fix}, }) case *ast.TypeSwitchStmt: if start.IsValid() && expr.End() < start || @@ -75,24 +61,17 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac return // non-overlapping } - namedType, err := namedTypeFromTypeSwitch(expr, info) - if err != nil { - return - } - - if fix, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil || fix == nil { + fix, err := suggestedFixTypeSwitch(expr, pkg, info) + if err != nil || fix == nil { return } diags = append(diags, analysis.Diagnostic{ - Message: "Switch has missing cases", - Pos: expr.Pos(), - End: expr.End(), - Category: FixCategory, - SuggestedFixes: []analysis.SuggestedFix{{ - Message: fmt.Sprintf("Add cases for %v", namedType.Obj().Name()), - // No TextEdits => computed later by gopls. - }}, + Message: fix.Message, + Pos: expr.Pos(), + End: expr.End(), + Category: FixCategory, + SuggestedFixes: []analysis.SuggestedFix{*fix}, }) } }) @@ -134,7 +113,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * } } - handledVariants := typeSwitchCases(stmt.Body, info) + handledVariants := caseTypes(stmt.Body, info) if len(variants) == 0 || len(variants) == len(handledVariants) { return nil, nil } @@ -144,7 +123,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - 1, End: stmt.End() - 1, - NewText: buildNewTypesText(variants, handledVariants, pkg), + NewText: buildTypesText(variants, handledVariants, pkg), }}, }, nil } @@ -170,7 +149,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In samePkg := obj.Pkg() != pkg if samePkg && !obj.Exported() { - continue + continue // inaccessible } if types.Identical(obj.Type(), namedType.Obj().Type()) { @@ -188,7 +167,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - 1, End: stmt.End() - 1, - NewText: buildNewConstsText(variants, handledVariants, pkg), + NewText: buildConstsText(variants, handledVariants, pkg), }}, }, nil } @@ -252,7 +231,7 @@ func hasDefaultCase(body *ast.BlockStmt) bool { return false } -func buildNewConstsText(variants []*types.Const, handledVariants []*types.Const, currentPkg *types.Package) []byte { +func buildConstsText(variants []*types.Const, handledVariants []*types.Const, currentPkg *types.Package) []byte { var textBuilder strings.Builder for _, c := range variants { if slices.Contains(handledVariants, c) { @@ -287,7 +266,7 @@ func isSameType(c, t types.Type) bool { return false } -func buildNewTypesText(variants []types.Type, handledVariants []types.Type, currentPkg *types.Package) []byte { +func buildTypesText(variants []types.Type, handledVariants []types.Type, currentPkg *types.Package) []byte { var textBuilder strings.Builder for _, c := range variants { if slices.ContainsFunc(handledVariants, func(t types.Type) bool { return isSameType(c, t) }) { @@ -309,6 +288,7 @@ func buildNewTypesText(variants []types.Type, handledVariants []types.Type, curr } if e.Obj().Pkg() != currentPkg { + // TODO: use the correct package name when the import is renamed textBuilder.WriteString("*" + e.Obj().Pkg().Name() + "." + e.Obj().Name()) } else { textBuilder.WriteString("*" + e.Obj().Name()) @@ -335,6 +315,7 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const { if !ok { continue } + c, ok := obj.(*types.Const) if !ok { continue @@ -365,7 +346,7 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const { return out } -func typeSwitchCases(body *ast.BlockStmt, info *types.Info) []types.Type { +func caseTypes(body *ast.BlockStmt, info *types.Info) []types.Type { var out []types.Type for _, stmt := range body.List { for _, e := range stmt.(*ast.CaseClause).List { @@ -421,32 +402,3 @@ func typeSwitchCases(body *ast.BlockStmt, info *types.Info) []types.Type { return out } - -// SuggestedFix computes the suggested fix for the kinds of -// diagnostics produced by the Analyzer above. -func SuggestedFix(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) { - pos := start // don't use the end - path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos) - if len(path) < 2 { - return nil, nil, fmt.Errorf("no expression found") - } - - switch stmt := path[0].(type) { - case *ast.SwitchStmt: - fix, err := suggestedFixSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo()) - if err != nil { - return nil, nil, err - } - - return pkg.FileSet(), fix, nil - case *ast.TypeSwitchStmt: - fix, err := suggestedFixTypeSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo()) - if err != nil { - return nil, nil, err - } - - return pkg.FileSet(), fix, nil - default: - return nil, nil, fmt.Errorf("no switch statement found") - } -} diff --git a/gopls/internal/analysis/fillswitch/testdata/src/a/a.go b/gopls/internal/analysis/fillswitch/testdata/src/a/a.go index b70c7263ccb..06d01da5f1e 100644 --- a/gopls/internal/analysis/fillswitch/testdata/src/a/a.go +++ b/gopls/internal/analysis/fillswitch/testdata/src/a/a.go @@ -18,10 +18,10 @@ const ( func doSwitch() { var a typeA - switch a { // want `Switch has missing cases` + switch a { // want `Add cases for typeA` } - switch a { // want `Switch has missing cases` + switch a { // want `Add cases for typeA` case typeAOne: } @@ -37,7 +37,7 @@ func doSwitch() { } var b data.TypeB - switch b { // want `Switch has missing cases` + switch b { // want `Add cases for TypeB` case data.TypeBOne: } } @@ -56,10 +56,10 @@ func (notificationTwo) isNotification() {} func doTypeSwitch() { var not notification - switch not.(type) { // want `Switch has missing cases` + switch not.(type) { // want `Add cases for notification` } - switch not.(type) { // want `Switch has missing cases` + switch not.(type) { // want `Add cases for notification` case notificationOne: } diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 2a3ab0050ba..b1f14fe31e5 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -99,7 +99,7 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, return nil, err } if want[protocol.RefactorRewrite] { - rewrites, err := getRewriteCodeActions(pkg, pgf, fh, rng, snapshot.Options()) + rewrites, err := getRewriteCodeActions(ctx, pkg, snapshot, pgf, fh, rng, snapshot.Options()) if err != nil { return nil, err } @@ -253,8 +253,7 @@ func newCodeAction(title string, kind protocol.CodeActionKind, cmd *protocol.Com return action } -// getRewriteCodeActions returns refactor.rewrite code actions available at the specified range. -func getRewriteCodeActions(pkg *cache.Package, pgf *parsego.File, fh file.Handle, rng protocol.Range, options *settings.Options) (_ []protocol.CodeAction, rerr error) { +func getRewriteCodeActions(ctx context.Context, pkg *cache.Package, snapshot *cache.Snapshot, pgf *parsego.File, fh file.Handle, rng protocol.Range, options *settings.Options) (_ []protocol.CodeAction, rerr error) { // golang/go#61693: code actions were refactored to run outside of the // analysis framework, but as a result they lost their panic recovery. // @@ -330,24 +329,27 @@ func getRewriteCodeActions(pkg *cache.Package, pgf *parsego.File, fh file.Handle } for _, diag := range fillswitch.Diagnose(inspect, start, end, pkg.GetTypes(), pkg.GetTypesInfo()) { - rng, err := pgf.Mapper.PosRange(pgf.Tok, diag.Pos, diag.End) + edits, err := suggestedFixToEdits(ctx, snapshot, pkg.FileSet(), &diag.SuggestedFixes[0]) if err != nil { return nil, err } - for _, fix := range diag.SuggestedFixes { - cmd, err := command.NewApplyFixCommand(fix.Message, command.ApplyFixArgs{ - Fix: diag.Category, - URI: pgf.URI, - Range: rng, - ResolveEdits: supportsResolveEdits(options), + + changes := []protocol.DocumentChanges{} // must be a slice + for _, edit := range edits { + edit := edit + changes = append(changes, protocol.DocumentChanges{ + TextDocumentEdit: &edit, }) - if err != nil { - return nil, err - } - commands = append(commands, cmd) } - } + actions = append(actions, protocol.CodeAction{ + Title: diag.Message, + Kind: protocol.RefactorRewrite, + Edit: &protocol.WorkspaceEdit{ + DocumentChanges: changes, + }, + }) + } for i := range commands { actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorRewrite, &commands[i], nil, options)) } diff --git a/gopls/internal/golang/fix.go b/gopls/internal/golang/fix.go index 5e4cdfd408e..6f07cb869c5 100644 --- a/gopls/internal/golang/fix.go +++ b/gopls/internal/golang/fix.go @@ -14,7 +14,6 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/gopls/internal/analysis/embeddirective" "golang.org/x/tools/gopls/internal/analysis/fillstruct" - "golang.org/x/tools/gopls/internal/analysis/fillswitch" "golang.org/x/tools/gopls/internal/analysis/stubmethods" "golang.org/x/tools/gopls/internal/analysis/undeclaredname" "golang.org/x/tools/gopls/internal/analysis/unusedparams" @@ -108,7 +107,6 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file fillstruct.FixCategory: singleFile(fillstruct.SuggestedFix), stubmethods.FixCategory: stubMethodsFixer, undeclaredname.FixCategory: singleFile(undeclaredname.SuggestedFix), - fillswitch.FixCategory: fillswitch.SuggestedFix, // Ad-hoc fixers: these are used when the command is // constructed directly by logic in server/code_action. From c4feaf32a2fab0d6940b7ee8050dda917131b257 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Sat, 10 Feb 2024 13:30:18 +0000 Subject: [PATCH 09/20] more fixes --- gopls/internal/analysis/fillswitch/doc.go | 57 ++++----- .../analysis/fillswitch/fillswitch.go | 110 +++++++++--------- 2 files changed, 86 insertions(+), 81 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/doc.go b/gopls/internal/analysis/fillswitch/doc.go index 1a0fe4e9a1b..b83c5ec2a18 100644 --- a/gopls/internal/analysis/fillswitch/doc.go +++ b/gopls/internal/analysis/fillswitch/doc.go @@ -4,36 +4,43 @@ // Package fillswitch identifies switches with missing cases. // -// It will provide diagnostics for type switches or switches over named types -// that are missing cases and provides a code action to fill those in. +// It reports a diagnostic for each type switch or 'enum' switch that +// has missing cases, and suggests a fix to fill them in. // -// If the switch statement is over a named type, it will suggest cases for all -// const values that are assignable to the named type. +// The possible cases are: for a type switch, each accessible named +// type T or pointer *T that is assignable to the interface type; and +// for an 'enum' switch, each accessible named constant of the same +// type as the switch value. // -// type T int -// const ( -// A T = iota -// B -// C -// ) +// For an 'enum' switch, it will suggest cases for all possible values of the +// type. // -// var t T -// switch t { -// case A: -// } +// type Suit int8 +// const ( +// Spades Suit = iota +// Hearts +// Diamonds +// Clubs +// ) +// +// var s Suit +// switch s { +// case Spades: +// } // -// It will provide a diagnostic with a suggested edit to fill in the remaining +// It will report a diagnostic with a suggested fix to fill in the remaining // cases: // -// var t T -// switch t { -// case A: -// case B: -// case C: +// var s Suit +// switch s { +// case Spades: +// case Hearts: +// case Diamons: +// case Clubs: // } // -// If the switch statement is over type of an interface, it will suggest cases for all types -// that implement the interface. +// For a type switch, it will suggest cases for all types that implement the +// interface. // // type I interface { // M() @@ -50,7 +57,7 @@ // case *T: // } // -// It will provide a diagnostic with a suggested edit to fill in the remaining +// It will report a diagnostic with a suggested fix to fill in the remaining // cases: // // var i I @@ -58,8 +65,4 @@ // case *T: // case *E: // } -// -// The provided diagnostics will only suggest cases for types that are defined -// on the same package as the switch statement, or for types that are exported; -// and it will not suggest any case if the switch handles the default case. package fillswitch diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 6bbee377458..4c94a2662cd 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -26,8 +26,6 @@ import ( "golang.org/x/tools/go/ast/inspector" ) -const FixCategory = "fillswitch" - // Diagnose computes diagnostics for switch statements with missing cases // overlapping with the provided start and end position. // @@ -36,44 +34,35 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac var diags []analysis.Diagnostic nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)} inspect.Preorder(nodeFilter, func(n ast.Node) { - switch expr := n.(type) { - case *ast.SwitchStmt: - if start.IsValid() && expr.End() < start || - end.IsValid() && expr.Pos() > end { - return // non-overlapping - } + if start.IsValid() && n.End() < start || + end.IsValid() && n.Pos() > end { + return // non-overlapping + } - fix, err := suggestedFixSwitch(expr, pkg, info) - if err != nil || fix == nil { + var fix *analysis.SuggestedFix + switch n := n.(type) { + case *ast.SwitchStmt: + f, err := suggestedFixSwitch(n, pkg, info) + if err != nil || f == nil { return } - diags = append(diags, analysis.Diagnostic{ - Message: fix.Message, - Pos: expr.Pos(), - End: expr.End(), - Category: FixCategory, - SuggestedFixes: []analysis.SuggestedFix{*fix}, - }) + fix = f case *ast.TypeSwitchStmt: - if start.IsValid() && expr.End() < start || - end.IsValid() && expr.Pos() > end { - return // non-overlapping - } - - fix, err := suggestedFixTypeSwitch(expr, pkg, info) - if err != nil || fix == nil { + f, err := suggestedFixTypeSwitch(n, pkg, info) + if err != nil || f == nil { return } - diags = append(diags, analysis.Diagnostic{ - Message: fix.Message, - Pos: expr.Pos(), - End: expr.End(), - Category: FixCategory, - SuggestedFixes: []analysis.SuggestedFix{*fix}, - }) + fix = f } + + diags = append(diags, analysis.Diagnostic{ + Message: fix.Message, + Pos: n.Pos(), + End: n.End(), + SuggestedFixes: []analysis.SuggestedFix{*fix}, + }) }) return diags @@ -89,6 +78,8 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * return nil, err } + // Gather accessible package-level concrete types + // that implement the switch interface type. scope := namedType.Obj().Pkg().Scope() var variants []types.Type for _, name := range scope.Names() { @@ -108,13 +99,17 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { variants = append(variants, obj.Type()) - } else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) { - variants = append(variants, types.NewPointer(obj.Type())) + } else if ptr := types.NewPointer(obj.Type()); types.AssignableTo(ptr, namedType.Obj().Type()) { + variants = append(variants, ptr) } } - handledVariants := caseTypes(stmt.Body, info) - if len(variants) == 0 || len(variants) == len(handledVariants) { + if len(variants) == 0 { + return nil, nil + } + + newText := buildTypesText(stmt.Body, variants, pkg, info) + if newText == nil { return nil, nil } @@ -123,7 +118,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - 1, End: stmt.End() - 1, - NewText: buildTypesText(variants, handledVariants, pkg), + NewText: newText, }}, }, nil } @@ -138,27 +133,24 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In return nil, err } + // Gather accessible named constants of the same type as the switch value. scope := namedType.Obj().Pkg().Scope() var variants []*types.Const for _, name := range scope.Names() { obj := scope.Lookup(name) - c, ok := obj.(*types.Const) - if !ok { - continue - } - - samePkg := obj.Pkg() != pkg - if samePkg && !obj.Exported() { - continue // inaccessible - } - - if types.Identical(obj.Type(), namedType.Obj().Type()) { + if c, ok := obj.(*types.Const); ok && + (obj.Pkg() == pkg || obj.Exported()) && // accessible + types.Identical(obj.Type(), namedType.Obj().Type()) { variants = append(variants, c) } } - handledVariants := caseConsts(stmt.Body, info) - if len(variants) == 0 || len(variants) == len(handledVariants) { + if len(variants) == 0 { + return nil, nil + } + + newText := buildConstsText(stmt.Body, variants, pkg, info) + if newText == nil { return nil, nil } @@ -167,7 +159,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - 1, End: stmt.End() - 1, - NewText: buildConstsText(variants, handledVariants, pkg), + NewText: newText, }}, }, nil } @@ -222,8 +214,8 @@ func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) (*types } func hasDefaultCase(body *ast.BlockStmt) bool { - for _, bl := range body.List { - if len(bl.(*ast.CaseClause).List) == 0 { + for _, clause := range body.List { + if len(clause.(*ast.CaseClause).List) == 0 { return true } } @@ -231,7 +223,12 @@ func hasDefaultCase(body *ast.BlockStmt) bool { return false } -func buildConstsText(variants []*types.Const, handledVariants []*types.Const, currentPkg *types.Package) []byte { +func buildConstsText(body *ast.BlockStmt, variants []*types.Const, currentPkg *types.Package, info *types.Info) []byte { + handledVariants := caseConsts(body, info) + if len(variants) == len(handledVariants) { + return nil + } + var textBuilder strings.Builder for _, c := range variants { if slices.Contains(handledVariants, c) { @@ -266,7 +263,12 @@ func isSameType(c, t types.Type) bool { return false } -func buildTypesText(variants []types.Type, handledVariants []types.Type, currentPkg *types.Package) []byte { +func buildTypesText(body *ast.BlockStmt, variants []types.Type, currentPkg *types.Package, info *types.Info) []byte { + handledVariants := caseTypes(body, info) + if len(variants) == len(handledVariants) { + return nil + } + var textBuilder strings.Builder for _, c := range variants { if slices.ContainsFunc(handledVariants, func(t types.Type) bool { return isSameType(c, t) }) { From 344eae75be52433efcce09d2ce3691203de672f1 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Sat, 10 Feb 2024 15:06:33 +0000 Subject: [PATCH 10/20] introduce namedVariant --- .../analysis/fillswitch/fillswitch.go | 153 ++++++------------ 1 file changed, 53 insertions(+), 100 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 4c94a2662cd..e96afbd45b6 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -19,7 +19,6 @@ import ( "go/ast" "go/token" "go/types" - "slices" "strings" "golang.org/x/tools/go/analysis" @@ -81,7 +80,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * // Gather accessible package-level concrete types // that implement the switch interface type. scope := namedType.Obj().Pkg().Scope() - var variants []types.Type + var variants []namedVariant for _, name := range scope.Names() { obj := scope.Lookup(name) if _, ok := obj.(*types.TypeName); !ok { @@ -98,9 +97,19 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * } if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { - variants = append(variants, obj.Type()) + named, ok := obj.Type().(*types.Named) + if !ok { + continue + } + + variants = append(variants, namedVariant{named: named, ptr: false}) } else if ptr := types.NewPointer(obj.Type()); types.AssignableTo(ptr, namedType.Obj().Type()) { - variants = append(variants, ptr) + named, ok := obj.Type().(*types.Named) + if !ok { + continue + } + + variants = append(variants, namedVariant{named: named, ptr: true}) } } @@ -231,7 +240,7 @@ func buildConstsText(body *ast.BlockStmt, variants []*types.Const, currentPkg *t var textBuilder strings.Builder for _, c := range variants { - if slices.Contains(handledVariants, c) { + if _, ok := handledVariants[c]; ok { continue } @@ -247,23 +256,7 @@ func buildConstsText(body *ast.BlockStmt, variants []*types.Const, currentPkg *t return bytes.ReplaceAll([]byte(textBuilder.String()), []byte("\n"), []byte("\n\t")) } -func isSameType(c, t types.Type) bool { - if types.Identical(c, t) { - return true - } - - if p, ok := c.(*types.Pointer); ok && types.Identical(p.Elem(), t) { - return true - } - - if p, ok := t.(*types.Pointer); ok && types.Identical(p.Elem(), c) { - return true - } - - return false -} - -func buildTypesText(body *ast.BlockStmt, variants []types.Type, currentPkg *types.Package, info *types.Info) []byte { +func buildTypesText(body *ast.BlockStmt, variants []namedVariant, currentPkg *types.Package, info *types.Info) []byte { handledVariants := caseTypes(body, info) if len(variants) == len(handledVariants) { return nil @@ -271,66 +264,41 @@ func buildTypesText(body *ast.BlockStmt, variants []types.Type, currentPkg *type var textBuilder strings.Builder for _, c := range variants { - if slices.ContainsFunc(handledVariants, func(t types.Type) bool { return isSameType(c, t) }) { - continue + if handledVariants[c] { + continue // already handled } textBuilder.WriteString("case ") - switch t := c.(type) { - case *types.Named: - if t.Obj().Pkg() != currentPkg { - textBuilder.WriteString(t.Obj().Pkg().Name() + "." + t.Obj().Name()) - } else { - textBuilder.WriteString(t.Obj().Name()) - } - case *types.Pointer: - e, ok := t.Elem().(*types.Named) - if !ok { - continue - } - - if e.Obj().Pkg() != currentPkg { - // TODO: use the correct package name when the import is renamed - textBuilder.WriteString("*" + e.Obj().Pkg().Name() + "." + e.Obj().Name()) - } else { - textBuilder.WriteString("*" + e.Obj().Name()) - } + if c.ptr { + textBuilder.WriteString("*") } + if pkg := c.named.Obj().Pkg(); pkg != currentPkg { + // TODO: use the correct package name when the import is renamed + textBuilder.WriteString(pkg.Name()) + textBuilder.WriteByte('.') + } + textBuilder.WriteString(c.named.Obj().Name()) textBuilder.WriteString(":\n") } return bytes.ReplaceAll([]byte(textBuilder.String()), []byte("\n"), []byte("\n\t")) } -func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const { - var out []*types.Const +func caseConsts(body *ast.BlockStmt, info *types.Info) map[*types.Const]bool { + out := map[*types.Const]bool{} for _, stmt := range body.List { for _, e := range stmt.(*ast.CaseClause).List { if !info.Types[e].IsValue() { continue } - switch e := e.(type) { - case *ast.Ident: - obj, ok := info.Uses[e] - if !ok { - continue - } - - c, ok := obj.(*types.Const) - if !ok { - continue - } - - out = append(out, c) - case *ast.SelectorExpr: - _, ok := e.X.(*ast.Ident) - if !ok { - continue - } + if sel, ok := e.(*ast.SelectorExpr); ok { + e = sel.Sel // replace pkg.C with C + } - obj, ok := info.Uses[e.Sel] + if e, ok := e.(*ast.Ident); ok { + obj, ok := info.Uses[e] if !ok { continue } @@ -340,7 +308,7 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const { continue } - out = append(out, c) + out[c] = true } } } @@ -348,56 +316,41 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const { return out } -func caseTypes(body *ast.BlockStmt, info *types.Info) []types.Type { - var out []types.Type +type namedVariant struct { + named *types.Named + ptr bool +} + +func caseTypes(body *ast.BlockStmt, info *types.Info) map[namedVariant]bool { + out := map[namedVariant]bool{} for _, stmt := range body.List { for _, e := range stmt.(*ast.CaseClause).List { if !info.Types[e].IsType() { continue } - switch e := e.(type) { - case *ast.Ident: - obj, ok := info.Uses[e] - if !ok { - continue - } + var ptr bool + if str, ok := e.(*ast.StarExpr); ok { + ptr = true + e = str.X // replace *T with T + } + + if sel, ok := e.(*ast.SelectorExpr); ok { + e = sel.Sel // replace pkg.C with C + } - out = append(out, obj.Type()) - case *ast.SelectorExpr: - i, ok := e.X.(*ast.Ident) + if e, ok := e.(*ast.Ident); ok { + obj, ok := info.Uses[e] if !ok { continue } - obj, ok := info.Uses[i] + named, ok := obj.Type().(*types.Named) if !ok { continue } - out = append(out, obj.Type()) - case *ast.StarExpr: - switch v := e.X.(type) { - case *ast.Ident: - obj, ok := info.Uses[v] - if !ok { - continue - } - - out = append(out, obj.Type()) - case *ast.SelectorExpr: - i, ok := e.X.(*ast.Ident) - if !ok { - continue - } - - obj, ok := info.Uses[i] - if !ok { - continue - } - - out = append(out, obj.Type()) - } + out[namedVariant{named: named, ptr: ptr}] = true } } } From 4066149ee04ef2b2c875c53cc7665355246f1e85 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Sat, 10 Feb 2024 17:28:48 +0000 Subject: [PATCH 11/20] update documentation --- gopls/internal/analysis/fillswitch/doc.go | 44 ++++++++++------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/doc.go b/gopls/internal/analysis/fillswitch/doc.go index b83c5ec2a18..33952c8deaa 100644 --- a/gopls/internal/analysis/fillswitch/doc.go +++ b/gopls/internal/analysis/fillswitch/doc.go @@ -15,18 +15,18 @@ // For an 'enum' switch, it will suggest cases for all possible values of the // type. // -// type Suit int8 -// const ( +// type Suit int8 +// const ( // Spades Suit = iota // Hearts // Diamonds -// Clubs -// ) +// Clubs +// ) // -// var s Suit -// switch s { -// case Spades: -// } +// var s Suit +// switch s { +// case Spades: +// } // // It will report a diagnostic with a suggested fix to fill in the remaining // cases: @@ -42,27 +42,21 @@ // For a type switch, it will suggest cases for all types that implement the // interface. // -// type I interface { -// M() -// } -// -// type T struct{} -// func (t *T) M() {} -// -// type E struct{} -// func (e *E) M() {} -// -// var i I -// switch i.(type) { -// case *T: +// var stmt ast.Stmt +// switch stmt.(type) { +// case *ast.IfStmt // } // // It will report a diagnostic with a suggested fix to fill in the remaining // cases: // -// var i I -// switch i.(type) { -// case *T: -// case *E: +// var stmt ast.Stmt +// switch stmt.(type) { +// case *ast.IfStmt +// case *ast.ForStmt +// case *ast.RangeStmt +// case *ast.AssignStmt +// case *ast.GoStmt +// ... // } package fillswitch From 2a470ab4e8efa7b656841898fa7c903dc80e81f7 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Sat, 10 Feb 2024 21:46:35 +0000 Subject: [PATCH 12/20] remove unnecessary errors --- gopls/internal/analysis/fillswitch/doc.go | 14 +- .../analysis/fillswitch/fillswitch.go | 203 +++++++----------- 2 files changed, 79 insertions(+), 138 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/doc.go b/gopls/internal/analysis/fillswitch/doc.go index 33952c8deaa..c5235577740 100644 --- a/gopls/internal/analysis/fillswitch/doc.go +++ b/gopls/internal/analysis/fillswitch/doc.go @@ -35,7 +35,7 @@ // switch s { // case Spades: // case Hearts: -// case Diamons: +// case Diamonds: // case Clubs: // } // @@ -44,7 +44,7 @@ // // var stmt ast.Stmt // switch stmt.(type) { -// case *ast.IfStmt +// case *ast.IfStmt: // } // // It will report a diagnostic with a suggested fix to fill in the remaining @@ -52,11 +52,11 @@ // // var stmt ast.Stmt // switch stmt.(type) { -// case *ast.IfStmt -// case *ast.ForStmt -// case *ast.RangeStmt -// case *ast.AssignStmt -// case *ast.GoStmt +// case *ast.IfStmt: +// case *ast.ForStmt: +// case *ast.RangeStmt: +// case *ast.AssignStmt: +// case *ast.GoStmt: // ... // } package fillswitch diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index e96afbd45b6..8cb247333b0 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -2,19 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package fillswitch provides diagnostics and fixes to fill the missing cases -// in type switches or switches over named types. -// -// The analyzer's diagnostic is merely a prompt. -// The actual fix is created by a separate direct call from gopls to -// the SuggestedFixes function. -// Tests of Analyzer.Run can be found in ./testdata/src. -// Tests of the SuggestedFixes logic live in ../../testdata/fillswitch. package fillswitch import ( - "bytes" - "errors" "fmt" "go/ast" "go/token" @@ -41,19 +31,13 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac var fix *analysis.SuggestedFix switch n := n.(type) { case *ast.SwitchStmt: - f, err := suggestedFixSwitch(n, pkg, info) - if err != nil || f == nil { - return - } - - fix = f + fix = suggestedFixSwitch(n, pkg, info) case *ast.TypeSwitchStmt: - f, err := suggestedFixTypeSwitch(n, pkg, info) - if err != nil || f == nil { - return - } + fix = suggestedFixTypeSwitch(n, pkg, info) + } - fix = f + if fix == nil { + return } diags = append(diags, analysis.Diagnostic{ @@ -67,20 +51,20 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac return diags } -func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { +func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix { if hasDefaultCase(stmt.Body) { - return nil, nil + return nil } - namedType, err := namedTypeFromTypeSwitch(stmt, info) - if err != nil { - return nil, err + namedType := namedTypeFromTypeSwitch(stmt, info) + if namedType == nil { + return nil } // Gather accessible package-level concrete types // that implement the switch interface type. scope := namedType.Obj().Pkg().Scope() - var variants []namedVariant + var variants []caseType for _, name := range scope.Names() { obj := scope.Lookup(name) if _, ok := obj.(*types.TypeName); !ok { @@ -102,45 +86,42 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * continue } - variants = append(variants, namedVariant{named: named, ptr: false}) + variants = append(variants, caseType{named, false}) } else if ptr := types.NewPointer(obj.Type()); types.AssignableTo(ptr, namedType.Obj().Type()) { named, ok := obj.Type().(*types.Named) if !ok { continue } - variants = append(variants, namedVariant{named: named, ptr: true}) + variants = append(variants, caseType{named, true}) } } if len(variants) == 0 { - return nil, nil + return nil } newText := buildTypesText(stmt.Body, variants, pkg, info) if newText == nil { - return nil, nil + return nil } return &analysis.SuggestedFix{ Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), TextEdits: []analysis.TextEdit{{ - Pos: stmt.End() - 1, - End: stmt.End() - 1, + Pos: stmt.End() - token.Pos(len("}")), + End: stmt.End() - token.Pos(len("}")), NewText: newText, }}, - }, nil + } } -func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { +func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix { if hasDefaultCase(stmt.Body) { - return nil, nil + return nil } - namedType, err := namedTypeFromSwitch(stmt, info) - if err != nil { - return nil, err - } + namedType := namedTypeFromSwitch(stmt, info) // Gather accessible named constants of the same type as the switch value. scope := namedType.Obj().Pkg().Scope() @@ -155,70 +136,53 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } if len(variants) == 0 { - return nil, nil + return nil } newText := buildConstsText(stmt.Body, variants, pkg, info) if newText == nil { - return nil, nil + return nil } return &analysis.SuggestedFix{ Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), TextEdits: []analysis.TextEdit{{ - Pos: stmt.End() - 1, - End: stmt.End() - 1, + Pos: stmt.End() - token.Pos(len("}")), + End: stmt.End() - token.Pos(len("}")), NewText: newText, }}, - }, nil -} - -func namedTypeFromSwitch(stmt *ast.SwitchStmt, info *types.Info) (*types.Named, error) { - typ := info.TypeOf(stmt.Tag) - if typ == nil { - return nil, errors.New("expected switch statement to have a tag") } +} - namedType, ok := typ.(*types.Named) +func namedTypeFromSwitch(stmt *ast.SwitchStmt, info *types.Info) *types.Named { + namedType, ok := info.TypeOf(stmt.Tag).(*types.Named) if !ok { - return nil, errors.New("switch statement is not on a named type") + return nil } - return namedType, nil + return namedType } -func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) (*types.Named, error) { +func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named { switch s := stmt.Assign.(type) { case *ast.ExprStmt: - typ, ok := s.X.(*ast.TypeAssertExpr) - if !ok { - return nil, errors.New("type switch expression is not a type assert expression") - } - - namedType, ok := info.TypeOf(typ.X).(*types.Named) - if !ok { - return nil, errors.New("type switch expression is not on a named type") + if typ, ok := s.X.(*ast.TypeAssertExpr); ok { + if named, ok := info.TypeOf(typ.X).(*types.Named); ok { + return named + } } - return namedType, nil + return nil case *ast.AssignStmt: - for _, expr := range s.Rhs { - typ, ok := expr.(*ast.TypeAssertExpr) - if !ok { - continue - } - - namedType, ok := info.TypeOf(typ.X).(*types.Named) - if !ok { - continue + if typ, ok := s.Rhs[0].(*ast.TypeAssertExpr); ok { + if named, ok := info.TypeOf(typ.X).(*types.Named); ok { + return named } - - return namedType, nil } - return nil, errors.New("expected type switch expression to have a named type") + return nil default: - return nil, errors.New("node is not a type switch statement") + return nil } } @@ -238,59 +202,59 @@ func buildConstsText(body *ast.BlockStmt, variants []*types.Const, currentPkg *t return nil } - var textBuilder strings.Builder + var buf strings.Builder for _, c := range variants { if _, ok := handledVariants[c]; ok { continue } - textBuilder.WriteString("case ") + buf.WriteString("case ") if c.Pkg() != currentPkg { - textBuilder.WriteString(c.Pkg().Name() + "." + c.Name()) - } else { - textBuilder.WriteString(c.Name()) + buf.WriteString(c.Pkg().Name()) + buf.WriteByte('.') } - textBuilder.WriteString(":\n") + buf.WriteString(c.Name()) + buf.WriteString(":\n\t") } - return bytes.ReplaceAll([]byte(textBuilder.String()), []byte("\n"), []byte("\n\t")) + return []byte(buf.String()) } -func buildTypesText(body *ast.BlockStmt, variants []namedVariant, currentPkg *types.Package, info *types.Info) []byte { +func buildTypesText(body *ast.BlockStmt, variants []caseType, currentPkg *types.Package, info *types.Info) []byte { handledVariants := caseTypes(body, info) if len(variants) == len(handledVariants) { return nil } - var textBuilder strings.Builder + var buf strings.Builder for _, c := range variants { if handledVariants[c] { continue // already handled } - textBuilder.WriteString("case ") + buf.WriteString("case ") if c.ptr { - textBuilder.WriteString("*") + buf.WriteByte('*') } if pkg := c.named.Obj().Pkg(); pkg != currentPkg { // TODO: use the correct package name when the import is renamed - textBuilder.WriteString(pkg.Name()) - textBuilder.WriteByte('.') + buf.WriteString(pkg.Name()) + buf.WriteByte('.') } - textBuilder.WriteString(c.named.Obj().Name()) - textBuilder.WriteString(":\n") + buf.WriteString(c.named.Obj().Name()) + buf.WriteString(":\n\t") } - return bytes.ReplaceAll([]byte(textBuilder.String()), []byte("\n"), []byte("\n\t")) + return []byte(buf.String()) } func caseConsts(body *ast.BlockStmt, info *types.Info) map[*types.Const]bool { out := map[*types.Const]bool{} for _, stmt := range body.List { for _, e := range stmt.(*ast.CaseClause).List { - if !info.Types[e].IsValue() { - continue + if info.Types[e].Value == nil { + continue // not a constant } if sel, ok := e.(*ast.SelectorExpr); ok { @@ -298,17 +262,9 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) map[*types.Const]bool { } if e, ok := e.(*ast.Ident); ok { - obj, ok := info.Uses[e] - if !ok { - continue + if c, ok := info.Uses[e].(*types.Const); ok { + out[c] = true } - - c, ok := obj.(*types.Const) - if !ok { - continue - } - - out[c] = true } } } @@ -316,41 +272,26 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) map[*types.Const]bool { return out } -type namedVariant struct { +type caseType struct { named *types.Named ptr bool } -func caseTypes(body *ast.BlockStmt, info *types.Info) map[namedVariant]bool { - out := map[namedVariant]bool{} +func caseTypes(body *ast.BlockStmt, info *types.Info) map[caseType]bool { + out := map[caseType]bool{} for _, stmt := range body.List { for _, e := range stmt.(*ast.CaseClause).List { - if !info.Types[e].IsType() { - continue - } - - var ptr bool - if str, ok := e.(*ast.StarExpr); ok { - ptr = true - e = str.X // replace *T with T - } - - if sel, ok := e.(*ast.SelectorExpr); ok { - e = sel.Sel // replace pkg.C with C - } - - if e, ok := e.(*ast.Ident); ok { - obj, ok := info.Uses[e] - if !ok { - continue + if tv, ok := info.Types[e]; ok && tv.IsType() { + t := tv.Type + ptr := false + if p, ok := t.(*types.Pointer); ok { + t = p.Elem() + ptr = true } - named, ok := obj.Type().(*types.Named) - if !ok { - continue + if named, ok := t.(*types.Named); ok { + out[caseType{named, ptr}] = true } - - out[namedVariant{named: named, ptr: ptr}] = true } } } From 388646e02b72dd5456aaec4ced52b5d45deb96cd Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Sun, 11 Feb 2024 12:15:01 +0000 Subject: [PATCH 13/20] invert order of handledVariants and variants calculations --- .../analysis/fillswitch/fillswitch.go | 57 +++++++++---------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 8cb247333b0..72d522e443d 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -61,6 +61,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * return nil } + handledVariants := caseTypes(stmt.Body, info) // Gather accessible package-level concrete types // that implement the switch interface type. scope := namedType.Obj().Pkg().Scope() @@ -80,20 +81,25 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * continue // inaccessible } + var key caseType if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { - named, ok := obj.Type().(*types.Named) - if !ok { - continue + if named, ok := obj.Type().(*types.Named); ok { + key.named = named + key.ptr = false } - - variants = append(variants, caseType{named, false}) } else if ptr := types.NewPointer(obj.Type()); types.AssignableTo(ptr, namedType.Obj().Type()) { - named, ok := obj.Type().(*types.Named) - if !ok { + if named, ok := obj.Type().(*types.Named); ok { + key.named = named + key.ptr = true + } + } + + if key.named != nil { + if _, ok := handledVariants[key]; ok { continue } - variants = append(variants, caseType{named, true}) + variants = append(variants, key) } } @@ -101,7 +107,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * return nil } - newText := buildTypesText(stmt.Body, variants, pkg, info) + newText := buildTypesText(variants, pkg) if newText == nil { return nil } @@ -122,7 +128,11 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } namedType := namedTypeFromSwitch(stmt, info) + if namedType == nil { + return nil + } + handledVariants := caseConsts(stmt.Body, info) // Gather accessible named constants of the same type as the switch value. scope := namedType.Obj().Pkg().Scope() var variants []*types.Const @@ -131,6 +141,11 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In if c, ok := obj.(*types.Const); ok && (obj.Pkg() == pkg || obj.Exported()) && // accessible types.Identical(obj.Type(), namedType.Obj().Type()) { + + if _, ok := handledVariants[c]; ok { + continue + } + variants = append(variants, c) } } @@ -139,7 +154,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In return nil } - newText := buildConstsText(stmt.Body, variants, pkg, info) + newText := buildConstsText(variants, pkg) if newText == nil { return nil } @@ -196,18 +211,9 @@ func hasDefaultCase(body *ast.BlockStmt) bool { return false } -func buildConstsText(body *ast.BlockStmt, variants []*types.Const, currentPkg *types.Package, info *types.Info) []byte { - handledVariants := caseConsts(body, info) - if len(variants) == len(handledVariants) { - return nil - } - +func buildConstsText(variants []*types.Const, currentPkg *types.Package) []byte { var buf strings.Builder for _, c := range variants { - if _, ok := handledVariants[c]; ok { - continue - } - buf.WriteString("case ") if c.Pkg() != currentPkg { buf.WriteString(c.Pkg().Name()) @@ -220,18 +226,9 @@ func buildConstsText(body *ast.BlockStmt, variants []*types.Const, currentPkg *t return []byte(buf.String()) } -func buildTypesText(body *ast.BlockStmt, variants []caseType, currentPkg *types.Package, info *types.Info) []byte { - handledVariants := caseTypes(body, info) - if len(variants) == len(handledVariants) { - return nil - } - +func buildTypesText(variants []caseType, currentPkg *types.Package) []byte { var buf strings.Builder for _, c := range variants { - if handledVariants[c] { - continue // already handled - } - buf.WriteString("case ") if c.ptr { buf.WriteByte('*') From 2f874f51f779ddeaf6e51cd9dbb5afabf0909bc5 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Mon, 12 Feb 2024 08:43:38 +0000 Subject: [PATCH 14/20] more fixes --- .../analysis/fillswitch/fillswitch.go | 139 +++++++----------- 1 file changed, 51 insertions(+), 88 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 72d522e443d..f3099b0660f 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -5,11 +5,11 @@ package fillswitch import ( + "bytes" "fmt" "go/ast" "go/token" "go/types" - "strings" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/inspector" @@ -43,7 +43,7 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac diags = append(diags, analysis.Diagnostic{ Message: fix.Message, Pos: n.Pos(), - End: n.End(), + End: n.Pos() + token.Pos(len("switch")), SuggestedFixes: []analysis.SuggestedFix{*fix}, }) }) @@ -61,15 +61,15 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * return nil } - handledVariants := caseTypes(stmt.Body, info) + existingCases := caseTypes(stmt.Body, info) // Gather accessible package-level concrete types // that implement the switch interface type. scope := namedType.Obj().Pkg().Scope() - var variants []caseType + var buf bytes.Buffer for _, name := range scope.Names() { obj := scope.Lookup(name) - if _, ok := obj.(*types.TypeName); !ok { - continue // not a type + if tname, ok := obj.(*types.TypeName); !ok || tname.IsAlias() { + continue // not a defined type } if types.IsInterface(obj.Type()) { @@ -83,32 +83,37 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * var key caseType if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { - if named, ok := obj.Type().(*types.Named); ok { - key.named = named - key.ptr = false - } + key.named = obj.Type().(*types.Named) } else if ptr := types.NewPointer(obj.Type()); types.AssignableTo(ptr, namedType.Obj().Type()) { - if named, ok := obj.Type().(*types.Named); ok { - key.named = named - key.ptr = true - } + key.named = obj.Type().(*types.Named) + key.ptr = true } if key.named != nil { - if _, ok := handledVariants[key]; ok { + if _, ok := existingCases[key]; ok { continue } - variants = append(variants, key) - } - } + if buf.Len() > 0 { + buf.WriteString("\t") + } - if len(variants) == 0 { - return nil + buf.WriteString("case ") + if key.ptr { + buf.WriteByte('*') + } + + if p := key.named.Obj().Pkg(); p != pkg { + // TODO: use the correct package name when the import is renamed + buf.WriteString(p.Name()) + buf.WriteByte('.') + } + buf.WriteString(key.named.Obj().Name()) + buf.WriteString(":\n") + } } - newText := buildTypesText(variants, pkg) - if newText == nil { + if buf.Len() == 0 { return nil } @@ -117,7 +122,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - token.Pos(len("}")), End: stmt.End() - token.Pos(len("}")), - NewText: newText, + NewText: buf.Bytes(), }}, } } @@ -127,35 +132,40 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In return nil } - namedType := namedTypeFromSwitch(stmt, info) - if namedType == nil { + namedType, ok := info.TypeOf(stmt.Tag).(*types.Named) + if !ok { return nil } - handledVariants := caseConsts(stmt.Body, info) + existingCases := caseConsts(stmt.Body, info) // Gather accessible named constants of the same type as the switch value. scope := namedType.Obj().Pkg().Scope() - var variants []*types.Const + var buf bytes.Buffer for _, name := range scope.Names() { obj := scope.Lookup(name) if c, ok := obj.(*types.Const); ok && (obj.Pkg() == pkg || obj.Exported()) && // accessible types.Identical(obj.Type(), namedType.Obj().Type()) { - if _, ok := handledVariants[c]; ok { + if _, ok := existingCases[c]; ok { continue } - variants = append(variants, c) - } - } + if buf.Len() > 0 { + buf.WriteString("\t") + } - if len(variants) == 0 { - return nil + buf.WriteString("case ") + if c.Pkg() != pkg { + buf.WriteString(c.Pkg().Name()) + buf.WriteByte('.') + } + buf.WriteString(c.Name()) + buf.WriteString(":\n") + } } - newText := buildConstsText(variants, pkg) - if newText == nil { + if buf.Len() == 0 { return nil } @@ -164,41 +174,29 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In TextEdits: []analysis.TextEdit{{ Pos: stmt.End() - token.Pos(len("}")), End: stmt.End() - token.Pos(len("}")), - NewText: newText, + NewText: buf.Bytes(), }}, } } -func namedTypeFromSwitch(stmt *ast.SwitchStmt, info *types.Info) *types.Named { - namedType, ok := info.TypeOf(stmt.Tag).(*types.Named) - if !ok { - return nil - } - - return namedType -} - func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named { - switch s := stmt.Assign.(type) { + switch assign := stmt.Assign.(type) { case *ast.ExprStmt: - if typ, ok := s.X.(*ast.TypeAssertExpr); ok { + if typ, ok := assign.X.(*ast.TypeAssertExpr); ok { if named, ok := info.TypeOf(typ.X).(*types.Named); ok { return named } } - return nil case *ast.AssignStmt: - if typ, ok := s.Rhs[0].(*ast.TypeAssertExpr); ok { + if typ, ok := assign.Rhs[0].(*ast.TypeAssertExpr); ok { if named, ok := info.TypeOf(typ.X).(*types.Named); ok { return named } } - - return nil - default: - return nil } + + return nil } func hasDefaultCase(body *ast.BlockStmt) bool { @@ -211,41 +209,6 @@ func hasDefaultCase(body *ast.BlockStmt) bool { return false } -func buildConstsText(variants []*types.Const, currentPkg *types.Package) []byte { - var buf strings.Builder - for _, c := range variants { - buf.WriteString("case ") - if c.Pkg() != currentPkg { - buf.WriteString(c.Pkg().Name()) - buf.WriteByte('.') - } - buf.WriteString(c.Name()) - buf.WriteString(":\n\t") - } - - return []byte(buf.String()) -} - -func buildTypesText(variants []caseType, currentPkg *types.Package) []byte { - var buf strings.Builder - for _, c := range variants { - buf.WriteString("case ") - if c.ptr { - buf.WriteByte('*') - } - - if pkg := c.named.Obj().Pkg(); pkg != currentPkg { - // TODO: use the correct package name when the import is renamed - buf.WriteString(pkg.Name()) - buf.WriteByte('.') - } - buf.WriteString(c.named.Obj().Name()) - buf.WriteString(":\n\t") - } - - return []byte(buf.String()) -} - func caseConsts(body *ast.BlockStmt, info *types.Info) map[*types.Const]bool { out := map[*types.Const]bool{} for _, stmt := range body.List { From 3e1e812a9fd6506a58da57a547755767ae05316f Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Mon, 12 Feb 2024 09:43:06 +0000 Subject: [PATCH 15/20] add default case --- .../analysis/fillswitch/fillswitch.go | 40 +++++++++++++++--- gopls/internal/test/.DS_Store | Bin 0 -> 6148 bytes gopls/internal/test/marker/.DS_Store | Bin 0 -> 6148 bytes gopls/internal/test/marker/testdata/.DS_Store | Bin 0 -> 6148 bytes .../testdata/codeaction/fill_switch.txt | 35 +++++++++++++-- .../codeaction/fill_switch_resolve.txt | 35 +++++++++++++-- 6 files changed, 99 insertions(+), 11 deletions(-) create mode 100644 gopls/internal/test/.DS_Store create mode 100644 gopls/internal/test/marker/.DS_Store create mode 100644 gopls/internal/test/marker/testdata/.DS_Store diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index f3099b0660f..8c7ef4f0866 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -117,6 +117,15 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * return nil } + switch assign := stmt.Assign.(type) { + case *ast.AssignStmt: + addDefaultCase(namedType, 'T', assign.Lhs[0], pkg, &buf) + case *ast.ExprStmt: + if assert, ok := assign.X.(*ast.TypeAssertExpr); ok { + addDefaultCase(namedType, 'T', assert.X, pkg, &buf) + } + } + return &analysis.SuggestedFix{ Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), TextEdits: []analysis.TextEdit{{ @@ -145,11 +154,8 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In obj := scope.Lookup(name) if c, ok := obj.(*types.Const); ok && (obj.Pkg() == pkg || obj.Exported()) && // accessible - types.Identical(obj.Type(), namedType.Obj().Type()) { - - if _, ok := existingCases[c]; ok { - continue - } + types.Identical(obj.Type(), namedType.Obj().Type()) && + !existingCases[c] { if buf.Len() > 0 { buf.WriteString("\t") @@ -169,6 +175,8 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In return nil } + addDefaultCase(namedType, 'v', stmt.Tag, pkg, &buf) + return &analysis.SuggestedFix{ Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), TextEdits: []analysis.TextEdit{{ @@ -179,6 +187,28 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } } +func addDefaultCase(named *types.Named, formatVerb byte, expr ast.Expr, pkg *types.Package, buf *bytes.Buffer) { + buf.WriteString("\tdefault:\n") + buf.WriteString("\t\tpanic(fmt.Sprintf(\"unexpected ") + if named.Obj().Pkg() != pkg { + buf.WriteString(named.Obj().Pkg().Name()) + buf.WriteByte('.') + } + buf.WriteString(named.Obj().Name()) + buf.WriteString(": %") + buf.WriteByte(formatVerb) + buf.WriteString("\", ") + switch tag := expr.(type) { + case *ast.SelectorExpr: + fmt.Fprint(buf, tag.X) + buf.WriteByte('.') + buf.WriteString(tag.Sel.Name) + case *ast.Ident: + buf.WriteString(tag.Name) + } + buf.WriteString("))\n\t") +} + func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named { switch assign := stmt.Assign.(type) { case *ast.ExprStmt: diff --git a/gopls/internal/test/.DS_Store b/gopls/internal/test/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..10782002443a3ff87eb8d6fac54ff0c239c3d1ac GIT binary patch literal 6148 zcmeHK%}T>S5Z<-5O({YS3OxqA7Od6^;w8lT0!H+pQWFw1G|iSIHHT8jSzpK}@p+ut z-5i1eZz6UEcE8#A+0A^A{b7u8XB8YXW-`VMXowt@20?S9YexqoayduHQoqQSek}Zk ziTS+|={=5agoyo0 z_x#%^EN15ZrATrwO2SkoL_r844|h=#h{6%8B*Kn(oA0PYVKG(=ltsZeem(BbtN{cS`P(D5ySXlb-HmI}cG!c{7u zO6B^A!BslgEgffTEETGB#^uT|k6yWYyl}ZX*ew~(xUG)Dvj|`+!`7P8T&em8e V#97cT(*fxsAPJ$482AMSz5wJ&Oe6pR literal 0 HcmV?d00001 diff --git a/gopls/internal/test/marker/.DS_Store b/gopls/internal/test/marker/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8c4b80c96a8fc0fdc0ce7617151ec08ef4e4ee84 GIT binary patch literal 6148 zcmeHK%}T>S5Z-O8-BN@c6nYGJEm*A;#7l_v1&ruHr6we3FlI}WnnNk%tS{t~_&m<+ zZotwUJc-yD*!^bbXE*af_J=XX-FbM(n8g?q&=5H)HG<|^SHlD&ay3WDe3mmmD6(nI zM1RqQ-!9;rBbGt`?fZj72H^V$rb(RTgTW_nG+SHSZP6B8aqm6J!pr@9p1J<)7Drc7 zCPAh9!F3!J6KnTernw)d(M%=8VFV#}H*p%u!jpQhquM&4!|OBpTZky2<68pJHs}~EHG&6( z>r_CU%FPpl>vXW&CeATfYSihBtCe9MvvT!#;c9iT+bW!KMM^6;#+p*=uD!MG9?5YSgH0WiRQq^F$PZ=()z Yj=@qR&VqK84oDXPMF@4oz%MZH1>n(4`~Uy| literal 0 HcmV?d00001 diff --git a/gopls/internal/test/marker/testdata/.DS_Store b/gopls/internal/test/marker/testdata/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..818909b77d1d8de441a6cd796adcea39f603b22e GIT binary patch literal 6148 zcmeHK%}T>S5T0$TO)NqW3OxqA7Od6^;w9Aj0!H+pQWFz27_+5K&7l->))(?gd>&_Z zw_<5McoDHPF#FBUPnP{Q><<8l?l{~7r~`n7N?5Y7St1lCU67LT5DN7TAMPN77$l?% z+3fg>4A9zb;HpUoA%ZXK7v!TskjMbzZ^0mmM_H@&E=uLf+IrQhS~cs&dyu)8`Pq2X z^@mr~JC`yEX0{((#8KX_Z=J|A^W!ucs)9I-Fy;C(PD7b@8DomE=$P6$8KVpFP2M3kVHCSj=TL(6DeWZAWkOXadOAtB+ zU4w;2jGzczil|G4d144%j(*4Fxdsc3x*UWW8NXvj7UqQ_)adAUR5}P(Be%=|GceCU zS$FGn{-1vT{+}=69y7oU{3`}TrRVj!I3;toE=-QjS_kzOm4xyNjh`iGsG}Hj=_qcX aDnY+P2BK@Q(1;!sz6dB9xM2o Date: Mon, 12 Feb 2024 15:31:04 +0000 Subject: [PATCH 16/20] fix documentation and handle ast.CallExpr as switch tag --- gopls/internal/analysis/fillswitch/doc.go | 18 +++++++++------ .../analysis/fillswitch/fillswitch.go | 21 +++++++----------- gopls/internal/test/.DS_Store | Bin 6148 -> 0 bytes gopls/internal/test/marker/.DS_Store | Bin 6148 -> 0 bytes gopls/internal/test/marker/testdata/.DS_Store | Bin 6148 -> 0 bytes 5 files changed, 19 insertions(+), 20 deletions(-) delete mode 100644 gopls/internal/test/.DS_Store delete mode 100644 gopls/internal/test/marker/.DS_Store delete mode 100644 gopls/internal/test/marker/testdata/.DS_Store diff --git a/gopls/internal/analysis/fillswitch/doc.go b/gopls/internal/analysis/fillswitch/doc.go index c5235577740..076c3a1323d 100644 --- a/gopls/internal/analysis/fillswitch/doc.go +++ b/gopls/internal/analysis/fillswitch/doc.go @@ -15,13 +15,13 @@ // For an 'enum' switch, it will suggest cases for all possible values of the // type. // -// type Suit int8 -// const ( -// Spades Suit = iota -// Hearts -// Diamonds -// Clubs -// ) +// type Suit int8 +// const ( +// Spades Suit = iota +// Hearts +// Diamonds +// Clubs +// ) // // var s Suit // switch s { @@ -37,6 +37,8 @@ // case Hearts: // case Diamonds: // case Clubs: +// default: +// panic(fmt.Sprintf("unexpected Suit: %v", s)) // } // // For a type switch, it will suggest cases for all types that implement the @@ -58,5 +60,7 @@ // case *ast.AssignStmt: // case *ast.GoStmt: // ... +// default: +// panic(fmt.Sprintf("unexpected ast.Stmt: %T", stmt)) // } package fillswitch diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 8c7ef4f0866..25d59a5dfd9 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -90,7 +90,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * } if key.named != nil { - if _, ok := existingCases[key]; ok { + if existingCases[key] { continue } @@ -117,12 +117,13 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * return nil } + // TODO: don't evaluate the assert.X expression a second time, switch assign := stmt.Assign.(type) { case *ast.AssignStmt: - addDefaultCase(namedType, 'T', assign.Lhs[0], pkg, &buf) + addDefaultCase(&buf, namedType, 'T', assign.Lhs[0], pkg) case *ast.ExprStmt: if assert, ok := assign.X.(*ast.TypeAssertExpr); ok { - addDefaultCase(namedType, 'T', assert.X, pkg, &buf) + addDefaultCase(&buf, namedType, 'T', assert.X, pkg) } } @@ -175,7 +176,8 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In return nil } - addDefaultCase(namedType, 'v', stmt.Tag, pkg, &buf) + // TODO: don't re-evaluate stmt.Tag. + addDefaultCase(&buf, namedType, 'v', stmt.Tag, pkg) return &analysis.SuggestedFix{ Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), @@ -187,7 +189,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } } -func addDefaultCase(named *types.Named, formatVerb byte, expr ast.Expr, pkg *types.Package, buf *bytes.Buffer) { +func addDefaultCase(buf *bytes.Buffer, named *types.Named, formatVerb byte, expr ast.Expr, pkg *types.Package) { buf.WriteString("\tdefault:\n") buf.WriteString("\t\tpanic(fmt.Sprintf(\"unexpected ") if named.Obj().Pkg() != pkg { @@ -198,14 +200,7 @@ func addDefaultCase(named *types.Named, formatVerb byte, expr ast.Expr, pkg *typ buf.WriteString(": %") buf.WriteByte(formatVerb) buf.WriteString("\", ") - switch tag := expr.(type) { - case *ast.SelectorExpr: - fmt.Fprint(buf, tag.X) - buf.WriteByte('.') - buf.WriteString(tag.Sel.Name) - case *ast.Ident: - buf.WriteString(tag.Name) - } + buf.WriteString(types.ExprString(expr)) buf.WriteString("))\n\t") } diff --git a/gopls/internal/test/.DS_Store b/gopls/internal/test/.DS_Store deleted file mode 100644 index 10782002443a3ff87eb8d6fac54ff0c239c3d1ac..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}T>S5Z<-5O({YS3OxqA7Od6^;w8lT0!H+pQWFw1G|iSIHHT8jSzpK}@p+ut z-5i1eZz6UEcE8#A+0A^A{b7u8XB8YXW-`VMXowt@20?S9YexqoayduHQoqQSek}Zk ziTS+|={=5agoyo0 z_x#%^EN15ZrATrwO2SkoL_r844|h=#h{6%8B*Kn(oA0PYVKG(=ltsZeem(BbtN{cS`P(D5ySXlb-HmI}cG!c{7u zO6B^A!BslgEgffTEETGB#^uT|k6yWYyl}ZX*ew~(xUG)Dvj|`+!`7P8T&em8e V#97cT(*fxsAPJ$482AMSz5wJ&Oe6pR diff --git a/gopls/internal/test/marker/.DS_Store b/gopls/internal/test/marker/.DS_Store deleted file mode 100644 index 8c4b80c96a8fc0fdc0ce7617151ec08ef4e4ee84..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}T>S5Z-O8-BN@c6nYGJEm*A;#7l_v1&ruHr6we3FlI}WnnNk%tS{t~_&m<+ zZotwUJc-yD*!^bbXE*af_J=XX-FbM(n8g?q&=5H)HG<|^SHlD&ay3WDe3mmmD6(nI zM1RqQ-!9;rBbGt`?fZj72H^V$rb(RTgTW_nG+SHSZP6B8aqm6J!pr@9p1J<)7Drc7 zCPAh9!F3!J6KnTernw)d(M%=8VFV#}H*p%u!jpQhquM&4!|OBpTZky2<68pJHs}~EHG&6( z>r_CU%FPpl>vXW&CeATfYSihBtCe9MvvT!#;c9iT+bW!KMM^6;#+p*=uD!MG9?5YSgH0WiRQq^F$PZ=()z Yj=@qR&VqK84oDXPMF@4oz%MZH1>n(4`~Uy| diff --git a/gopls/internal/test/marker/testdata/.DS_Store b/gopls/internal/test/marker/testdata/.DS_Store deleted file mode 100644 index 818909b77d1d8de441a6cd796adcea39f603b22e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}T>S5T0$TO)NqW3OxqA7Od6^;w9Aj0!H+pQWFz27_+5K&7l->))(?gd>&_Z zw_<5McoDHPF#FBUPnP{Q><<8l?l{~7r~`n7N?5Y7St1lCU67LT5DN7TAMPN77$l?% z+3fg>4A9zb;HpUoA%ZXK7v!TskjMbzZ^0mmM_H@&E=uLf+IrQhS~cs&dyu)8`Pq2X z^@mr~JC`yEX0{((#8KX_Z=J|A^W!ucs)9I-Fy;C(PD7b@8DomE=$P6$8KVpFP2M3kVHCSj=TL(6DeWZAWkOXadOAtB+ zU4w;2jGzczil|G4d144%j(*4Fxdsc3x*UWW8NXvj7UqQ_)adAUR5}P(Be%=|GceCU zS$FGn{-1vT{+}=69y7oU{3`}TrRVj!I3;toE=-QjS_kzOm4xyNjh`iGsG}Hj=_qcX aDnY+P2BK@Q(1;!sz6dB9xM2o Date: Mon, 12 Feb 2024 15:49:28 +0000 Subject: [PATCH 17/20] use types.WriteExpr instead of types.ExprString --- gopls/internal/analysis/fillswitch/fillswitch.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 25d59a5dfd9..66878dd1b3f 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -200,7 +200,7 @@ func addDefaultCase(buf *bytes.Buffer, named *types.Named, formatVerb byte, expr buf.WriteString(": %") buf.WriteByte(formatVerb) buf.WriteString("\", ") - buf.WriteString(types.ExprString(expr)) + types.WriteExpr(buf, expr) buf.WriteString("))\n\t") } From bb094689c479e7d42ad6e76b179612b8c878d7ec Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Tue, 13 Feb 2024 18:13:10 +0000 Subject: [PATCH 18/20] do not use types.WriteExpr --- .../analysis/fillswitch/fillswitch.go | 54 ++++++++++++++----- .../testdata/codeaction/fill_switch.txt | 10 ++-- .../codeaction/fill_switch_resolve.txt | 10 ++-- 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 66878dd1b3f..730ae6544ff 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -120,10 +120,10 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * // TODO: don't evaluate the assert.X expression a second time, switch assign := stmt.Assign.(type) { case *ast.AssignStmt: - addDefaultCase(&buf, namedType, 'T', assign.Lhs[0], pkg) + addDefaultCase(&buf, namedType, assign.Lhs[0]) case *ast.ExprStmt: if assert, ok := assign.X.(*ast.TypeAssertExpr); ok { - addDefaultCase(&buf, namedType, 'T', assert.X, pkg) + addDefaultCase(&buf, namedType, assert.X) } } @@ -177,7 +177,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } // TODO: don't re-evaluate stmt.Tag. - addDefaultCase(&buf, namedType, 'v', stmt.Tag, pkg) + addDefaultCase(&buf, namedType, stmt.Tag) return &analysis.SuggestedFix{ Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), @@ -189,19 +189,45 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } } -func addDefaultCase(buf *bytes.Buffer, named *types.Named, formatVerb byte, expr ast.Expr, pkg *types.Package) { +func addDefaultCase(buf *bytes.Buffer, named *types.Named, expr ast.Expr) { buf.WriteString("\tdefault:\n") - buf.WriteString("\t\tpanic(fmt.Sprintf(\"unexpected ") - if named.Obj().Pkg() != pkg { - buf.WriteString(named.Obj().Pkg().Name()) - buf.WriteByte('.') + typeName := fmt.Sprintf("%s.%s", named.Obj().Pkg().Name(), named.Obj().Name()) + format := fmt.Sprintf("unexpected %s: %%#v", typeName) + + switch expr := expr.(type) { + case *ast.Ident: + fmt.Fprintf(buf, "\t\tpanic(fmt.Sprintf(%q, %s))\n\t", format, expr.Name) + case *ast.SelectorExpr: + // Use a new buffer to avoid writing to the original buffer if there's an error. + var selectorBuf bytes.Buffer + if err := writeSelector(&selectorBuf, expr); err != nil { + fmt.Fprintf(buf, "\t\tpanic(\"unexpected %s\")\n\t", typeName) + } else { + fmt.Fprintf(buf, "\t\tpanic(fmt.Sprintf(%q, %s))\n\t", format, selectorBuf.String()) + } + default: + fmt.Fprintf(buf, "\t\tpanic(\"unexpected %s\")\n\t", typeName) } - buf.WriteString(named.Obj().Name()) - buf.WriteString(": %") - buf.WriteByte(formatVerb) - buf.WriteString("\", ") - types.WriteExpr(buf, expr) - buf.WriteString("))\n\t") +} + +// writeSelector writes the formatted selector expression to the buffer. If one +// of the expressions in the chain in expr.X is not an *ast.Ident or +// *ast.SelectorExpr, an error is returned. +func writeSelector(buf *bytes.Buffer, expr *ast.SelectorExpr) error { + switch expr := expr.X.(type) { + case *ast.Ident: + buf.WriteString(expr.Name) + case *ast.SelectorExpr: + if err := writeSelector(buf, expr); err != nil { + return err + } + default: + return fmt.Errorf("unexpected type %T", expr) + } + + buf.WriteString(".") + buf.WriteString(expr.Sel.Name) + return nil } func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named { diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt b/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt index 299c197c8fd..2c1b19e130c 100644 --- a/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt +++ b/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt @@ -78,28 +78,28 @@ func doSwitch() { + case data.TypeBThree: + case data.TypeBTwo: + default: -+ panic(fmt.Sprintf("unexpected data.TypeB: %v", b)) ++ panic(fmt.Sprintf("unexpected data.TypeB: %#v", b)) -- @a2/a.go -- @@ -36 +36,4 @@ + case typeAOne: + case typeATwo: + default: -+ panic(fmt.Sprintf("unexpected typeA: %v", a)) ++ panic(fmt.Sprintf("unexpected fillswitch.typeA: %#v", a)) -- @a3/a.go -- @@ -40 +40,4 @@ + case notificationOne: + case notificationTwo: + default: -+ panic(fmt.Sprintf("unexpected notification: %T", n)) ++ panic(fmt.Sprintf("unexpected fillswitch.notification: %#v", n)) -- @a4/a.go -- @@ -43 +43,4 @@ + case notificationOne: + case notificationTwo: + default: -+ panic(fmt.Sprintf("unexpected notification: %T", nt)) ++ panic(fmt.Sprintf("unexpected fillswitch.notification: %#v", nt)) -- @a5/a.go -- @@ -51 +51,4 @@ + case typeAOne: + case typeATwo: + default: -+ panic(fmt.Sprintf("unexpected typeA: %v", s.a)) ++ panic(fmt.Sprintf("unexpected fillswitch.typeA: %#v", s.a)) diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt index c49fd5722c7..504acd6043e 100644 --- a/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt +++ b/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt @@ -89,28 +89,28 @@ func doSwitch() { + case data.TypeBThree: + case data.TypeBTwo: + default: -+ panic(fmt.Sprintf("unexpected data.TypeB: %v", b)) ++ panic(fmt.Sprintf("unexpected data.TypeB: %#v", b)) -- @a2/a.go -- @@ -36 +36,4 @@ + case typeAOne: + case typeATwo: + default: -+ panic(fmt.Sprintf("unexpected typeA: %v", a)) ++ panic(fmt.Sprintf("unexpected fillswitch.typeA: %#v", a)) -- @a3/a.go -- @@ -40 +40,4 @@ + case notificationOne: + case notificationTwo: + default: -+ panic(fmt.Sprintf("unexpected notification: %T", n)) ++ panic(fmt.Sprintf("unexpected fillswitch.notification: %#v", n)) -- @a4/a.go -- @@ -43 +43,4 @@ + case notificationOne: + case notificationTwo: + default: -+ panic(fmt.Sprintf("unexpected notification: %T", nt)) ++ panic(fmt.Sprintf("unexpected fillswitch.notification: %#v", nt)) -- @a5/a.go -- @@ -51 +51,4 @@ + case typeAOne: + case typeATwo: + default: -+ panic(fmt.Sprintf("unexpected typeA: %v", s.a)) ++ panic(fmt.Sprintf("unexpected fillswitch.typeA: %#v", s.a)) From 0cbc264f933c2bb378d9bc27ea7407d513643c8e Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Tue, 13 Feb 2024 20:37:01 +0000 Subject: [PATCH 19/20] refactor addDefaultCase --- .../analysis/fillswitch/fillswitch.go | 53 ++++++++----------- 1 file changed, 21 insertions(+), 32 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index 730ae6544ff..fada18d10e1 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -117,7 +117,6 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info * return nil } - // TODO: don't evaluate the assert.X expression a second time, switch assign := stmt.Assign.(type) { case *ast.AssignStmt: addDefaultCase(&buf, namedType, assign.Lhs[0]) @@ -176,7 +175,6 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In return nil } - // TODO: don't re-evaluate stmt.Tag. addDefaultCase(&buf, namedType, stmt.Tag) return &analysis.SuggestedFix{ @@ -192,42 +190,33 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In func addDefaultCase(buf *bytes.Buffer, named *types.Named, expr ast.Expr) { buf.WriteString("\tdefault:\n") typeName := fmt.Sprintf("%s.%s", named.Obj().Pkg().Name(), named.Obj().Name()) - format := fmt.Sprintf("unexpected %s: %%#v", typeName) - - switch expr := expr.(type) { - case *ast.Ident: - fmt.Fprintf(buf, "\t\tpanic(fmt.Sprintf(%q, %s))\n\t", format, expr.Name) - case *ast.SelectorExpr: - // Use a new buffer to avoid writing to the original buffer if there's an error. - var selectorBuf bytes.Buffer - if err := writeSelector(&selectorBuf, expr); err != nil { - fmt.Fprintf(buf, "\t\tpanic(\"unexpected %s\")\n\t", typeName) - } else { - fmt.Fprintf(buf, "\t\tpanic(fmt.Sprintf(%q, %s))\n\t", format, selectorBuf.String()) - } - default: - fmt.Fprintf(buf, "\t\tpanic(\"unexpected %s\")\n\t", typeName) + var dottedBuf bytes.Buffer + if writeDotted(&dottedBuf, expr) { + // Switch tag expression is a dotted path. + // It is safe to re-evaluate it in the default case. + format := fmt.Sprintf("unexpected %s: %%#v", typeName) + fmt.Fprintf(buf, "\t\tpanic(fmt.Sprintf(%q, %s))\n\t", format, dottedBuf.String()) + } else { + // Emit simpler message, without re-evaluating tag expression. + fmt.Fprintf(buf, "\t\tpanic(%q)\n\t", "unexpected "+typeName) } } -// writeSelector writes the formatted selector expression to the buffer. If one -// of the expressions in the chain in expr.X is not an *ast.Ident or -// *ast.SelectorExpr, an error is returned. -func writeSelector(buf *bytes.Buffer, expr *ast.SelectorExpr) error { - switch expr := expr.X.(type) { - case *ast.Ident: - buf.WriteString(expr.Name) +// writeDotted emits a dotted path a.b.c. +func writeDotted(buf *bytes.Buffer, e ast.Expr) bool { + switch e := e.(type) { case *ast.SelectorExpr: - if err := writeSelector(buf, expr); err != nil { - return err + if !writeDotted(buf, e.X) { + return false } - default: - return fmt.Errorf("unexpected type %T", expr) + buf.WriteByte('.') + buf.WriteString(e.Sel.Name) + return true + case *ast.Ident: + buf.WriteString(e.Name) + return true } - - buf.WriteString(".") - buf.WriteString(expr.Sel.Name) - return nil + return false } func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named { From a04dc69c7bb1ea23e396b8e701980cc425cdffe8 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Wed, 14 Feb 2024 09:00:17 +0000 Subject: [PATCH 20/20] move writeDotted into addDefaultCase --- .../analysis/fillswitch/fillswitch.go | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go index fada18d10e1..b93ade01065 100644 --- a/gopls/internal/analysis/fillswitch/fillswitch.go +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -188,10 +188,28 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In } func addDefaultCase(buf *bytes.Buffer, named *types.Named, expr ast.Expr) { + var dottedBuf bytes.Buffer + // writeDotted emits a dotted path a.b.c. + var writeDotted func(e ast.Expr) bool + writeDotted = func(e ast.Expr) bool { + switch e := e.(type) { + case *ast.SelectorExpr: + if !writeDotted(e.X) { + return false + } + dottedBuf.WriteByte('.') + dottedBuf.WriteString(e.Sel.Name) + return true + case *ast.Ident: + dottedBuf.WriteString(e.Name) + return true + } + return false + } + buf.WriteString("\tdefault:\n") typeName := fmt.Sprintf("%s.%s", named.Obj().Pkg().Name(), named.Obj().Name()) - var dottedBuf bytes.Buffer - if writeDotted(&dottedBuf, expr) { + if writeDotted(expr) { // Switch tag expression is a dotted path. // It is safe to re-evaluate it in the default case. format := fmt.Sprintf("unexpected %s: %%#v", typeName) @@ -202,23 +220,6 @@ func addDefaultCase(buf *bytes.Buffer, named *types.Named, expr ast.Expr) { } } -// writeDotted emits a dotted path a.b.c. -func writeDotted(buf *bytes.Buffer, e ast.Expr) bool { - switch e := e.(type) { - case *ast.SelectorExpr: - if !writeDotted(buf, e.X) { - return false - } - buf.WriteByte('.') - buf.WriteString(e.Sel.Name) - return true - case *ast.Ident: - buf.WriteString(e.Name) - return true - } - return false -} - func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named { switch assign := stmt.Assign.(type) { case *ast.ExprStmt: