Skip to content

Commit 320bf61

Browse files
committed
gopls: add fill switch cases code action
1 parent d077888 commit 320bf61

File tree

8 files changed

+660
-0
lines changed

8 files changed

+660
-0
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
// Copyright 2020 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// Package fillswitch defines an Analyzer that automatically
6+
// fills the missing cases in type switches or switches over named types.
7+
//
8+
// The analyzer's diagnostic is merely a prompt.
9+
// The actual fix is created by a separate direct call from gopls to
10+
// the SuggestedFixes function.
11+
// Tests of Analyzer.Run can be found in ./testdata/src.
12+
// Tests of the SuggestedFixes logic live in ../../testdata/fillswitch.
13+
package fillswitch
14+
15+
import (
16+
"bytes"
17+
"context"
18+
"errors"
19+
"fmt"
20+
"go/ast"
21+
"go/token"
22+
"go/types"
23+
"slices"
24+
"strings"
25+
26+
"golang.org/x/tools/go/analysis"
27+
"golang.org/x/tools/go/ast/astutil"
28+
"golang.org/x/tools/go/ast/inspector"
29+
"golang.org/x/tools/gopls/internal/cache"
30+
"golang.org/x/tools/gopls/internal/cache/parsego"
31+
)
32+
33+
const FixCategory = "fillswitch" // recognized by gopls ApplyFix
34+
35+
// errNoSuggestedFix is returned when no suggested fix is available. This could
36+
// be because all cases are already covered, or (in the case of a type switch)
37+
// because the remaining cases are for types not accessible by the current
38+
// package.
39+
var errNoSuggestedFix = errors.New("no suggested fix")
40+
41+
// Diagnose computes diagnostics for switch statements with missing cases
42+
// overlapping with the provided start and end position.
43+
//
44+
// The diagnostic contains a lazy fix; the actual patch is computed
45+
// (via the ApplyFix command) by a call to [SuggestedFix].
46+
//
47+
// If either start or end is invalid, the entire package is inspected.
48+
func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic {
49+
var diags []analysis.Diagnostic
50+
nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)}
51+
inspect.Preorder(nodeFilter, func(n ast.Node) {
52+
if expr, ok := n.(*ast.SwitchStmt); ok {
53+
if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) {
54+
return // non-overlapping
55+
}
56+
57+
if defaultHandled(expr.Body) {
58+
return
59+
}
60+
61+
namedType, err := namedTypeFromSwitch(expr, info)
62+
if err != nil {
63+
return
64+
}
65+
66+
if _, err := suggestedFixSwitch(expr, pkg, info); err != nil {
67+
return
68+
}
69+
70+
diags = append(diags, analysis.Diagnostic{
71+
Message: "Switch has missing cases",
72+
Pos: expr.Pos(),
73+
End: expr.End(),
74+
Category: FixCategory,
75+
SuggestedFixes: []analysis.SuggestedFix{{
76+
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
77+
// No TextEdits => computed later by gopls.
78+
}},
79+
})
80+
}
81+
82+
if expr, ok := n.(*ast.TypeSwitchStmt); ok {
83+
if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) {
84+
return // non-overlapping
85+
}
86+
87+
if defaultHandled(expr.Body) {
88+
return
89+
}
90+
91+
namedType, err := namedTypeFromTypeSwitch(expr, info)
92+
if err != nil {
93+
return
94+
}
95+
96+
if _, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil {
97+
return
98+
}
99+
100+
diags = append(diags, analysis.Diagnostic{
101+
Message: "Switch has missing cases",
102+
Pos: expr.Pos(),
103+
End: expr.End(),
104+
Category: FixCategory,
105+
SuggestedFixes: []analysis.SuggestedFix{{
106+
Message: fmt.Sprintf("Add cases for %v", namedType.Obj().Name()),
107+
// No TextEdits => computed later by gopls.
108+
}},
109+
})
110+
}
111+
})
112+
113+
return diags
114+
}
115+
116+
func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
117+
namedType, err := namedTypeFromTypeSwitch(stmt, info)
118+
if err != nil {
119+
return nil, err
120+
}
121+
122+
scope := namedType.Obj().Pkg().Scope()
123+
variants := make([]string, 0)
124+
for _, name := range scope.Names() {
125+
obj := scope.Lookup(name)
126+
if _, ok := obj.(*types.TypeName); !ok {
127+
continue
128+
}
129+
130+
if types.Identical(obj.Type(), namedType.Obj().Type()) {
131+
continue
132+
}
133+
134+
if types.AssignableTo(obj.Type(), namedType.Obj().Type()) {
135+
if obj.Pkg().Name() != pkg.Name() {
136+
if !obj.Exported() {
137+
continue
138+
}
139+
140+
variants = append(variants, obj.Pkg().Name()+"."+obj.Name())
141+
} else {
142+
variants = append(variants, obj.Name())
143+
}
144+
} else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) {
145+
if obj.Pkg().Name() != pkg.Name() {
146+
if !obj.Exported() {
147+
continue
148+
}
149+
150+
variants = append(variants, "*"+obj.Pkg().Name()+"."+obj.Name())
151+
} else {
152+
variants = append(variants, "*"+obj.Name())
153+
}
154+
}
155+
}
156+
157+
handledVariants := getHandledVariants(stmt.Body)
158+
if len(variants) == 0 || len(variants) == len(handledVariants) {
159+
return nil, errNoSuggestedFix
160+
}
161+
162+
newText := buildNewText(variants, handledVariants)
163+
return &analysis.SuggestedFix{
164+
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
165+
TextEdits: []analysis.TextEdit{{
166+
Pos: stmt.End() - 1,
167+
End: stmt.End() - 1,
168+
NewText: indent([]byte(newText), []byte{'\t'}),
169+
}},
170+
}, nil
171+
}
172+
173+
func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
174+
namedType, err := namedTypeFromSwitch(stmt, info)
175+
if err != nil {
176+
return nil, err
177+
}
178+
179+
scope := namedType.Obj().Pkg().Scope()
180+
variants := make([]string, 0)
181+
for _, name := range scope.Names() {
182+
obj := scope.Lookup(name)
183+
if obj.Id() == namedType.Obj().Id() {
184+
continue
185+
}
186+
187+
if types.Identical(obj.Type(), namedType.Obj().Type()) {
188+
// TODO: comparing the package name like this feels wrong, is it?
189+
if obj.Pkg().Name() != pkg.Name() {
190+
if !obj.Exported() {
191+
continue
192+
}
193+
194+
variants = append(variants, obj.Pkg().Name()+"."+obj.Name())
195+
} else {
196+
variants = append(variants, obj.Name())
197+
}
198+
}
199+
}
200+
201+
handledVariants := getHandledVariants(stmt.Body)
202+
if len(variants) == 0 || len(variants) == len(handledVariants) {
203+
return nil, errNoSuggestedFix
204+
}
205+
206+
newText := buildNewText(variants, handledVariants)
207+
return &analysis.SuggestedFix{
208+
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
209+
TextEdits: []analysis.TextEdit{{
210+
Pos: stmt.End() - 1,
211+
End: stmt.End() - 1,
212+
NewText: indent([]byte(newText), []byte{'\t'}),
213+
}},
214+
}, nil
215+
}
216+
217+
func namedTypeFromSwitch(stmt *ast.SwitchStmt, info *types.Info) (*types.Named, error) {
218+
typ := info.TypeOf(stmt.Tag)
219+
if typ == nil {
220+
return nil, errors.New("expected switch statement to have a tag")
221+
}
222+
223+
namedType, ok := typ.(*types.Named)
224+
if !ok {
225+
return nil, errors.New("switch statement is not on a named type")
226+
}
227+
228+
return namedType, nil
229+
}
230+
231+
func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) (*types.Named, error) {
232+
switch s := stmt.Assign.(type) {
233+
case *ast.ExprStmt:
234+
typ := s.X.(*ast.TypeAssertExpr)
235+
namedType, ok := info.TypeOf(typ.X).(*types.Named)
236+
if !ok {
237+
return nil, errors.New("type switch expression is not on a named type")
238+
}
239+
240+
return namedType, nil
241+
case *ast.AssignStmt:
242+
for _, expr := range s.Rhs {
243+
typ, ok := expr.(*ast.TypeAssertExpr)
244+
if !ok {
245+
continue
246+
}
247+
248+
namedType, ok := info.TypeOf(typ.X).(*types.Named)
249+
if !ok {
250+
continue
251+
}
252+
253+
return namedType, nil
254+
}
255+
256+
return nil, errors.New("expected type switch expression to have a named type")
257+
default:
258+
return nil, errors.New("node is not a type switch statement")
259+
}
260+
}
261+
262+
func defaultHandled(body *ast.BlockStmt) bool {
263+
for _, bl := range body.List {
264+
if len(bl.(*ast.CaseClause).List) == 0 {
265+
return true
266+
}
267+
}
268+
269+
return false
270+
}
271+
272+
func buildNewText(variants []string, handledVariants []string) string {
273+
var textBuilder strings.Builder
274+
for _, c := range variants {
275+
if slices.Contains(handledVariants, c) {
276+
continue
277+
}
278+
279+
textBuilder.WriteString("case ")
280+
textBuilder.WriteString(c)
281+
textBuilder.WriteString(":\n")
282+
}
283+
284+
return textBuilder.String()
285+
}
286+
287+
func getHandledVariants(body *ast.BlockStmt) []string {
288+
out := make([]string, 0)
289+
for _, bl := range body.List {
290+
for _, c := range bl.(*ast.CaseClause).List {
291+
switch v := c.(type) {
292+
case *ast.Ident:
293+
out = append(out, v.Name)
294+
case *ast.SelectorExpr:
295+
out = append(out, v.X.(*ast.Ident).Name+"."+v.Sel.Name)
296+
case *ast.StarExpr:
297+
out = append(out, "*"+v.X.(*ast.Ident).Name)
298+
}
299+
}
300+
}
301+
302+
return out
303+
}
304+
305+
// SuggestedFix computes the suggested fix for the kinds of
306+
// diagnostics produced by the Analyzer above.
307+
func SuggestedFix(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) {
308+
pos := start // don't use the end
309+
path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos)
310+
if len(path) < 2 {
311+
return nil, nil, fmt.Errorf("no expression found")
312+
}
313+
314+
switch stmt := path[0].(type) {
315+
case *ast.SwitchStmt:
316+
fix, err := suggestedFixSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo())
317+
if err != nil {
318+
return nil, nil, err
319+
}
320+
321+
return pkg.FileSet(), fix, nil
322+
case *ast.TypeSwitchStmt:
323+
fix, err := suggestedFixTypeSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo())
324+
if err != nil {
325+
return nil, nil, err
326+
}
327+
328+
return pkg.FileSet(), fix, nil
329+
default:
330+
return nil, nil, fmt.Errorf("no switch statement found")
331+
}
332+
}
333+
334+
// indent works line by line through str, prefixing each line with
335+
// prefix.
336+
func indent(str, prefix []byte) []byte {
337+
split := bytes.Split(str, []byte("\n"))
338+
newText := bytes.NewBuffer(nil)
339+
for i, s := range split {
340+
if i != 0 {
341+
newText.Write(prefix)
342+
}
343+
344+
newText.Write(s)
345+
if i < len(split)-1 {
346+
newText.WriteByte('\n')
347+
}
348+
}
349+
return newText.Bytes()
350+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2020 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package fillswitch_test
6+
7+
import (
8+
"go/token"
9+
"testing"
10+
11+
"golang.org/x/tools/go/analysis"
12+
"golang.org/x/tools/go/analysis/analysistest"
13+
"golang.org/x/tools/go/analysis/passes/inspect"
14+
"golang.org/x/tools/go/ast/inspector"
15+
"golang.org/x/tools/gopls/internal/analysis/fillswitch"
16+
)
17+
18+
// analyzer allows us to test the fillswitch code action using the analysistest
19+
// harness. (fillswitch used to be a gopls analyzer.)
20+
var analyzer = &analysis.Analyzer{
21+
Name: "fillswitch",
22+
Doc: "test only",
23+
Requires: []*analysis.Analyzer{inspect.Analyzer},
24+
Run: func(pass *analysis.Pass) (any, error) {
25+
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
26+
for _, d := range fillswitch.Diagnose(inspect, token.NoPos, token.NoPos, pass.Pkg, pass.TypesInfo) {
27+
pass.Report(d)
28+
}
29+
return nil, nil
30+
},
31+
URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/fillswitch",
32+
RunDespiteErrors: true,
33+
}
34+
35+
func Test(t *testing.T) {
36+
testdata := analysistest.TestData()
37+
analysistest.Run(t, testdata, analyzer, "a")
38+
}

0 commit comments

Comments
 (0)