8
8
"go/parser"
9
9
"go/token"
10
10
"os"
11
+ "path"
11
12
"runtime"
12
13
"strconv"
13
14
"strings"
@@ -39,32 +40,65 @@ func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
39
40
}
40
41
debug ("call stack position: %s:%d" , filename , lineNum )
41
42
42
- node , err := getNodeAtLine (filename , lineNum )
43
+ source , err := getNodeAtLine (filename , lineNum )
43
44
if err != nil {
44
45
return nil , err
45
46
}
46
- debug ("found node: %s" , debugFormatNode {node })
47
+ debug ("found node: %s" , debugFormatNode {source . Node })
47
48
48
- return getCallExprArgs (node )
49
+ return getCallExprArgs (source )
49
50
}
50
51
51
- func getNodeAtLine (filename string , lineNum int ) (ast.Node , error ) {
52
+ type fileSource struct {
53
+ Node ast.Node
54
+ Imports imports
55
+ }
56
+
57
+ type imports map [string ]struct {}
58
+
59
+ func newImports (specs []* ast.ImportSpec ) imports {
60
+ result := make (imports )
61
+ for _ , spec := range specs {
62
+ pkgPath := strings .Trim (spec .Path .Value , `"` )
63
+ if ! strings .HasPrefix (pkgPath , `gotest.tools/` ) {
64
+ continue
65
+ }
66
+ name := path .Base (pkgPath )
67
+ // Only two packages use internal/source right now.
68
+ // Don't include the others to reduce the chance of a false positive
69
+ // match on the name of some other package or type.
70
+ if name != "assert" && name != "skip" {
71
+ continue
72
+ }
73
+ if spec .Name != nil {
74
+ name = spec .Name .Name
75
+ }
76
+ result [name ] = struct {}{}
77
+ }
78
+ return result
79
+ }
80
+
81
+ func getNodeAtLine (filename string , lineNum int ) (fileSource , error ) {
82
+ fs := fileSource {}
52
83
fileset := token .NewFileSet ()
53
84
astFile , err := parser .ParseFile (fileset , filename , nil , parser .AllErrors )
54
85
if err != nil {
55
- return nil , errors .Wrapf (err , "failed to parse source file: %s" , filename )
86
+ return fs , errors .Wrapf (err , "failed to parse source file: %s" , filename )
56
87
}
88
+ fs .Imports = newImports (astFile .Imports )
57
89
58
90
if node := scanToLine (fileset , astFile , lineNum ); node != nil {
59
- return node , nil
91
+ fs .Node = node
92
+ return fs , nil
60
93
}
61
94
if node := scanToDeferLine (fileset , astFile , lineNum ); node != nil {
62
95
node , err := guessDefer (node )
63
96
if err != nil || node != nil {
64
- return node , err
97
+ fs .Node = node
98
+ return fs , err
65
99
}
66
100
}
67
- return nil , errors .Errorf ("failed to find an expression on line %d in %s" , lineNum , filename )
101
+ return fs , errors .Errorf ("failed to find an expression on line %d in %s" , lineNum , filename )
68
102
}
69
103
70
104
func scanToLine (fileset * token.FileSet , node ast.Node , lineNum int ) ast.Node {
@@ -121,9 +155,9 @@ func GoVersionLessThan(major, minor int64) bool {
121
155
122
156
var goVersionBefore19 = GoVersionLessThan (1 , 9 )
123
157
124
- func getCallExprArgs (node ast. Node ) ([]ast.Expr , error ) {
125
- visitor := & callExprVisitor {}
126
- ast .Walk (visitor , node )
158
+ func getCallExprArgs (source fileSource ) ([]ast.Expr , error ) {
159
+ visitor := & callExprVisitor {imports : source . Imports }
160
+ ast .Walk (visitor , source . Node )
127
161
if visitor .expr == nil {
128
162
return nil , errors .New ("failed to find call expression" )
129
163
}
@@ -132,7 +166,8 @@ func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
132
166
}
133
167
134
168
type callExprVisitor struct {
135
- expr * ast.CallExpr
169
+ expr * ast.CallExpr
170
+ imports imports
136
171
}
137
172
138
173
func (v * callExprVisitor ) Visit (node ast.Node ) ast.Visitor {
@@ -143,7 +178,7 @@ func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
143
178
144
179
switch typed := node .(type ) {
145
180
case * ast.CallExpr :
146
- if ! isGoTestToolsCallExpr (typed ) {
181
+ if ! isGoTestToolsCallExpr (typed , v . imports ) {
147
182
return v
148
183
}
149
184
v .expr = typed
@@ -155,7 +190,7 @@ func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
155
190
return v
156
191
}
157
192
158
- func isGoTestToolsCallExpr (ce * ast.CallExpr ) bool {
193
+ func isGoTestToolsCallExpr (ce * ast.CallExpr , imports imports ) bool {
159
194
debug ("call expr function: (%T), %v" , ce .Fun , ce .Fun )
160
195
se , ok := ce .Fun .(* ast.SelectorExpr )
161
196
if ! ok {
@@ -165,17 +200,21 @@ func isGoTestToolsCallExpr(ce *ast.CallExpr) bool {
165
200
if ! ok {
166
201
return false
167
202
}
168
- switch ident .Name {
169
- // TODO: use import alias from file for these values
170
- case "assert" , "skip" :
203
+ switch {
204
+ case imports .isGoTestToolsPackageSelector (ident .Name ):
171
205
return true
172
- case "gotestToolsTestShim" :
206
+ case ident . Name == "gotestToolsTestShim" :
173
207
// gotestToolsTestShim is used by tests of this package.
174
208
return true
175
209
}
176
210
return false
177
211
}
178
212
213
+ func (i imports ) isGoTestToolsPackageSelector (name string ) bool {
214
+ _ , ok := i [name ]
215
+ return ok
216
+ }
217
+
179
218
// FormatNode using go/format.Node and return the result as a string
180
219
func FormatNode (node ast.Node ) (string , error ) {
181
220
buf := new (bytes.Buffer )
0 commit comments