Skip to content
This repository was archived by the owner on Jun 27, 2023. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 46 additions & 68 deletions mockgen/mockgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,35 +294,13 @@ func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface,
g.p("")
g.GenerateMockMethod(mockType, m, pkgOverride)
g.p("")
g.GenerateMockRecorderMethod(mockType, m)
g.GenerateMockRecorderMethod(mockType, m, pkgOverride)
}
}

// GenerateMockMethod generates a mock method implementation.
// If non-empty, pkgOverride is the package in which unqualified types reside.
func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
args := make([]string, len(m.In))
argNames := make([]string, len(m.In))
for i, p := range m.In {
name := p.Name
if name == "" {
name = fmt.Sprintf("_param%d", i)
}
ts := p.Type.String(g.packageMap, pkgOverride)
args[i] = name + " " + ts
argNames[i] = name
}
if m.Variadic != nil {
name := m.Variadic.Name
if name == "" {
name = fmt.Sprintf("_param%d", len(m.In))
}
ts := m.Variadic.Type.String(g.packageMap, pkgOverride)
args = append(args, name+" ..."+ts)
argNames = append(argNames, name)
}
argString := strings.Join(args, ", ")

rets := make([]string, len(m.Out))
for i, p := range m.Out {
rets[i] = p.Type.String(g.packageMap, pkgOverride)
Expand All @@ -335,24 +313,8 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
retString = " " + retString
}

g.p("func (_m *%v) %v(%v)%v {", mockType, m.Name, argString, retString)
g.in()
callArgs := g.funcStart("_m", mockType, m, pkgOverride, retString)

callArgs := strings.Join(argNames, ", ")
if callArgs != "" {
callArgs = ", " + callArgs
}
if m.Variadic != nil {
// Non-trivial. The generated code must build a []interface{},
// but the variadic argument may be any type.
g.p("_s := []interface{}{%s}", strings.Join(argNames[:len(argNames)-1], ", "))
g.p("for _, _x := range %s {", argNames[len(argNames)-1])
g.in()
g.p("_s = append(_s, _x)")
g.out()
g.p("}")
callArgs = ", _s..."
}
if len(m.Out) == 0 {
g.p(`_m.ctrl.Call(_m, "%v"%v)`, m.Name, callArgs)
} else {
Expand All @@ -374,45 +336,61 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
return nil
}

func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error {
nargs := len(m.In)
args := make([]string, nargs)
for i := 0; i < nargs; i++ {
args[i] = "arg" + strconv.Itoa(i)
}
argString := strings.Join(args, ", ")
if nargs > 0 {
argString += " interface{}"
func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method, pkgOverride string) error {

callArgs := g.funcStart("_mr", "_"+mockType+"Recorder", m, pkgOverride, "*gomock.Call")
g.p(`return _mr.mock.ctrl.RecordCall(_mr.mock, "%v"%v)`, m.Name, callArgs)

g.out()
g.p("}")
return nil
}

// funcStart generates the signature of the generated method and the type
// conversions needed for variadic functions. It returns the string of arguments
// that the generated function body needs to use in its call.
func (g *generator) funcStart(recvName, mockType string, m *model.Method, pkgOverride string, retString string) string {
args := make([]string, len(m.In))
argNames := make([]string, len(m.In))
for i, p := range m.In {
name := p.Name
if name == "" {
name = fmt.Sprintf("_param%d", i)
}
ts := p.Type.String(g.packageMap, pkgOverride)
args[i] = name + " " + ts
argNames[i] = name
}
if m.Variadic != nil {
if nargs > 0 {
argString += ", "
name := m.Variadic.Name
if name == "" {
name = fmt.Sprintf("_param%d", len(m.In))
}
argString += fmt.Sprintf("arg%d ...interface{}", nargs)
ts := m.Variadic.Type.String(g.packageMap, pkgOverride)
args = append(args, name+" ..."+ts)
argNames = append(argNames, name)
}
argString := strings.Join(args, ", ")

g.p("func (_mr *_%vRecorder) %v(%v) *gomock.Call {", mockType, m.Name, argString)
g.p("func (%s *%v) %v(%v)%v {", recvName, mockType, m.Name, argString, retString)
g.in()

callArgs := strings.Join(args, ", ")
if nargs > 0 {
callArgs := strings.Join(argNames, ", ")
if callArgs != "" {
callArgs = ", " + callArgs
}
if m.Variadic != nil {
if nargs == 0 {
// Easy: just use ... to push the arguments through.
callArgs = ", arg0..."
} else {
// Hard: create a temporary slice.
g.p("_s := append([]interface{}{%s}, arg%d...)", strings.Join(args, ", "), nargs)
callArgs = ", _s..."
}
// Non-trivial. The generated code must build a []interface{},
// but the variadic argument may be any type.
g.p("_s := []interface{}{%s}", strings.Join(argNames[:len(argNames)-1], ", "))
g.p("for _, _x := range %s {", argNames[len(argNames)-1])
g.in()
g.p("_s = append(_s, _x)")
g.out()
g.p("}")
callArgs = ", _s..."
}
g.p(`return _mr.mock.ctrl.RecordCall(_mr.mock, "%v"%v)`, m.Name, callArgs)

g.out()
g.p("}")
return nil
return callArgs
}

// Output returns the generator's output, formatted in the standard Go style.
Expand Down