Skip to content

Commit c8fda66

Browse files
committed
return edits on diagnostic
1 parent bab3012 commit c8fda66

File tree

5 files changed

+111
-93
lines changed

5 files changed

+111
-93
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// Package fillswitch identifies switches with missing cases.
6+
//
7+
// It will provide diagnostics for type switches or switches over named types
8+
// that are missing cases and provides a code action to fill those in.
9+
//
10+
// If the switch statement is over a named type, it will suggest cases for all
11+
// const values that are assignable to the named type.
12+
//
13+
// type T int
14+
// const (
15+
// A T = iota
16+
// B
17+
// C
18+
// )
19+
//
20+
// var t T
21+
// switch t {
22+
// case A:
23+
// }
24+
//
25+
// It will provide a diagnostic with a suggested edit to fill in the remaining
26+
// cases:
27+
//
28+
// var t T
29+
// switch t {
30+
// case A:
31+
// case B:
32+
// case C:
33+
// }
34+
//
35+
// If the switch statement is over type of an interface, it will suggest cases for all types
36+
// that implement the interface.
37+
//
38+
// type I interface {
39+
// M()
40+
// }
41+
//
42+
// type T struct{}
43+
// func (t *T) M() {}
44+
//
45+
// type E struct{}
46+
// func (e *E) M() {}
47+
//
48+
// var i I
49+
// switch i.(type) {
50+
// case *T:
51+
// }
52+
//
53+
// It will provide a diagnostic with a suggested edit to fill in the remaining
54+
// cases:
55+
//
56+
// var i I
57+
// switch i.(type) {
58+
// case *T:
59+
// case *E:
60+
// }
61+
//
62+
// The provided diagnostics will only suggest cases for types that are defined
63+
// on the same package as the switch statement, or for types that are exported;
64+
// and it will not suggest any case if the switch handles the default case.
65+
package fillswitch

gopls/internal/analysis/fillswitch/fillswitch.go

Lines changed: 24 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ package fillswitch
1414

1515
import (
1616
"bytes"
17-
"context"
1817
"errors"
1918
"fmt"
2019
"go/ast"
@@ -24,20 +23,14 @@ import (
2423
"strings"
2524

2625
"golang.org/x/tools/go/analysis"
27-
"golang.org/x/tools/go/ast/astutil"
2826
"golang.org/x/tools/go/ast/inspector"
29-
"golang.org/x/tools/gopls/internal/cache"
30-
"golang.org/x/tools/gopls/internal/cache/parsego"
3127
)
3228

33-
const FixCategory = "fillswitch" // recognized by gopls ApplyFix
29+
const FixCategory = "fillswitch"
3430

3531
// Diagnose computes diagnostics for switch statements with missing cases
3632
// overlapping with the provided start and end position.
3733
//
38-
// The diagnostic contains a lazy fix; the actual patch is computed
39-
// (via the ApplyFix command) by a call to [SuggestedFix].
40-
//
4134
// If either start or end is invalid, the entire package is inspected.
4235
func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic {
4336
var diags []analysis.Diagnostic
@@ -50,49 +43,35 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac
5043
return // non-overlapping
5144
}
5245

53-
namedType, err := namedTypeFromSwitch(expr, info)
54-
if err != nil {
55-
return
56-
}
57-
58-
if fix, err := suggestedFixSwitch(expr, pkg, info); err != nil || fix == nil {
46+
fix, err := suggestedFixSwitch(expr, pkg, info)
47+
if err != nil || fix == nil {
5948
return
6049
}
6150

6251
diags = append(diags, analysis.Diagnostic{
63-
Message: "Switch has missing cases",
64-
Pos: expr.Pos(),
65-
End: expr.End(),
66-
Category: FixCategory,
67-
SuggestedFixes: []analysis.SuggestedFix{{
68-
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
69-
// No TextEdits => computed later by gopls.
70-
}},
52+
Message: fix.Message,
53+
Pos: expr.Pos(),
54+
End: expr.End(),
55+
Category: FixCategory,
56+
SuggestedFixes: []analysis.SuggestedFix{*fix},
7157
})
7258
case *ast.TypeSwitchStmt:
7359
if start.IsValid() && expr.End() < start ||
7460
end.IsValid() && expr.Pos() > end {
7561
return // non-overlapping
7662
}
7763

78-
namedType, err := namedTypeFromTypeSwitch(expr, info)
79-
if err != nil {
80-
return
81-
}
82-
83-
if fix, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil || fix == nil {
64+
fix, err := suggestedFixTypeSwitch(expr, pkg, info)
65+
if err != nil || fix == nil {
8466
return
8567
}
8668

8769
diags = append(diags, analysis.Diagnostic{
88-
Message: "Switch has missing cases",
89-
Pos: expr.Pos(),
90-
End: expr.End(),
91-
Category: FixCategory,
92-
SuggestedFixes: []analysis.SuggestedFix{{
93-
Message: fmt.Sprintf("Add cases for %v", namedType.Obj().Name()),
94-
// No TextEdits => computed later by gopls.
95-
}},
70+
Message: fix.Message,
71+
Pos: expr.Pos(),
72+
End: expr.End(),
73+
Category: FixCategory,
74+
SuggestedFixes: []analysis.SuggestedFix{*fix},
9675
})
9776
}
9877
})
@@ -134,7 +113,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *
134113
}
135114
}
136115

137-
handledVariants := typeSwitchCases(stmt.Body, info)
116+
handledVariants := caseTypes(stmt.Body, info)
138117
if len(variants) == 0 || len(variants) == len(handledVariants) {
139118
return nil, nil
140119
}
@@ -144,7 +123,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *
144123
TextEdits: []analysis.TextEdit{{
145124
Pos: stmt.End() - 1,
146125
End: stmt.End() - 1,
147-
NewText: buildNewTypesText(variants, handledVariants, pkg),
126+
NewText: buildTypesText(variants, handledVariants, pkg),
148127
}},
149128
}, nil
150129
}
@@ -170,7 +149,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In
170149

171150
samePkg := obj.Pkg() != pkg
172151
if samePkg && !obj.Exported() {
173-
continue
152+
continue // inaccessible
174153
}
175154

176155
if types.Identical(obj.Type(), namedType.Obj().Type()) {
@@ -188,7 +167,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In
188167
TextEdits: []analysis.TextEdit{{
189168
Pos: stmt.End() - 1,
190169
End: stmt.End() - 1,
191-
NewText: buildNewConstsText(variants, handledVariants, pkg),
170+
NewText: buildConstsText(variants, handledVariants, pkg),
192171
}},
193172
}, nil
194173
}
@@ -252,7 +231,7 @@ func hasDefaultCase(body *ast.BlockStmt) bool {
252231
return false
253232
}
254233

255-
func buildNewConstsText(variants []*types.Const, handledVariants []*types.Const, currentPkg *types.Package) []byte {
234+
func buildConstsText(variants []*types.Const, handledVariants []*types.Const, currentPkg *types.Package) []byte {
256235
var textBuilder strings.Builder
257236
for _, c := range variants {
258237
if slices.Contains(handledVariants, c) {
@@ -287,7 +266,7 @@ func isSameType(c, t types.Type) bool {
287266
return false
288267
}
289268

290-
func buildNewTypesText(variants []types.Type, handledVariants []types.Type, currentPkg *types.Package) []byte {
269+
func buildTypesText(variants []types.Type, handledVariants []types.Type, currentPkg *types.Package) []byte {
291270
var textBuilder strings.Builder
292271
for _, c := range variants {
293272
if slices.ContainsFunc(handledVariants, func(t types.Type) bool { return isSameType(c, t) }) {
@@ -309,6 +288,7 @@ func buildNewTypesText(variants []types.Type, handledVariants []types.Type, curr
309288
}
310289

311290
if e.Obj().Pkg() != currentPkg {
291+
// TODO: use the correct package name when the import is renamed
312292
textBuilder.WriteString("*" + e.Obj().Pkg().Name() + "." + e.Obj().Name())
313293
} else {
314294
textBuilder.WriteString("*" + e.Obj().Name())
@@ -335,6 +315,7 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const {
335315
if !ok {
336316
continue
337317
}
318+
338319
c, ok := obj.(*types.Const)
339320
if !ok {
340321
continue
@@ -365,7 +346,7 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const {
365346
return out
366347
}
367348

368-
func typeSwitchCases(body *ast.BlockStmt, info *types.Info) []types.Type {
349+
func caseTypes(body *ast.BlockStmt, info *types.Info) []types.Type {
369350
var out []types.Type
370351
for _, stmt := range body.List {
371352
for _, e := range stmt.(*ast.CaseClause).List {
@@ -421,32 +402,3 @@ func typeSwitchCases(body *ast.BlockStmt, info *types.Info) []types.Type {
421402

422403
return out
423404
}
424-
425-
// SuggestedFix computes the suggested fix for the kinds of
426-
// diagnostics produced by the Analyzer above.
427-
func SuggestedFix(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) {
428-
pos := start // don't use the end
429-
path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos)
430-
if len(path) < 2 {
431-
return nil, nil, fmt.Errorf("no expression found")
432-
}
433-
434-
switch stmt := path[0].(type) {
435-
case *ast.SwitchStmt:
436-
fix, err := suggestedFixSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo())
437-
if err != nil {
438-
return nil, nil, err
439-
}
440-
441-
return pkg.FileSet(), fix, nil
442-
case *ast.TypeSwitchStmt:
443-
fix, err := suggestedFixTypeSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo())
444-
if err != nil {
445-
return nil, nil, err
446-
}
447-
448-
return pkg.FileSet(), fix, nil
449-
default:
450-
return nil, nil, fmt.Errorf("no switch statement found")
451-
}
452-
}

gopls/internal/analysis/fillswitch/testdata/src/a/a.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ const (
1818

1919
func doSwitch() {
2020
var a typeA
21-
switch a { // want `Switch has missing cases`
21+
switch a { // want `Add cases for typeA`
2222
}
2323

24-
switch a { // want `Switch has missing cases`
24+
switch a { // want `Add cases for typeA`
2525
case typeAOne:
2626
}
2727

@@ -37,7 +37,7 @@ func doSwitch() {
3737
}
3838

3939
var b data.TypeB
40-
switch b { // want `Switch has missing cases`
40+
switch b { // want `Add cases for TypeB`
4141
case data.TypeBOne:
4242
}
4343
}
@@ -56,10 +56,10 @@ func (notificationTwo) isNotification() {}
5656

5757
func doTypeSwitch() {
5858
var not notification
59-
switch not.(type) { // want `Switch has missing cases`
59+
switch not.(type) { // want `Add cases for notification`
6060
}
6161

62-
switch not.(type) { // want `Switch has missing cases`
62+
switch not.(type) { // want `Add cases for notification`
6363
case notificationOne:
6464
}
6565

gopls/internal/golang/codeaction.go

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle,
9999
return nil, err
100100
}
101101
if want[protocol.RefactorRewrite] {
102-
rewrites, err := getRewriteCodeActions(pkg, pgf, fh, rng, snapshot.Options())
102+
rewrites, err := getRewriteCodeActions(ctx, pkg, snapshot, pgf, fh, rng, snapshot.Options())
103103
if err != nil {
104104
return nil, err
105105
}
@@ -254,7 +254,7 @@ func newCodeAction(title string, kind protocol.CodeActionKind, cmd *protocol.Com
254254
}
255255

256256
// getRewriteCodeActions returns refactor.rewrite code actions available at the specified range.
257-
func getRewriteCodeActions(pkg *cache.Package, pgf *ParsedGoFile, fh file.Handle, rng protocol.Range, options *settings.Options) (_ []protocol.CodeAction, rerr error) {
257+
func getRewriteCodeActions(ctx context.Context, pkg *cache.Package, snapshot *cache.Snapshot, pgf *ParsedGoFile, fh file.Handle, rng protocol.Range, options *settings.Options) (_ []protocol.CodeAction, rerr error) {
258258
// golang/go#61693: code actions were refactored to run outside of the
259259
// analysis framework, but as a result they lost their panic recovery.
260260
//
@@ -330,24 +330,27 @@ func getRewriteCodeActions(pkg *cache.Package, pgf *ParsedGoFile, fh file.Handle
330330
}
331331

332332
for _, diag := range fillswitch.Diagnose(inspect, start, end, pkg.GetTypes(), pkg.GetTypesInfo()) {
333-
rng, err := pgf.Mapper.PosRange(pgf.Tok, diag.Pos, diag.End)
333+
edits, err := suggestedFixToEdits(ctx, snapshot, pkg.FileSet(), &diag.SuggestedFixes[0])
334334
if err != nil {
335335
return nil, err
336336
}
337-
for _, fix := range diag.SuggestedFixes {
338-
cmd, err := command.NewApplyFixCommand(fix.Message, command.ApplyFixArgs{
339-
Fix: diag.Category,
340-
URI: pgf.URI,
341-
Range: rng,
342-
ResolveEdits: supportsResolveEdits(options),
337+
338+
changes := []protocol.DocumentChanges{} // must be a slice
339+
for _, edit := range edits {
340+
edit := edit
341+
changes = append(changes, protocol.DocumentChanges{
342+
TextDocumentEdit: &edit,
343343
})
344-
if err != nil {
345-
return nil, err
346-
}
347-
commands = append(commands, cmd)
348344
}
349-
}
350345

346+
actions = append(actions, protocol.CodeAction{
347+
Title: diag.Message,
348+
Kind: protocol.RefactorRewrite,
349+
Edit: &protocol.WorkspaceEdit{
350+
DocumentChanges: changes,
351+
},
352+
})
353+
}
351354
for i := range commands {
352355
actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorRewrite, &commands[i], nil, options))
353356
}

gopls/internal/golang/fix.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"golang.org/x/tools/go/analysis"
1515
"golang.org/x/tools/gopls/internal/analysis/embeddirective"
1616
"golang.org/x/tools/gopls/internal/analysis/fillstruct"
17-
"golang.org/x/tools/gopls/internal/analysis/fillswitch"
1817
"golang.org/x/tools/gopls/internal/analysis/stubmethods"
1918
"golang.org/x/tools/gopls/internal/analysis/undeclaredname"
2019
"golang.org/x/tools/gopls/internal/analysis/unusedparams"
@@ -108,7 +107,6 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
108107
fillstruct.FixCategory: singleFile(fillstruct.SuggestedFix),
109108
stubmethods.FixCategory: stubMethodsFixer,
110109
undeclaredname.FixCategory: singleFile(undeclaredname.SuggestedFix),
111-
fillswitch.FixCategory: fillswitch.SuggestedFix,
112110

113111
// Ad-hoc fixers: these are used when the command is
114112
// constructed directly by logic in server/code_action.

0 commit comments

Comments
 (0)