Skip to content

Commit db76cce

Browse files
committed
feat: support interface methods
1 parent cb4bf76 commit db76cce

File tree

4 files changed

+28
-36
lines changed

4 files changed

+28
-36
lines changed

musttag.go

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func New(funcs ...Func) *analysis.Analyzer {
6060
}
6161
}
6262

63-
return run(pass, mainModule, allFuncs)
63+
return nil, run(pass, mainModule, allFuncs)
6464
},
6565
}
6666
}
@@ -86,43 +86,34 @@ func flags(funcs *[]Func) flag.FlagSet {
8686
return *fs
8787
}
8888

89-
func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any, err error) {
89+
func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) error {
9090
visit := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
91-
filter := []ast.Node{(*ast.CallExpr)(nil)}
9291

93-
visit.Preorder(filter, func(node ast.Node) {
94-
if err != nil {
95-
return // there is already an error.
96-
}
92+
for node := range visit.PreorderSeq((*ast.CallExpr)(nil)) {
93+
call := node.(*ast.CallExpr)
9794

98-
call, ok := node.(*ast.CallExpr)
95+
callee, ok := typeutil.Callee(pass.TypesInfo, call).(*types.Func)
9996
if !ok {
100-
return
101-
}
102-
103-
callee := typeutil.StaticCallee(pass.TypesInfo, call)
104-
if callee == nil {
105-
return
97+
continue
10698
}
10799

108100
fn, ok := funcs[cutVendor(callee.FullName())]
109101
if !ok {
110-
return
102+
continue
111103
}
112104

113105
if len(call.Args) <= fn.ArgPos {
114-
err = fmt.Errorf("musttag: Func.ArgPos cannot be %d: %s accepts only %d argument(s)", fn.ArgPos, fn.Name, len(call.Args))
115-
return
106+
return fmt.Errorf("musttag: Func.ArgPos cannot be %d: %s accepts only %d argument(s)", fn.ArgPos, fn.Name, len(call.Args))
116107
}
117108

118109
arg := call.Args[fn.ArgPos]
119110
if ident, ok := arg.(*ast.Ident); ok && ident.Obj == nil {
120-
return // e.g. json.Marshal(nil)
111+
continue // e.g. json.Marshal(nil)
121112
}
122113

123114
typ := pass.TypesInfo.TypeOf(arg)
124115
if typ == nil {
125-
return
116+
continue
126117
}
127118

128119
checker := checker{
@@ -132,13 +123,13 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any,
132123
imports: pass.Pkg.Imports(),
133124
}
134125
if checker.isValidType(typ, fn.Tag) {
135-
return
126+
continue
136127
}
137128

138129
pass.Reportf(arg.Pos(), "the given struct should be annotated with the `%s` tag", fn.Tag)
139-
})
130+
}
140131

141-
return nil, err
132+
return nil
142133
}
143134

144135
type checker struct {
@@ -176,7 +167,6 @@ func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) {
176167
return c.parseStruct(typ.Elem())
177168
case *types.Map:
178169
return c.parseStruct(typ.Elem())
179-
180170
case *types.Named: // a struct of the named type.
181171
pkg := typ.Obj().Pkg()
182172
if pkg == nil {
@@ -190,10 +180,8 @@ func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) {
190180
return nil, false
191181
}
192182
return styp, true
193-
194183
case *types.Struct: // an anonymous struct.
195184
return typ, true
196-
197185
default:
198186
return nil, false
199187
}
@@ -208,15 +196,12 @@ func (c *checker) isValidStruct(styp *types.Struct, tag string) bool {
208196

209197
tagValue, ok := reflect.StructTag(styp.Tag(i)).Lookup(tag)
210198
if !ok {
211-
// tag is not required for embedded types.
212199
if !field.Embedded() {
213-
return false
200+
return false // tag is not required for embedded types.
214201
}
215202
}
216-
217-
// the field is explicitly ignored.
218203
if tagValue == "-" {
219-
continue
204+
continue // the field is explicitly ignored.
220205
}
221206

222207
if !c.isValidType(field.Type(), tag) {

musttag_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ func TestAnalyzer(t *testing.T) {
1818

1919
t.Run("tests", func(t *testing.T) {
2020
analyzer := New(
21-
Func{Name: "example.com/custom.Marshal", Tag: "custom", ArgPos: 0},
22-
Func{Name: "example.com/custom.Unmarshal", Tag: "custom", ArgPos: 1},
21+
Func{Name: "example.com/custom.Function", Tag: "custom", ArgPos: 0},
22+
Func{Name: "(example.com/custom.Struct).Method", Tag: "custom", ArgPos: 0},
23+
Func{Name: "(example.com/custom.Interface).Method", Tag: "custom", ArgPos: 0},
2324
)
2425
analysistest.Run(t, testdata, analyzer, "tests")
2526
})
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
package custom
22

3-
func Marshal(any) ([]byte, error) { return nil, nil }
4-
func Unmarshal([]byte, any) error { return nil }
3+
func Function(any) ([]byte, error) { return nil, nil }
4+
5+
type Struct struct{}
6+
7+
func (Struct) Method(any) ([]byte, error) { return nil, nil }
8+
9+
type Interface interface{ Method(any) ([]byte, error) }

testdata/src/tests/builtins.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ func testSQLX() {
195195

196196
func testCustom() {
197197
var st Struct
198-
custom.Marshal(st) // want "the given struct should be annotated with the `custom` tag"
199-
custom.Unmarshal(nil, &st) // want "the given struct should be annotated with the `custom` tag"
198+
custom.Function(st) // want "the given struct should be annotated with the `custom` tag"
199+
new(custom.Struct).Method(st) // want "the given struct should be annotated with the `custom` tag"
200+
(custom.Interface)(nil).Method(st) // want "the given struct should be annotated with the `custom` tag"
200201
}

0 commit comments

Comments
 (0)