gomock 是 Google 开源的 Golang 测试框架。

GoMock is a mocking framework for the Go programming language.https://github.com/golang/mock


安装 mockgen

To get the latest released version use:
Go version < 1.16

GO111MODULE=on go get github.com/golang/mock/mockgen@v1.6.0

Go 1.16+

go install github.com/golang/mock/mockgen@v1.6.0


// mockgen -source=./driver/navigator_driver.go -destination ./driver/navigator_driver_mock.go -package driver

type INavigatorDriver interface 
    Query(Ctx context.Context,
        SqlClient *sqlclient.SQLClient,
        sql string,
        searchOptions ...*engine.Option,
    ) ([]map[string]interface, error)

    BatchGetProductInfoMap(Ctx context.Context,
        SqlClient *sqlclient.SQLClient,
        date string,
        ids []int64,
        entityFields []string,
    ) (map[int64]interface, error)

    BatchGetBrandInfoMap(Ctx context.Context,
        SqlClient *sqlclient.SQLClient,
        date string,
        ids []int64,
        entityFields []string,
    ) (map[int64]interface, error)

type NavigatorDriver struct 

使用 mockgen 命令行自动生成 gomock代码


mockgen -source=./driver/navigator_driver.go -destination ./driver/navigator_driver_mock.go -package driver

其中, navigator_driver_mock.go 是生成的 mock 代码.



生成的Mock Stub代码如下:

// Code generated by MockGen. DO NOT EDIT.
// Source: ./driver/navigator_driver.go

// Package driver is a generated GoMock package.
package driver

import (
    context "context"
    reflect "reflect"

    gomock "github.com/golang/mock/gomock"

// MockINavigatorDriver is a mock of INavigatorDriver interface.
type MockINavigatorDriver struct 
    ctrl     *gomock.Controller
    recorder *MockINavigatorDriverMockRecorder

// MockINavigatorDriverMockRecorder is the mock recorder for MockINavigatorDriver.
type MockINavigatorDriverMockRecorder struct 
    mock *MockINavigatorDriver

// NewMockINavigatorDriver creates a new mock instance.
func NewMockINavigatorDriver(ctrl *gomock.Controller) *MockINavigatorDriver 
    mock := &MockINavigatorDriverctrl: ctrl
    mock.recorder = &MockINavigatorDriverMockRecordermock
    return mock

// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockINavigatorDriver) EXPECT() *MockINavigatorDriverMockRecorder 
    return m.recorder

// BatchGetBrandInfoList mocks base method.
func (m *MockINavigatorDriver) BatchGetBrandInfoMap(Ctx context.Context, SqlClient *sqlclient.SQLClient, date string, ids []int64, entityFields []string) (map[int64]interface, error) 
    ret := m.ctrl.Call(m, "BatchGetBrandInfoMap", Ctx, SqlClient, date, ids, entityFields)
    ret0, _ := ret[0].(map[int64]interface)
    ret1, _ := ret[1].(error)
    return ret0, ret1

// BatchGetBrandInfoList indicates an expected call of BatchGetBrandInfoList.
func (mr *MockINavigatorDriverMockRecorder) BatchGetBrandInfoList(Ctx, SqlClient, date, ids, entityFields interface) *gomock.Call 
    return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetBrandInfoMap", reflect.TypeOf((*MockINavigatorDriver)(nil).BatchGetBrandInfoMap), Ctx, SqlClient, date, ids, entityFields)

// BatchGetProductInfoList mocks base method.
func (m *MockINavigatorDriver) BatchGetProductInfoMap(Ctx context.Context, SqlClient *sqlclient.SQLClient, date string, ids []int64, entityFields []string) (map[int64]interface, error) 
    ret := m.ctrl.Call(m, "BatchGetProductInfoMap", Ctx, SqlClient, date, ids, entityFields)
    ret0, _ := ret[0].(map[int64]interface)
    ret1, _ := ret[1].(error)
    return ret0, ret1

// BatchGetProductInfoList indicates an expected call of BatchGetProductInfoList.
func (mr *MockINavigatorDriverMockRecorder) BatchGetProductInfoList(Ctx, SqlClient, date, ids, entityFields interface) *gomock.Call 
    return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetProductInfoMap", reflect.TypeOf((*MockINavigatorDriver)(nil).BatchGetProductInfoMap), Ctx, SqlClient, date, ids, entityFields)

// Query mocks base method.
func (m *MockINavigatorDriver) Query(Ctx context.Context, SqlClient *sqlclient.SQLClient, sqlKey, sql string, searchOptions ...*engine.Option) ([]map[string]interface, error) 
    varargs := []interfaceCtx, SqlClient, sqlKey, sql
    for _, a := range searchOptions 
        varargs = append(varargs, a)
    ret := m.ctrl.Call(m, "Query", varargs...)
    ret0, _ := ret[0].([]map[string]interface)
    ret1, _ := ret[1].(error)
    return ret0, ret1

// Query indicates an expected call of Query.
func (mr *MockINavigatorDriverMockRecorder) Query(Ctx, SqlClient, sqlKey, sql interface, searchOptions ...interface) *gomock.Call 
    varargs := append([]interfaceCtx, SqlClient, sqlKey, sql, searchOptions...)
    return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockINavigatorDriver)(nil).Query), varargs...)


我们来 Mock 如下代码中的这个接口调用的返回值:

datasourceData, _ := navigatorDriver.Query(u.Ctx, u.Datasource.SqlClient, u.Datasource.SqlKey, navigatorSQL)
func (u *UIComponent) RenderDataTable() ([]map[string]interface, error) 
    bu, _ := json.Marshal(u)
    endTime := u.DateRangeFilter.EndDate
    dateType := u.DateRangeFilter.DaysType
    // 本周期时间
    dateStr := indexu.GetDateStr(endTime)

    // 1.fetch 数据源
    // 1.1 select 数据源字段 select columns
    columnNames := make([]string, 0)
    for _, e := range u.Datasource.Columns 
        columnNames = append(columnNames, e.Name)

    // 1.2 数据源表
    datasourceTable := u.Datasource.TableName

    // 1.3 过滤条件 DateRangeFilter 是固定的
    whereExpr := buildWhereExpr(dateStr, dateType, u)
    datasourceCQL := alpha.NewCQL().SELECT(columnNames...).FROM(datasourceTable).WHERE(whereExpr).ORDERBY2(u.Datasource.DSOrderType, u.Datasource.DSOrderColumn).LIMIT2(u.Datasource.DSLimit)

    // 1.4 从数据源获取数据(领航者)
    // TODO MOCK, 线上用真实的 NavigatorQueryList 查询
    navigatorDriver := u.INavigatorDriver
    navigatorSQL := datasourceCQL.Compile()
    logu.CtxInfo(u.Ctx, "RenderDataTable", "navigatorSQL: %v", navigatorSQL)
    datasourceData, _ := navigatorDriver.Query(u.Ctx, u.Datasource.SqlClient, u.Datasource.SqlKey, navigatorSQL)
    // 1.5 RSD 数据中添加排名字段信息 RankKey
    if u.NeedRank 
        datasourceData = rocket.NewRSD2(datasourceData, u.Ctx, u.Datasource.SqlClient, u.INavigatorDriver).WithRank(u.RankKey, u.Datasource.Columns).Records

    // 2.内存指标计算
    sqlite, _ := driver.InitSqlite(u.Ctx, map[string]interface)

    // 2.1 从数据源返回数据中解析出列的元数据信息
    columns := ParseColumnsMeta(datasourceData)

    // 2.2 表名生成
    sqliteTableName := driver.GenerateUniqSQLiteTableName()

    // 2.3 建内存表
    driver.CreateTable(u.Ctx, sqliteTableName, columns, sqlite)
    // 2.4 同步数据到内存
    driver.InsertData(u.Ctx, sqliteTableName, datasourceData, columns, sqlite)

    // 2.5 内存数据条数校验
    count, _, _ := driver.Query(nil, fmt.Sprintf("select count(1) as count from %s", sqliteTableName), sqlite)

    if nums, err := convert.ToInt64E(count[0]["count"]); err == nil && nums <= 0 
        return nil, fmt.Errorf("datasource empty")

    // 2.6 内存计算非指标列
    var selectItem = []string
    for _, c := range u.Datasource.Columns  // 指标计算规则元数据信息
        if !c.IsDataIndex 
            cname := c.Name
            selectItem = append(selectItem, cname)
    // 2.7 内存计算指标列
    indexColumns := getIndexColumns(u.Datasource.Columns)
    // add incr select items
    for _, column := range indexColumns 
        columnName := column.Name
        var exp = fmt.Sprintf("IndexInfo(%s) as %s", columnName, columnName)
        selectItem = append(selectItem, exp)

    // 2.8 CQL中添加排名信息 UDF
    if u.NeedRank 
        // Rank Key 是单独指定的,不是数据列的概念
        rankKey := u.RankKey
        selectItem = append(selectItem, fmt.Sprintf("RankInfo(%s) as %s", rankKey, rankKey))

    memCQL := alpha.NewCQL().

    if u.DFLimit != nil && u.DFOffset != nil 
        memCQL = memCQL.LIMIT3(*u.DFLimit, *u.DFOffset)
     else if u.DFLimit != nil && u.DFOffset == nil 
        memCQL = memCQL.LIMIT2(*u.DFLimit)

    incrSQL := memCQL.Compile()

    result, _, _ := driver.Query(u.Ctx, incrSQL, sqlite)

    rsd := rocket.NewRSD2(result, u.Ctx, u.Datasource.SqlClient, u.INavigatorDriver).
        UnmarshalRankInfo(u.NeedRank, u.RankKey).
        FillEntityInfoColumn(dateStr, u.Datasource.Columns)

    return rsd.Records, nil


mock 测试代码


ctrl := gomock.NewController(t)
    defer ctrl.Finish()

    mockDriver := driver.NewMockINavigatorDriver(ctrl)
    // NavigatorQueryList 期望返回
        Query(ctx, SqlClient, "compass_strategy_chance_property_product_stats_di", gomock.Any(), gomock.Any()).


var (
    ctx       = context.Background()
    SqlClient = gomock.Any()

func TestDataTableUIComponent(t *testing.T) 
    ctrl := gomock.NewController(t)
    defer ctrl.Finish()

    mockDriver := driver.NewMockINavigatorDriver(ctrl)
    // NavigatorQueryList 期望返回
        Query(ctx, SqlClient, "compass_strategy_chance_property_product_stats_di", gomock.Any(), gomock.Any()).

        BatchGetProductInfoList(ctx, SqlClient, gomock.Any(), gomock.Any(), gomock.Any()).

    //  EXPECT().
    //  BatchGetBrandInfoList(ctx, SqlClient, gomock.Any(), gomock.Any(), gomock.Any()).
    //  Return(driver.MockNavigatorQueryListBrandMap())

    // 初始化数据源
    columns := []datasource.Column
        Name: "date",
        Name: "days_type",
        Name: "stats_date",
        Name: "cate_id",
        Name: "cate_name",
        Name: "property_name",
        Name: "market_name",
        Name: "product_property_value",
        Name: "product_id", IsRowKey: true, NeedFillEntityInfo: true, EntityType: datasource.Product, EntityInfoColumnKey: "product_info",
        Name: "pay_amt", IsDataIndex: true,
        Name: "pay_combo_cnt", IsDataIndex: true,

    datasoure := &datasource.DataSource
        TableName:     "compass_strategy_chance_property_product_stats_di",
        Columns:       columns,
        SqlKey:        "compass_strategy_chance_property_product_stats_di",
        SearchOptions: []*engine.Option,
        DSOrderColumn: "pay_combo_cnt",
        DSOrderType:   alpha.DESC,
        DSLimit:       50,

    // 创建组件
    UIComponent := NewUIComponent(
            DaysType:  constu.DateType_LAST_SEVEN_DAYS,
            StartDate: 0,
            EndDate:   1653177600,
            DimCondition: map[string]string"cate_id": "123",
                "market_name":            "碎花",
                "product_property_value": "长款裙子",

    // 内存分页
    PageNo := int64(2)
    PageSize := int64(5)

    dflimit := (PageNo - 1) * PageSize
    dfoffset := PageSize

    UIComponent.DFOrderColumn = "pay_combo_cnt"
    UIComponent.DFOrderType = alpha.DESC
    UIComponent.DFLimit = &dflimit
    UIComponent.DFOffset = &dfoffset
    UIComponent.NeedRank = true
    UIComponent.RankKey = "rank"

    // UIComponent 唯一 Render() 数据函数
    result, _ := UIComponent.Render()

    fmt.Println("size:", len(result))

    b, _ := json.Marshal(result)


