Skip to content

Commit e86e93a

Browse files
committed
internal/source: use imports to find appropriate selector names
1 parent 101bdb1 commit e86e93a

File tree

1 file changed

+57
-18
lines changed

1 file changed

+57
-18
lines changed

internal/source/source.go

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"go/parser"
99
"go/token"
1010
"os"
11+
"path"
1112
"runtime"
1213
"strconv"
1314
"strings"
@@ -39,32 +40,65 @@ func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
3940
}
4041
debug("call stack position: %s:%d", filename, lineNum)
4142

42-
node, err := getNodeAtLine(filename, lineNum)
43+
source, err := getNodeAtLine(filename, lineNum)
4344
if err != nil {
4445
return nil, err
4546
}
46-
debug("found node: %s", debugFormatNode{node})
47+
debug("found node: %s", debugFormatNode{source.Node})
4748

48-
return getCallExprArgs(node)
49+
return getCallExprArgs(source)
4950
}
5051

51-
func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
52+
type fileSource struct {
53+
Node ast.Node
54+
Imports imports
55+
}
56+
57+
type imports map[string]struct{}
58+
59+
func newImports(specs []*ast.ImportSpec) imports {
60+
result := make(imports)
61+
for _, spec := range specs {
62+
pkgPath := strings.Trim(spec.Path.Value, `"`)
63+
if !strings.HasPrefix(pkgPath, `gotest.tools/`) {
64+
continue
65+
}
66+
name := path.Base(pkgPath)
67+
// Only two packages use internal/source right now.
68+
// Don't include the others to reduce the chance of a false positive
69+
// match on the name of some other package or type.
70+
if name != "assert" && name != "skip" {
71+
continue
72+
}
73+
if spec.Name != nil {
74+
name = spec.Name.Name
75+
}
76+
result[name] = struct{}{}
77+
}
78+
return result
79+
}
80+
81+
func getNodeAtLine(filename string, lineNum int) (fileSource, error) {
82+
fs := fileSource{}
5283
fileset := token.NewFileSet()
5384
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
5485
if err != nil {
55-
return nil, errors.Wrapf(err, "failed to parse source file: %s", filename)
86+
return fs, errors.Wrapf(err, "failed to parse source file: %s", filename)
5687
}
88+
fs.Imports = newImports(astFile.Imports)
5789

5890
if node := scanToLine(fileset, astFile, lineNum); node != nil {
59-
return node, nil
91+
fs.Node = node
92+
return fs, nil
6093
}
6194
if node := scanToDeferLine(fileset, astFile, lineNum); node != nil {
6295
node, err := guessDefer(node)
6396
if err != nil || node != nil {
64-
return node, err
97+
fs.Node = node
98+
return fs, err
6599
}
66100
}
67-
return nil, errors.Errorf("failed to find an expression on line %d in %s", lineNum, filename)
101+
return fs, errors.Errorf("failed to find an expression on line %d in %s", lineNum, filename)
68102
}
69103

70104
func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
@@ -121,9 +155,9 @@ func GoVersionLessThan(major, minor int64) bool {
121155

122156
var goVersionBefore19 = GoVersionLessThan(1, 9)
123157

124-
func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
125-
visitor := &callExprVisitor{}
126-
ast.Walk(visitor, node)
158+
func getCallExprArgs(source fileSource) ([]ast.Expr, error) {
159+
visitor := &callExprVisitor{imports: source.Imports}
160+
ast.Walk(visitor, source.Node)
127161
if visitor.expr == nil {
128162
return nil, errors.New("failed to find call expression")
129163
}
@@ -132,7 +166,8 @@ func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
132166
}
133167

134168
type callExprVisitor struct {
135-
expr *ast.CallExpr
169+
expr *ast.CallExpr
170+
imports imports
136171
}
137172

138173
func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
@@ -143,7 +178,7 @@ func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
143178

144179
switch typed := node.(type) {
145180
case *ast.CallExpr:
146-
if !isGoTestToolsCallExpr(typed) {
181+
if !isGoTestToolsCallExpr(typed, v.imports) {
147182
return v
148183
}
149184
v.expr = typed
@@ -155,7 +190,7 @@ func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
155190
return v
156191
}
157192

158-
func isGoTestToolsCallExpr(ce *ast.CallExpr) bool {
193+
func isGoTestToolsCallExpr(ce *ast.CallExpr, imports imports) bool {
159194
debug("call expr function: (%T), %v", ce.Fun, ce.Fun)
160195
se, ok := ce.Fun.(*ast.SelectorExpr)
161196
if !ok {
@@ -165,17 +200,21 @@ func isGoTestToolsCallExpr(ce *ast.CallExpr) bool {
165200
if !ok {
166201
return false
167202
}
168-
switch ident.Name {
169-
// TODO: use import alias from file for these values
170-
case "assert", "skip":
203+
switch {
204+
case imports.isGoTestToolsPackageSelector(ident.Name):
171205
return true
172-
case "gotestToolsTestShim":
206+
case ident.Name == "gotestToolsTestShim":
173207
// gotestToolsTestShim is used by tests of this package.
174208
return true
175209
}
176210
return false
177211
}
178212

213+
func (i imports) isGoTestToolsPackageSelector(name string) bool {
214+
_, ok := i[name]
215+
return ok
216+
}
217+
179218
// FormatNode using go/format.Node and return the result as a string
180219
func FormatNode(node ast.Node) (string, error) {
181220
buf := new(bytes.Buffer)

0 commit comments

Comments
 (0)