未验证 提交 13b50810 编写于 作者: W wei liu 提交者: GitHub

refine mock querycoord (#22198)

Signed-off-by: NWei Liu <wei.liu@zilliz.com>
上级 f4e7b246
......@@ -332,5 +332,6 @@ generate-mockery: getdeps
# internal/rootcoord
$(PWD)/bin/mockery --name=IMetaTable --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=meta_table.go --with-expecter --outpkg=mockrootcoord
$(PWD)/bin/mockery --name=GarbageCollector --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=garbage_collector.go --with-expecter --outpkg=mockrootcoord
#internal/types
$(PWD)/bin/mockery --name=QueryCoordComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/types --filename=mock_querycoord.go --with-expecter --structname=MockQueryCoord --outpkg=types --inpackage
ci-ut: build-cpp-with-coverage generated-proto-go-without-cpp codecov-cpp codecov-go
......@@ -44,7 +44,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/types"
......@@ -293,145 +292,6 @@ func (m *MockRootCoord) RenameCollection(ctx context.Context, req *milvuspb.Rena
return nil, nil
}
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockQueryCoord struct {
MockBase
initErr error
startErr error
stopErr error
regErr error
}
func (m *MockQueryCoord) Init() error {
return m.initErr
}
func (m *MockQueryCoord) Start() error {
return m.startErr
}
func (m *MockQueryCoord) Stop() error {
return m.stopErr
}
func (m *MockQueryCoord) Register() error {
return m.regErr
}
func (m *MockQueryCoord) UpdateStateCode(code commonpb.StateCode) {
}
func (m *MockQueryCoord) SetRootCoord(types.RootCoord) error {
return nil
}
func (m *MockQueryCoord) SetDataCoord(types.DataCoord) error {
return nil
}
func (m *MockQueryCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return &milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Role: "MockQueryCoord",
StateCode: commonpb.StateCode_Healthy,
ExtraInfo: nil,
},
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil
}
func (m *MockQueryCoord) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockQueryCoord) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockQueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockQueryCoord) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockQueryCoord) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockQueryCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {
return &milvuspb.CheckHealthResponse{
IsHealthy: true,
}, nil
}
func (m *MockQueryCoord) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockQueryCoord) DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockQueryCoord) DescribeResourceGroup(ctx context.Context, req *querypb.DescribeResourceGroupRequest) (*querypb.DescribeResourceGroupResponse, error) {
return nil, nil
}
func (m *MockQueryCoord) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockQueryCoord) TransferReplica(ctx context.Context, req *querypb.TransferReplicaRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockQueryCoord) ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) {
return nil, nil
}
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockDataCoord struct {
MockBase
......@@ -1090,15 +950,9 @@ func runAndWaitForServerReady(server *Server) error {
func Test_NewServer(t *testing.T) {
paramtable.Init()
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.NotNil(t, server)
assert.Nil(t, err)
server.proxy = &MockProxy{}
server.rootCoordClient = &MockRootCoord{}
server.queryCoordClient = &MockQueryCoord{}
server.dataCoordClient = &MockDataCoord{}
server := getServer(t)
var err error
t.Run("Run", func(t *testing.T) {
err = runAndWaitForServerReady(server)
assert.Nil(t, err)
......@@ -1448,15 +1302,8 @@ func Test_NewServer(t *testing.T) {
func TestServer_Check(t *testing.T) {
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.NotNil(t, server)
assert.Nil(t, err)
mockProxy := &MockProxy{}
server.proxy = mockProxy
server.rootCoordClient = &MockRootCoord{}
server.queryCoordClient = &MockQueryCoord{}
server.dataCoordClient = &MockDataCoord{}
server := getServer(t)
mockProxy := server.proxy.(*MockProxy)
req := &grpc_health_v1.HealthCheckRequest{Service: ""}
ret, err := server.Check(ctx, req)
......@@ -1503,21 +1350,14 @@ func TestServer_Check(t *testing.T) {
func TestServer_Watch(t *testing.T) {
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.NotNil(t, server)
assert.Nil(t, err)
mockProxy := &MockProxy{}
server.proxy = mockProxy
server.rootCoordClient = &MockRootCoord{}
server.queryCoordClient = &MockQueryCoord{}
server.dataCoordClient = &MockDataCoord{}
server := getServer(t)
mockProxy := server.proxy.(*MockProxy)
watchServer := milvusmock.NewGrpcHealthWatchServer()
resultChan := watchServer.Chan()
req := &grpc_health_v1.HealthCheckRequest{Service: ""}
//var ret *grpc_health_v1.HealthCheckResponse
err = server.Watch(req, watchServer)
err := server.Watch(req, watchServer)
ret := <-resultChan
assert.Nil(t, err)
......@@ -1567,19 +1407,10 @@ func TestServer_Watch(t *testing.T) {
}
func Test_NewServer_HTTPServer_Enabled(t *testing.T) {
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.NotNil(t, server)
assert.Nil(t, err)
server.proxy = &MockProxy{}
server.rootCoordClient = &MockRootCoord{}
server.queryCoordClient = &MockQueryCoord{}
server.dataCoordClient = &MockDataCoord{}
server := getServer(t)
paramtable.Get().Save(proxy.Params.HTTPCfg.Enabled.Key, "true")
err = runAndWaitForServerReady(server)
err := runAndWaitForServerReady(server)
assert.Nil(t, err)
err = server.Stop()
assert.Nil(t, err)
......@@ -1602,8 +1433,21 @@ func getServer(t *testing.T) *Server {
server.proxy = &MockProxy{}
server.rootCoordClient = &MockRootCoord{}
server.queryCoordClient = &MockQueryCoord{}
server.dataCoordClient = &MockDataCoord{}
mockQC := &types.MockQueryCoord{}
server.queryCoordClient = mockQC
mockQC.EXPECT().Init().Return(nil)
mockQC.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Role: "MockQueryCoord",
StateCode: commonpb.StateCode_Healthy,
ExtraInfo: nil,
},
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
return server
}
......
......@@ -22,7 +22,6 @@ import (
"os"
"testing"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/paramtable"
......@@ -31,165 +30,9 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/stretchr/testify/assert"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/stretchr/testify/mock"
)
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockQueryCoord struct {
states *milvuspb.ComponentStates
status *commonpb.Status
err error
initErr error
startErr error
stopErr error
regErr error
strResp *milvuspb.StringResponse
showcolResp *querypb.ShowCollectionsResponse
showpartResp *querypb.ShowPartitionsResponse
partResp *querypb.GetPartitionStatesResponse
infoResp *querypb.GetSegmentInfoResponse
configResp *internalpb.ShowConfigurationsResponse
metricResp *milvuspb.GetMetricsResponse
replicasResp *milvuspb.GetReplicasResponse
shardLeadersResp *querypb.GetShardLeadersResponse
}
func (m *MockQueryCoord) Init() error {
return m.initErr
}
func (m *MockQueryCoord) Start() error {
return m.startErr
}
func (m *MockQueryCoord) Stop() error {
return m.stopErr
}
func (m *MockQueryCoord) Register() error {
log.Debug("MockQueryCoord::Register")
return m.regErr
}
func (m *MockQueryCoord) UpdateStateCode(code commonpb.StateCode) {
}
func (m *MockQueryCoord) SetAddress(address string) {
}
func (m *MockQueryCoord) SetEtcdClient(client *clientv3.Client) {
}
func (m *MockQueryCoord) SetRootCoord(types.RootCoord) error {
return nil
}
func (m *MockQueryCoord) SetDataCoord(types.DataCoord) error {
return nil
}
func (m *MockQueryCoord) SetQueryNodeCreator(func(ctx context.Context, addr string) (types.QueryNode, error)) {
}
func (m *MockQueryCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
log.Debug("MockQueryCoord::WaitForComponentStates")
return m.states, m.err
}
func (m *MockQueryCoord) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return m.strResp, m.err
}
func (m *MockQueryCoord) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return m.strResp, m.err
}
func (m *MockQueryCoord) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return m.showcolResp, m.err
}
func (m *MockQueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
return m.status, m.err
}
func (m *MockQueryCoord) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return m.status, m.err
}
func (m *MockQueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return m.showpartResp, m.err
}
func (m *MockQueryCoord) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
return m.partResp, m.err
}
func (m *MockQueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return m.status, m.err
}
func (m *MockQueryCoord) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return m.status, m.err
}
func (m *MockQueryCoord) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return m.infoResp, m.err
}
func (m *MockQueryCoord) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) {
return m.status, m.err
}
func (m *MockQueryCoord) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
return m.configResp, m.err
}
func (m *MockQueryCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
return m.metricResp, m.err
}
func (m *MockQueryCoord) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) {
return m.replicasResp, m.err
}
func (m *MockQueryCoord) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return m.shardLeadersResp, m.err
}
func (m *MockQueryCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {
return &milvuspb.CheckHealthResponse{
IsHealthy: true,
}, m.err
}
func (m *MockQueryCoord) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) {
return m.status, nil
}
func (m *MockQueryCoord) DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) {
return m.status, nil
}
func (m *MockQueryCoord) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) {
return m.status, nil
}
func (m *MockQueryCoord) TransferReplica(ctx context.Context, req *querypb.TransferReplicaRequest) (*commonpb.Status, error) {
return m.status, nil
}
func (m *MockQueryCoord) ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) {
return &milvuspb.ListResourceGroupsResponse{
Status: m.status,
}, nil
}
func (m *MockQueryCoord) DescribeResourceGroup(ctx context.Context, req *querypb.DescribeResourceGroupRequest) (*querypb.DescribeResourceGroupResponse, error) {
return &querypb.DescribeResourceGroupResponse{
Status: m.status,
}, nil
}
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockRootCoord struct {
types.RootCoord
......@@ -269,22 +112,6 @@ func Test_NewServer(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, server)
mqc := &MockQueryCoord{
states: &milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
},
status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
err: nil,
strResp: &milvuspb.StringResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}},
showcolResp: &querypb.ShowCollectionsResponse{},
showpartResp: &querypb.ShowPartitionsResponse{},
partResp: &querypb.GetPartitionStatesResponse{},
infoResp: &querypb.GetSegmentInfoResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}},
configResp: &internalpb.ShowConfigurationsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}},
metricResp: &milvuspb.GetMetricsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}},
}
mdc := &MockDataCoord{
stateErr: commonpb.ErrorCode_Success,
}
......@@ -293,6 +120,11 @@ func Test_NewServer(t *testing.T) {
stateErr: commonpb.ErrorCode_Success,
}
mqc := getQueryCoord()
successStatus := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
t.Run("Run", func(t *testing.T) {
server.queryCoord = mqc
server.dataCoord = mdc
......@@ -303,6 +135,15 @@ func Test_NewServer(t *testing.T) {
})
t.Run("GetComponentStates", func(t *testing.T) {
mqc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: "MockQueryCoord",
StateCode: commonpb.StateCode_Healthy,
},
Status: successStatus,
}, nil)
req := &milvuspb.GetComponentStatesRequest{}
states, err := server.GetComponentStates(ctx, req)
assert.Nil(t, err)
......@@ -311,6 +152,11 @@ func Test_NewServer(t *testing.T) {
t.Run("GetStatisticsChannel", func(t *testing.T) {
req := &internalpb.GetStatisticsChannelRequest{}
mqc.EXPECT().GetStatisticsChannel(mock.Anything).Return(
&milvuspb.StringResponse{
Status: successStatus,
}, nil,
)
resp, err := server.GetStatisticsChannel(ctx, req)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
......@@ -318,53 +164,71 @@ func Test_NewServer(t *testing.T) {
t.Run("GetTimeTickChannel", func(t *testing.T) {
req := &internalpb.GetTimeTickChannelRequest{}
mqc.EXPECT().GetTimeTickChannel(mock.Anything).Return(
&milvuspb.StringResponse{
Status: successStatus,
}, nil,
)
resp, err := server.GetTimeTickChannel(ctx, req)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
t.Run("ShowCollections", func(t *testing.T) {
mqc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(
&querypb.ShowCollectionsResponse{
Status: successStatus,
}, nil,
)
resp, err := server.ShowCollections(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, resp)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
t.Run("LoadCollection", func(t *testing.T) {
mqc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := server.LoadCollection(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, resp)
})
t.Run("ReleaseCollection", func(t *testing.T) {
mqc.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := server.ReleaseCollection(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, resp)
})
t.Run("ShowPartitions", func(t *testing.T) {
mqc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{Status: successStatus}, nil)
resp, err := server.ShowPartitions(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, resp)
})
t.Run("GetPartitionStates", func(t *testing.T) {
mqc.EXPECT().GetPartitionStates(mock.Anything, mock.Anything).Return(&querypb.GetPartitionStatesResponse{Status: successStatus}, nil)
resp, err := server.GetPartitionStates(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, resp)
})
t.Run("LoadPartitions", func(t *testing.T) {
mqc.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := server.LoadPartitions(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, resp)
})
t.Run("ReleasePartitions", func(t *testing.T) {
mqc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := server.ReleasePartitions(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, resp)
})
t.Run("GetTimeTickChannel", func(t *testing.T) {
mqc.EXPECT().GetTimeTickChannel(mock.Anything).Return(&milvuspb.StringResponse{Status: successStatus}, nil)
resp, err := server.GetTimeTickChannel(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, resp)
......@@ -372,6 +236,7 @@ func Test_NewServer(t *testing.T) {
t.Run("GetSegmentInfo", func(t *testing.T) {
req := &querypb.GetSegmentInfoRequest{}
mqc.EXPECT().GetSegmentInfo(mock.Anything, req).Return(&querypb.GetSegmentInfoResponse{Status: successStatus}, nil)
resp, err := server.GetSegmentInfo(ctx, req)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
......@@ -379,6 +244,7 @@ func Test_NewServer(t *testing.T) {
t.Run("LoadBalance", func(t *testing.T) {
req := &querypb.LoadBalanceRequest{}
mqc.EXPECT().LoadBalance(mock.Anything, req).Return(successStatus, nil)
resp, err := server.LoadBalance(ctx, req)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
......@@ -388,36 +254,43 @@ func Test_NewServer(t *testing.T) {
req := &milvuspb.GetMetricsRequest{
Request: "",
}
mqc.EXPECT().GetMetrics(mock.Anything, req).Return(&milvuspb.GetMetricsResponse{Status: successStatus}, nil)
resp, err := server.GetMetrics(ctx, req)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
t.Run("CheckHealth", func(t *testing.T) {
mqc.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return(
&milvuspb.CheckHealthResponse{Status: successStatus, IsHealthy: true}, nil)
ret, err := server.CheckHealth(ctx, nil)
assert.Nil(t, err)
assert.Equal(t, true, ret.IsHealthy)
})
t.Run("CreateResourceGroup", func(t *testing.T) {
mqc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := server.CreateResourceGroup(ctx, nil)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
})
t.Run("DropResourceGroup", func(t *testing.T) {
mqc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := server.DropResourceGroup(ctx, nil)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
})
t.Run("TransferNode", func(t *testing.T) {
mqc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := server.TransferNode(ctx, nil)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
})
t.Run("TransferReplica", func(t *testing.T) {
mqc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := server.TransferReplica(ctx, nil)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
......@@ -425,12 +298,14 @@ func Test_NewServer(t *testing.T) {
t.Run("ListResourceGroups", func(t *testing.T) {
req := &milvuspb.ListResourceGroupsRequest{}
mqc.EXPECT().ListResourceGroups(mock.Anything, req).Return(&milvuspb.ListResourceGroupsResponse{Status: successStatus}, nil)
resp, err := server.ListResourceGroups(ctx, req)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
t.Run("DescribeResourceGroup", func(t *testing.T) {
mqc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{Status: successStatus}, nil)
resp, err := server.DescribeResourceGroup(ctx, nil)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
......@@ -448,9 +323,9 @@ func TestServer_Run1(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, server)
server.queryCoord = &MockQueryCoord{
regErr: errors.New("error"),
}
mqc := getQueryCoord()
mqc.EXPECT().Start().Return(errors.New("error"))
server.queryCoord = mqc
err = server.Run()
assert.Error(t, err)
......@@ -464,7 +339,7 @@ func TestServer_Run2(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, server)
server.queryCoord = &MockQueryCoord{}
server.queryCoord = getQueryCoord()
server.rootCoord = &MockRootCoord{
initErr: errors.New("error"),
}
......@@ -473,13 +348,27 @@ func TestServer_Run2(t *testing.T) {
assert.Nil(t, err)
}
func getQueryCoord() *types.MockQueryCoord {
mqc := &types.MockQueryCoord{}
mqc.EXPECT().Init().Return(nil)
mqc.EXPECT().SetEtcdClient(mock.Anything)
mqc.EXPECT().SetAddress(mock.Anything)
mqc.EXPECT().SetRootCoord(mock.Anything).Return(nil)
mqc.EXPECT().SetDataCoord(mock.Anything).Return(nil)
mqc.EXPECT().UpdateStateCode(mock.Anything)
mqc.EXPECT().Register().Return(nil)
mqc.EXPECT().Start().Return(nil)
mqc.EXPECT().Stop().Return(nil)
return mqc
}
func TestServer_Run3(t *testing.T) {
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, server)
server.queryCoord = &MockQueryCoord{}
server.queryCoord = getQueryCoord()
server.rootCoord = &MockRootCoord{
startErr: errors.New("error"),
}
......@@ -495,7 +384,7 @@ func TestServer_Run4(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, server)
server.queryCoord = &MockQueryCoord{}
server.queryCoord = getQueryCoord()
server.rootCoord = &MockRootCoord{}
server.dataCoord = &MockDataCoord{
initErr: errors.New("error"),
......@@ -511,7 +400,7 @@ func TestServer_Run5(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, server)
server.queryCoord = &MockQueryCoord{}
server.queryCoord = getQueryCoord()
server.rootCoord = &MockRootCoord{}
server.dataCoord = &MockDataCoord{
startErr: errors.New("error"),
......
......@@ -6,6 +6,7 @@ import (
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/crypto"
......@@ -24,7 +25,7 @@ func TestValidAuth(t *testing.T) {
assert.False(t, res)
// normal metadata
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
......@@ -54,7 +55,7 @@ func TestAuthenticationInterceptor(t *testing.T) {
assert.NotNil(t, err)
// mock metacache
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err = InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
......
......@@ -29,6 +29,8 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil"
......@@ -72,9 +74,11 @@ func TestProxy_CheckHealth(t *testing.T) {
})
t.Run("proxy health check is ok", func(t *testing.T) {
qc := &types.MockQueryCoord{}
qc.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return(&milvuspb.CheckHealthResponse{IsHealthy: true}, nil)
node := &Proxy{
rootCoord: NewRootCoordMock(),
queryCoord: NewQueryCoordMock(),
queryCoord: qc,
dataCoord: NewDataCoordMock(),
session: &sessionutil.Session{ServerID: 1},
}
......@@ -96,22 +100,18 @@ func TestProxy_CheckHealth(t *testing.T) {
}, nil
}
checkHealthFunc2 := func(ctx context.Context,
req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {
return nil, errors.New("test")
}
dataCoordMock := NewDataCoordMock()
dataCoordMock.checkHealthFunc = checkHealthFunc1
qc := &types.MockQueryCoord{}
qc.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return(nil, errors.New("test"))
node := &Proxy{
session: &sessionutil.Session{ServerID: 1},
rootCoord: NewRootCoordMock(func(mock *RootCoordMock) {
mock.checkHealthFunc = checkHealthFunc1
}),
queryCoord: NewQueryCoordMock(func(mock *QueryCoordMock) {
mock.checkHealthFunc = checkHealthFunc2
}),
dataCoord: dataCoordMock}
queryCoord: qc,
dataCoord: dataCoordMock}
node.multiRateLimiter = NewMultiRateLimiter()
node.stateCode.Store(commonpb.StateCode_Healthy)
ctx := context.Background()
......@@ -122,10 +122,12 @@ func TestProxy_CheckHealth(t *testing.T) {
})
t.Run("check quota state", func(t *testing.T) {
qc := &types.MockQueryCoord{}
qc.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return(&milvuspb.CheckHealthResponse{IsHealthy: true}, nil)
node := &Proxy{
rootCoord: NewRootCoordMock(),
dataCoord: NewDataCoordMock(),
queryCoord: NewQueryCoordMock(),
queryCoord: qc,
}
node.multiRateLimiter = NewMultiRateLimiter()
node.stateCode.Store(commonpb.StateCode_Healthy)
......@@ -209,7 +211,7 @@ func TestProxy_ResourceGroup(t *testing.T) {
node.multiRateLimiter = NewMultiRateLimiter()
node.stateCode.Store(commonpb.StateCode_Healthy)
qc := NewQueryCoordMock()
qc := types.NewMockQueryCoord(t)
node.SetQueryCoordClient(qc)
tsoAllocatorIns := newMockTsoAllocator()
......@@ -222,7 +224,10 @@ func TestProxy_ResourceGroup(t *testing.T) {
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
t.Run("create resource group", func(t *testing.T) {
qc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := node.CreateResourceGroup(ctx, &milvuspb.CreateResourceGroupRequest{
ResourceGroup: "rg",
})
......@@ -231,6 +236,7 @@ func TestProxy_ResourceGroup(t *testing.T) {
})
t.Run("drop resource group", func(t *testing.T) {
qc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := node.DropResourceGroup(ctx, &milvuspb.DropResourceGroupRequest{
ResourceGroup: "rg",
})
......@@ -239,6 +245,7 @@ func TestProxy_ResourceGroup(t *testing.T) {
})
t.Run("transfer node", func(t *testing.T) {
qc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := node.TransferNode(ctx, &milvuspb.TransferNodeRequest{
SourceResourceGroup: "rg1",
TargetResourceGroup: "rg2",
......@@ -249,6 +256,7 @@ func TestProxy_ResourceGroup(t *testing.T) {
})
t.Run("transfer replica", func(t *testing.T) {
qc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(successStatus, nil)
resp, err := node.TransferReplica(ctx, &milvuspb.TransferReplicaRequest{
SourceResourceGroup: "rg1",
TargetResourceGroup: "rg2",
......@@ -260,12 +268,24 @@ func TestProxy_ResourceGroup(t *testing.T) {
})
t.Run("list resource group", func(t *testing.T) {
qc.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{Status: successStatus}, nil)
resp, err := node.ListResourceGroups(ctx, &milvuspb.ListResourceGroupsRequest{})
assert.NoError(t, err)
assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_Success)
})
t.Run("describe resource group", func(t *testing.T) {
qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{
Status: successStatus,
ResourceGroup: &querypb.ResourceGroupInfo{
Name: "rg",
Capacity: 1,
NumAvailableNode: 1,
NumLoadedReplica: nil,
NumOutgoingNode: nil,
NumIncomingNode: nil,
},
}, nil)
resp, err := node.DescribeResourceGroup(ctx, &milvuspb.DescribeResourceGroupRequest{
ResourceGroup: "rg",
})
......@@ -283,7 +303,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) {
node.multiRateLimiter = NewMultiRateLimiter()
node.stateCode.Store(commonpb.StateCode_Healthy)
qc := NewQueryCoordMock()
qc := types.NewMockQueryCoord(t)
node.SetQueryCoordClient(qc)
tsoAllocatorIns := newMockTsoAllocator()
......
......@@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
......@@ -59,21 +60,6 @@ func (m *MockRootCoordClientInterface) GetAccessCount() int {
return int(ret)
}
type MockQueryCoordClientInterface struct {
types.QueryCoord
Error bool
AccessCount int32
}
func (m *MockQueryCoordClientInterface) IncAccessCount() {
atomic.AddInt32(&m.AccessCount, 1)
}
func (m *MockQueryCoordClientInterface) GetAccessCount() int {
ret := atomic.LoadInt32(&m.AccessCount)
return int(ret)
}
func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) {
if m.Error {
return nil, errors.New("mocked error")
......@@ -217,26 +203,11 @@ func (m *MockRootCoordClientInterface) ListPolicy(ctx context.Context, in *inter
}, nil
}
func (m *MockQueryCoordClientInterface) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
if m.Error {
return nil, errors.New("mocked error")
}
m.IncAccessCount()
rsp := &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []UniqueID{1, 2},
InMemoryPercentages: []int64{100, 50},
}
return rsp, nil
}
// Simulate the cache path and the
func TestMetaCache_GetCollection(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
......@@ -287,7 +258,7 @@ func TestMetaCache_GetCollection(t *testing.T) {
func TestMetaCache_GetCollectionName(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
......@@ -337,7 +308,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
func TestMetaCache_GetCollectionFailure(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
......@@ -370,7 +341,7 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) {
func TestMetaCache_GetNonExistCollection(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
......@@ -386,7 +357,7 @@ func TestMetaCache_GetNonExistCollection(t *testing.T) {
func TestMetaCache_GetPartitionID(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
......@@ -408,7 +379,7 @@ func TestMetaCache_GetPartitionID(t *testing.T) {
func TestMetaCache_ConcurrentTest1(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
......@@ -462,7 +433,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
func TestMetaCache_GetPartitionError(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
......@@ -498,7 +469,8 @@ func TestMetaCache_GetShards(t *testing.T) {
)
rootCoord := &MockRootCoordClientInterface{}
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.EXPECT().Init().Return(nil)
shardMgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, qc, shardMgr)
require.Nil(t, err)
......@@ -514,14 +486,36 @@ func TestMetaCache_GetShards(t *testing.T) {
})
t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) {
qc.validShardLeaders = false
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented",
},
}, nil).Times(1)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
shards, err := globalMetaCache.GetShards(ctx, false, collectionName)
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info", func(t *testing.T) {
qc.validShardLeaders = true
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
shards, err := globalMetaCache.GetShards(ctx, true, collectionName)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
......@@ -529,7 +523,12 @@ func TestMetaCache_GetShards(t *testing.T) {
assert.Equal(t, 3, len(shards["channel-1"]))
// get from cache
qc.validShardLeaders = false
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented",
},
}, nil)
shards, err = globalMetaCache.GetShards(ctx, true, collectionName)
assert.NoError(t, err)
......@@ -546,7 +545,8 @@ func TestMetaCache_ClearShards(t *testing.T) {
)
rootCoord := &MockRootCoordClientInterface{}
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.EXPECT().Init().Return(nil)
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, qc, mgr)
require.Nil(t, err)
......@@ -565,7 +565,21 @@ func TestMetaCache_ClearShards(t *testing.T) {
t.Run("Clear valid collection valid cache", func(t *testing.T) {
qc.validShardLeaders = true
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil).Times(1)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
shards, err := globalMetaCache.GetShards(ctx, true, collectionName)
require.NoError(t, err)
require.NotEmpty(t, shards)
......@@ -574,7 +588,12 @@ func TestMetaCache_ClearShards(t *testing.T) {
globalMetaCache.ClearShards(collectionName)
qc.validShardLeaders = false
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented",
},
}, nil)
shards, err = globalMetaCache.GetShards(ctx, true, collectionName)
assert.Error(t, err)
assert.Empty(t, shards)
......@@ -583,7 +602,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
func TestMetaCache_PolicyInfo(t *testing.T) {
client := &MockRootCoordClientInterface{}
qc := &MockQueryCoordClientInterface{}
qc := &types.MockQueryCoord{}
mgr := newShardClientMgr()
t.Run("InitMetaCache", func(t *testing.T) {
......@@ -666,11 +685,21 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
func TestMetaCache_LoadCache(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.Nil(t, err)
qcCounter := 0
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []UniqueID{1, 2},
InMemoryPercentages: []int64{100, 50},
}, nil).Run(func(ctx context.Context, req *querypb.ShowCollectionsRequest) {
qcCounter++
})
t.Run("test IsCollectionLoaded", func(t *testing.T) {
info, err := globalMetaCache.GetCollectionInfo(ctx, "collection1")
assert.NoError(t, err)
......@@ -678,14 +707,14 @@ func TestMetaCache_LoadCache(t *testing.T) {
// no collectionInfo of collection1, should access RootCoord
assert.Equal(t, rootCoord.GetAccessCount(), 1)
// not loaded, should access QueryCoord
assert.Equal(t, queryCoord.GetAccessCount(), 1)
assert.Equal(t, qcCounter, 1)
info, err = globalMetaCache.GetCollectionInfo(ctx, "collection1")
assert.NoError(t, err)
assert.True(t, info.isLoaded)
// shouldn't access QueryCoord or RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 1)
assert.Equal(t, queryCoord.GetAccessCount(), 1)
assert.Equal(t, qcCounter, 1)
// test collection2 not fully loaded
info, err = globalMetaCache.GetCollectionInfo(ctx, "collection2")
......@@ -694,7 +723,7 @@ func TestMetaCache_LoadCache(t *testing.T) {
// no collectionInfo of collection2, should access RootCoord
assert.Equal(t, rootCoord.GetAccessCount(), 2)
// not loaded, should access QueryCoord
assert.Equal(t, queryCoord.GetAccessCount(), 2)
assert.Equal(t, qcCounter, 2)
})
t.Run("test RemoveCollectionLoadCache", func(t *testing.T) {
......@@ -703,18 +732,26 @@ func TestMetaCache_LoadCache(t *testing.T) {
assert.NoError(t, err)
assert.True(t, info.isLoaded)
// should access QueryCoord
assert.Equal(t, queryCoord.GetAccessCount(), 3)
assert.Equal(t, qcCounter, 3)
})
}
func TestMetaCache_RemoveCollection(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
shardMgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr)
assert.Nil(t, err)
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []UniqueID{1, 2},
InMemoryPercentages: []int64{100, 50},
}, nil)
info, err := globalMetaCache.GetCollectionInfo(ctx, "collection1")
assert.NoError(t, err)
assert.True(t, info.isLoaded)
......
......@@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
)
......@@ -43,7 +44,7 @@ func TestProxy_metrics(t *testing.T) {
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.Start()
defer qc.Stop()
......@@ -99,7 +100,7 @@ func TestProxy_metrics(t *testing.T) {
}, nil
}
qc.getMetricsFunc = func(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
getMetricsFunc := func(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
id := typeutil.UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
clusterTopology := metricsinfo.QueryClusterTopology{
......@@ -150,6 +151,7 @@ func TestProxy_metrics(t *testing.T) {
ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryCoordRole, id),
}, nil
}
qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(getMetricsFunc(nil, nil))
dc.getMetricsFunc = func(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
id := typeutil.UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
......@@ -217,6 +219,5 @@ func TestProxy_metrics(t *testing.T) {
assert.NotNil(t, resp)
rc.getMetricsFunc = nil
qc.getMetricsFunc = nil
dc.getMetricsFunc = nil
}
......@@ -8,6 +8,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/stretchr/testify/assert"
......@@ -44,7 +45,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
ctx = GetContext(context.Background(), "alice:123456")
client := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
......@@ -145,7 +146,7 @@ func TestResourceGroupPrivilege(t *testing.T) {
ctx = GetContext(context.Background(), "fooo:123456")
client := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
mgr := newShardClientMgr()
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
......
......@@ -72,6 +72,7 @@ import (
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
......@@ -4192,9 +4193,26 @@ func TestProxy_GetLoadState(t *testing.T) {
}()
{
q := NewQueryCoordMock()
q.state.Store(commonpb.StateCode_Abnormal)
proxy := &Proxy{queryCoord: q}
qc := getQueryCoord()
qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: typeutil.QueryCoordRole,
StateCode: commonpb.StateCode_Abnormal,
ExtraInfo: nil,
},
SubcomponentStates: nil,
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: nil,
InMemoryPercentages: []int64{},
}, nil)
proxy := &Proxy{queryCoord: qc}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
assert.NoError(t, err)
......@@ -4206,13 +4224,23 @@ func TestProxy_GetLoadState(t *testing.T) {
}
{
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return nil, errors.New("test")
}), SetQueryCoordShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("test")
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
qc := getQueryCoord()
qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: typeutil.QueryCoordRole,
StateCode: commonpb.StateCode_Healthy,
ExtraInfo: nil,
},
SubcomponentStates: nil,
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(nil, errors.New("test"))
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(nil, errors.New("test"))
proxy := &Proxy{queryCoord: qc}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
......@@ -4237,15 +4265,27 @@ func TestProxy_GetLoadState(t *testing.T) {
}
{
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: request.CollectionIDs,
InMemoryPercentages: []int64{},
}, nil
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
qc := getQueryCoord()
qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: typeutil.QueryCoordRole,
StateCode: commonpb.StateCode_Healthy,
ExtraInfo: nil,
},
SubcomponentStates: nil,
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: nil,
InMemoryPercentages: []int64{},
}, nil)
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(nil, errors.New("test"))
proxy := &Proxy{queryCoord: qc}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
......@@ -4259,15 +4299,26 @@ func TestProxy_GetLoadState(t *testing.T) {
}
{
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: request.CollectionIDs,
InMemoryPercentages: []int64{100},
}, nil
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
qc := getQueryCoord()
qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: typeutil.QueryCoordRole,
StateCode: commonpb.StateCode_Healthy,
ExtraInfo: nil,
},
SubcomponentStates: nil,
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: nil,
InMemoryPercentages: []int64{100},
}, nil)
proxy := &Proxy{queryCoord: qc}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo", Base: &commonpb.MsgBase{}})
......@@ -4286,15 +4337,26 @@ func TestProxy_GetLoadState(t *testing.T) {
}
{
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: request.CollectionIDs,
InMemoryPercentages: []int64{50},
}, nil
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
qc := getQueryCoord()
qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: typeutil.QueryCoordRole,
StateCode: commonpb.StateCode_Healthy,
ExtraInfo: nil,
},
SubcomponentStates: nil,
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: nil,
InMemoryPercentages: []int64{50},
}, nil)
proxy := &Proxy{queryCoord: qc}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
......@@ -4309,17 +4371,27 @@ func TestProxy_GetLoadState(t *testing.T) {
}
t.Run("test insufficient memory", func(t *testing.T) {
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad},
}, nil
}), SetQueryCoordShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad},
}, nil
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
qc := getQueryCoord()
qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: typeutil.QueryCoordRole,
StateCode: commonpb.StateCode_Healthy,
ExtraInfo: nil,
},
SubcomponentStates: nil,
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad},
}, nil)
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad},
}, nil)
proxy := &Proxy{queryCoord: qc}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
......
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"fmt"
"sync"
"sync/atomic"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type QueryCoordMockOption func(mock *QueryCoordMock)
type queryCoordShowCollectionsFuncType func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)
type queryCoordShowPartitionsFuncType func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error)
type queryCoordShowConfigurationsFuncType func(ctx context.Context, request *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)
func SetQueryCoordShowCollectionsFunc(f queryCoordShowCollectionsFuncType) QueryCoordMockOption {
return func(mock *QueryCoordMock) {
mock.showCollectionsFunc = f
}
}
func SetQueryCoordShowPartitionsFunc(f queryCoordShowPartitionsFuncType) QueryCoordMockOption {
return func(mock *QueryCoordMock) {
mock.showPartitionsFunc = f
}
}
func withValidShardLeaders() QueryCoordMockOption {
return func(mock *QueryCoordMock) {
mock.validShardLeaders = true
}
}
type QueryCoordMock struct {
nodeID typeutil.UniqueID
address string
state atomic.Value // internal.StateCode
collectionIDs []int64
inMemoryPercentages []int64
colMtx sync.RWMutex
showConfigurationsFunc queryCoordShowConfigurationsFuncType
showCollectionsFunc queryCoordShowCollectionsFuncType
getMetricsFunc getMetricsFuncType
showPartitionsFunc queryCoordShowPartitionsFuncType
statisticsChannel string
timeTickChannel string
validShardLeaders bool
checkHealthFunc func(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)
}
func (coord *QueryCoordMock) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {
if coord.checkHealthFunc != nil {
return coord.checkHealthFunc(ctx, req)
}
return &milvuspb.CheckHealthResponse{IsHealthy: true}, nil
}
func (coord *QueryCoordMock) updateState(state commonpb.StateCode) {
coord.state.Store(state)
}
func (coord *QueryCoordMock) getState() commonpb.StateCode {
return coord.state.Load().(commonpb.StateCode)
}
func (coord *QueryCoordMock) healthy() bool {
return coord.getState() == commonpb.StateCode_Healthy
}
func (coord *QueryCoordMock) Init() error {
coord.updateState(commonpb.StateCode_Initializing)
return nil
}
func (coord *QueryCoordMock) Start() error {
defer coord.updateState(commonpb.StateCode_Healthy)
return nil
}
func (coord *QueryCoordMock) Stop() error {
defer coord.updateState(commonpb.StateCode_Abnormal)
return nil
}
func (coord *QueryCoordMock) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return &milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: coord.nodeID,
Role: typeutil.QueryCoordRole,
StateCode: coord.getState(),
ExtraInfo: nil,
},
SubcomponentStates: nil,
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
}, nil
}
func (coord *QueryCoordMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return &milvuspb.StringResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
Value: coord.statisticsChannel,
}, nil
}
func (coord *QueryCoordMock) Register() error {
return nil
}
func (coord *QueryCoordMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return &milvuspb.StringResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
Value: coord.timeTickChannel,
}, nil
}
func (coord *QueryCoordMock) ResetShowCollectionsFunc() {
coord.showCollectionsFunc = nil
}
func (coord *QueryCoordMock) SetShowCollectionsFunc(f queryCoordShowCollectionsFuncType) {
coord.showCollectionsFunc = f
}
func (coord *QueryCoordMock) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
if !coord.healthy() {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
},
}, nil
}
if coord.showCollectionsFunc != nil {
return coord.showCollectionsFunc(ctx, req)
}
coord.colMtx.RLock()
defer coord.colMtx.RUnlock()
resp := &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
CollectionIDs: coord.collectionIDs,
InMemoryPercentages: coord.inMemoryPercentages,
}
return resp, nil
}
func (coord *QueryCoordMock) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
if !coord.healthy() {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
}, nil
}
coord.colMtx.Lock()
defer coord.colMtx.Unlock()
for _, colID := range coord.collectionIDs {
if req.CollectionID == colID {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: fmt.Sprintf("collection %v already loaded", req.CollectionID),
}, nil
}
}
coord.collectionIDs = append(coord.collectionIDs, req.CollectionID)
coord.inMemoryPercentages = append(coord.inMemoryPercentages, 100)
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func (coord *QueryCoordMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
if !coord.healthy() {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
}, nil
}
coord.colMtx.Lock()
defer coord.colMtx.Unlock()
for i := len(coord.collectionIDs) - 1; i >= 0; i-- {
if req.CollectionID == coord.collectionIDs[i] {
coord.collectionIDs = append(coord.collectionIDs[:i], coord.collectionIDs[i+1:]...)
coord.inMemoryPercentages = append(coord.inMemoryPercentages[:i], coord.inMemoryPercentages[i+1:]...)
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
}
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: fmt.Sprintf("collection %v not loaded", req.CollectionID),
}, nil
}
func (coord *QueryCoordMock) SetShowPartitionsFunc(f queryCoordShowPartitionsFuncType) {
coord.showPartitionsFunc = f
}
func (coord *QueryCoordMock) ResetShowPartitionsFunc() {
coord.showPartitionsFunc = nil
}
func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
if coord.showPartitionsFunc != nil {
return coord.showPartitionsFunc(ctx, req)
}
return nil, nil
}
func (coord *QueryCoordMock) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
if !coord.healthy() {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
}, nil
}
panic("implement me")
}
func (coord *QueryCoordMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
if !coord.healthy() {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
}, nil
}
panic("implement me")
}
func (coord *QueryCoordMock) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
if !coord.healthy() {
return &querypb.GetPartitionStatesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
},
}, nil
}
panic("implement me")
}
func (coord *QueryCoordMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
if !coord.healthy() {
return &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
},
}, nil
}
panic("implement me")
}
func (coord *QueryCoordMock) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) {
if !coord.healthy() {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
}, nil
}
panic("implement me")
}
func (coord *QueryCoordMock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
if !coord.healthy() {
return &internalpb.ShowConfigurationsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
},
}, nil
}
if coord.showConfigurationsFunc != nil {
return coord.showConfigurationsFunc(ctx, req)
}
return &internalpb.ShowConfigurationsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented",
},
}, nil
}
func (coord *QueryCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
if !coord.healthy() {
return &milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
},
}, nil
}
if coord.getMetricsFunc != nil {
return coord.getMetricsFunc(ctx, req)
}
return &milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented",
},
Response: "",
ComponentName: "",
}, nil
}
func (coord *QueryCoordMock) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) {
if !coord.healthy() {
return &milvuspb.GetReplicasResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
},
}, nil
}
return &milvuspb.GetReplicasResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented",
},
}, nil
}
func (coord *QueryCoordMock) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
if !coord.healthy() {
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
},
}, nil
}
if coord.validShardLeaders {
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil
}
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented",
},
}, nil
}
func (coord *QueryCoordMock) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func (coord *QueryCoordMock) DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func (coord *QueryCoordMock) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func (coord *QueryCoordMock) TransferReplica(ctx context.Context, req *querypb.TransferReplicaRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func (coord *QueryCoordMock) ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) {
return &milvuspb.ListResourceGroupsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResourceGroups: []string{meta.DefaultResourceGroupName, "rg"},
}, nil
}
func (coord *QueryCoordMock) DescribeResourceGroup(ctx context.Context, req *querypb.DescribeResourceGroupRequest) (*querypb.DescribeResourceGroupResponse, error) {
if req.GetResourceGroup() == "rg" {
return &querypb.DescribeResourceGroupResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResourceGroup: &querypb.ResourceGroupInfo{
Name: "rg",
Capacity: 2,
NumAvailableNode: 1,
NumOutgoingNode: map[int64]int32{1: 1},
NumIncomingNode: map[int64]int32{2: 2},
},
}, nil
}
return &querypb.DescribeResourceGroupResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "",
},
}, nil
}
func NewQueryCoordMock(opts ...QueryCoordMockOption) *QueryCoordMock {
coord := &QueryCoordMock{
nodeID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
address: funcutil.GenRandomStr(), // TODO(dragondriver): random address
state: atomic.Value{},
collectionIDs: make([]int64, 0),
inMemoryPercentages: make([]int64, 0),
colMtx: sync.RWMutex{},
statisticsChannel: funcutil.GenRandomStr(),
timeTickChannel: funcutil.GenRandomStr(),
}
for _, opt := range opts {
opt(coord)
}
return coord
}
......@@ -23,8 +23,10 @@ import (
"testing"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
......@@ -50,7 +52,13 @@ func TestGetIndexStateTask_Execute(t *testing.T) {
ctx := context.Background()
rootCoord := newMockRootCoord()
queryCoord := NewQueryCoordMock()
queryCoord := getMockQueryCoord()
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{},
}, nil)
datacoord := NewDataCoordMock()
gist := &getIndexStateTask{
......@@ -108,18 +116,15 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
indexName := "_default_idx_101"
Params.Init()
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: nil,
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc := getMockQueryCoord()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{},
}, nil)
dc := NewDataCoordMock()
ctx := context.Background()
qc.updateState(commonpb.StateCode_Healthy)
mockCache := newMockCache()
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
......@@ -168,16 +173,13 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
globalMetaCache = mockCache
t.Run("coll has been loaded", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{collectionID},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
qc := getMockQueryCoord()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{collectionID},
}, nil)
dit.queryCoord = qc
err := dit.PreExecute(ctx)
......@@ -185,11 +187,8 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
})
t.Run("show collection error", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return nil, errors.New("error")
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
qc := getMockQueryCoord()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(nil, errors.New("error"))
dit.queryCoord = qc
err := dit.PreExecute(ctx)
......@@ -197,16 +196,13 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
})
t.Run("show collection fail", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
qc := getMockQueryCoord()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
},
}, nil)
dit.queryCoord = qc
err := dit.PreExecute(ctx)
......@@ -214,6 +210,23 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
})
}
func getMockQueryCoord() *types.MockQueryCoord {
qc := &types.MockQueryCoord{}
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil)
return qc
}
func TestCreateIndexTask_PreExecute(t *testing.T) {
collectionName := "collection1"
collectionID := UniqueID(1)
......@@ -247,16 +260,13 @@ func TestCreateIndexTask_PreExecute(t *testing.T) {
}
t.Run("normal", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
qc := getMockQueryCoord()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{},
}, nil)
cit.queryCoord = qc
err := cit.PreExecute(ctx)
......
......@@ -9,6 +9,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
......@@ -34,7 +35,7 @@ func TestQueryTask_all(t *testing.T) {
ctx = context.TODO()
rc = NewRootCoordMock()
qc = NewQueryCoordMock(withValidShardLeaders())
qc = types.NewMockQueryCoord(t)
qn = &QueryNodeMock{}
shardsNum = int32(2)
......@@ -48,6 +49,21 @@ func TestQueryTask_all(t *testing.T) {
}
)
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().Start().Return(nil)
qc.EXPECT().Stop().Return(nil)
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil)
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) {
return qn, nil
}
......@@ -97,6 +113,12 @@ func TestQueryTask_all(t *testing.T) {
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &successStatus,
CollectionIDs: []int64{collectionID},
InMemoryPercentages: []int64{100},
}, nil)
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
......
......@@ -14,6 +14,7 @@ import (
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/common"
......@@ -117,11 +118,12 @@ func TestSearchTask_PreExecute(t *testing.T) {
var (
rc = NewRootCoordMock()
qc = NewQueryCoordMock()
qc = getQueryCoord()
ctx = context.TODO()
collectionName = t.Name() + funcutil.GenRandomStr()
)
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
err = rc.Start()
defer rc.Stop()
......@@ -234,21 +236,15 @@ func TestSearchTask_PreExecute(t *testing.T) {
task.collectionName = collName
t.Run("show collection status unexpected error", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil
})
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil).Times(1)
assert.Error(t, task.PreExecute(ctx))
qc.ResetShowCollectionsFunc()
})
qc.ResetShowCollectionsFunc()
qc.ResetShowPartitionsFunc()
})
t.Run("search with timeout", func(t *testing.T) {
......@@ -256,6 +252,13 @@ func TestSearchTask_PreExecute(t *testing.T) {
createColl(t, collName, rc)
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &successStatus,
CollectionIDs: []int64{collID},
InMemoryPercentages: []int64{100},
}, nil)
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
......@@ -288,13 +291,20 @@ func TestSearchTask_PreExecute(t *testing.T) {
})
}
func getQueryCoord() *types.MockQueryCoord {
qc := &types.MockQueryCoord{}
qc.EXPECT().Start().Return(nil)
qc.EXPECT().Stop().Return(nil)
return qc
}
func TestSearchTaskV2_Execute(t *testing.T) {
var (
err error
rc = NewRootCoordMock()
qc = NewQueryCoordMock()
qc = getQueryCoord()
ctx = context.TODO()
collectionName = t.Name() + funcutil.GenRandomStr()
......@@ -1589,10 +1599,8 @@ func Test_checkIfLoaded(t *testing.T) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock")
})
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(nil, errors.New("mock")).Times(1)
_, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2})
assert.Error(t, err)
})
......@@ -1603,10 +1611,8 @@ func Test_checkIfLoaded(t *testing.T) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}}, nil
})
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}}, nil).Times(1)
_, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2})
assert.Error(t, err)
})
......@@ -1617,10 +1623,9 @@ func Test_checkIfLoaded(t *testing.T) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, InMemoryPercentages: []int64{100, 100}}, nil
})
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(
&querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, InMemoryPercentages: []int64{100, 100}}, nil).Times(1)
loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2})
assert.NoError(t, err)
assert.True(t, loaded)
......@@ -1632,10 +1637,9 @@ func Test_checkIfLoaded(t *testing.T) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, InMemoryPercentages: []int64{100, 50}}, nil
})
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(
&querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, InMemoryPercentages: []int64{100, 50}}, nil).Times(1)
loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2})
assert.NoError(t, err)
assert.False(t, loaded)
......@@ -1647,10 +1651,8 @@ func Test_checkIfLoaded(t *testing.T) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock")
})
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(nil, errors.New("mock")).Times(1)
_, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2})
assert.Error(t, err)
})
......@@ -1661,10 +1663,8 @@ func Test_checkIfLoaded(t *testing.T) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}}, nil
})
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}}, nil).Times(1)
_, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2})
assert.Error(t, err)
})
......@@ -1675,10 +1675,9 @@ func Test_checkIfLoaded(t *testing.T) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, PartitionIDs: []UniqueID{1, 2}}, nil
})
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(
&querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, PartitionIDs: []UniqueID{1, 2}}, nil).Times(1)
loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{})
assert.NoError(t, err)
assert.False(t, loaded)
......@@ -1690,10 +1689,9 @@ func Test_checkIfLoaded(t *testing.T) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, PartitionIDs: []UniqueID{}}, nil
})
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(
&querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, PartitionIDs: []UniqueID{}}, nil).Times(1)
loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{})
assert.NoError(t, err)
assert.False(t, loaded)
......@@ -1707,7 +1705,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
ctx = context.TODO()
rc = NewRootCoordMock()
qc = NewQueryCoordMock(withValidShardLeaders())
qc = getQueryCoord()
qn = &QueryNodeMock{}
shardsNum = int32(2)
......@@ -1766,6 +1764,23 @@ func TestSearchTask_ErrExecute(t *testing.T) {
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: successStatus,
CollectionIDs: []int64{collectionID},
InMemoryPercentages: []int64{100},
}, nil)
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
......
......@@ -753,7 +753,7 @@ func TestHasCollectionTask(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -838,7 +838,7 @@ func TestDescribeCollectionTask(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -900,7 +900,7 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -964,7 +964,7 @@ func TestDescribeCollectionTask_ShardsNum2(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -1081,18 +1081,17 @@ func TestDropPartitionTask(t *testing.T) {
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
partitionName := prefix + funcutil.GenRandomStr()
showPartitionsMock := func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
PartitionIDs: []int64{},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowPartitionsFunc(showPartitionsMock))
qc.updateState(commonpb.StateCode_Healthy)
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
PartitionIDs: []int64{},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
mockCache := newMockCache()
mockCache.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) {
return 1, nil
......@@ -1286,7 +1285,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.Start()
defer qc.Stop()
......@@ -1537,7 +1536,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.Start()
defer qc.Stop()
......@@ -2202,16 +2201,17 @@ func Test_createIndexTask_PreExecute(t *testing.T) {
FieldName: fieldName,
},
}
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
PartitionIDs: []int64{},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
cit.queryCoord = qc
t.Run("normal", func(t *testing.T) {
......@@ -2348,9 +2348,20 @@ func Test_dropCollectionTask_PostExecute(t *testing.T) {
func Test_loadCollectionTask_Execute(t *testing.T) {
rc := newMockRootCoord()
qc := NewQueryCoordMock(withValidShardLeaders())
dc := NewDataCoordMock()
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
PartitionIDs: []int64{},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
dbName := funcutil.GenRandomStr()
collectionName := funcutil.GenRandomStr()
collectionID := UniqueID(1)
......@@ -2445,9 +2456,20 @@ func Test_loadCollectionTask_Execute(t *testing.T) {
func Test_loadPartitionTask_Execute(t *testing.T) {
rc := newMockRootCoord()
qc := NewQueryCoordMock(withValidShardLeaders())
dc := NewDataCoordMock()
qc := getQueryCoord()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
PartitionIDs: []int64{},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
dbName := funcutil.GenRandomStr()
collectionName := funcutil.GenRandomStr()
collectionID := UniqueID(1)
......@@ -2544,7 +2566,8 @@ func TestCreateResourceGroupTask(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -2583,7 +2606,8 @@ func TestDropResourceGroupTask(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -2622,7 +2646,8 @@ func TestTransferNodeTask(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -2661,7 +2686,8 @@ func TestTransferNodeTask(t *testing.T) {
func TestTransferReplicaTask(t *testing.T) {
rc := &MockRootCoordClientInterface{}
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -2703,7 +2729,11 @@ func TestTransferReplicaTask(t *testing.T) {
func TestListResourceGroupsTask(t *testing.T) {
rc := &MockRootCoordClientInterface{}
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
ResourceGroups: []string{meta.DefaultResourceGroupName, "rg"},
}, nil)
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -2742,7 +2772,17 @@ func TestListResourceGroupsTask(t *testing.T) {
func TestDescribeResourceGroupTask(t *testing.T) {
rc := &MockRootCoordClientInterface{}
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
ResourceGroup: &querypb.ResourceGroupInfo{
Name: "rg",
Capacity: 2,
NumAvailableNode: 1,
NumOutgoingNode: map[int64]int32{1: 1},
NumIncomingNode: map[int64]int32{2: 2},
},
}, nil)
qc.Start()
defer qc.Stop()
ctx := context.Background()
......@@ -2787,7 +2827,10 @@ func TestDescribeResourceGroupTask(t *testing.T) {
func TestDescribeResourceGroupTaskFailed(t *testing.T) {
rc := &MockRootCoordClientInterface{}
qc := NewQueryCoordMock()
qc := getQueryCoord()
qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
}, nil)
qc.Start()
defer qc.Stop()
ctx := context.Background()
......
......@@ -26,6 +26,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
......@@ -34,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/util/paramtable"
......@@ -838,17 +840,23 @@ func Test_isCollectionIsLoaded(t *testing.T) {
ctx := context.Background()
t.Run("normal", func(t *testing.T) {
collID := int64(1)
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
qc := &types.MockQueryCoord{}
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
CollectionIDs: []int64{collID, 10, 100},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: successStatus,
CollectionIDs: []int64{collID, 10, 100},
}, nil)
loaded, err := isCollectionLoaded(ctx, qc, collID)
assert.NoError(t, err)
assert.True(t, loaded)
......@@ -856,17 +864,23 @@ func Test_isCollectionIsLoaded(t *testing.T) {
t.Run("error", func(t *testing.T) {
collID := int64(1)
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
qc := &types.MockQueryCoord{}
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
CollectionIDs: []int64{collID},
}, errors.New("error")
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: successStatus,
CollectionIDs: []int64{collID},
}, errors.New("error"))
loaded, err := isCollectionLoaded(ctx, qc, collID)
assert.Error(t, err)
assert.False(t, loaded)
......@@ -874,17 +888,26 @@ func Test_isCollectionIsLoaded(t *testing.T) {
t.Run("fail", func(t *testing.T) {
collID := int64(1)
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
qc := &types.MockQueryCoord{}
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
CollectionIDs: []int64{collID},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
},
CollectionIDs: []int64{collID},
}, nil)
loaded, err := isCollectionLoaded(ctx, qc, collID)
assert.Error(t, err)
assert.False(t, loaded)
......@@ -896,17 +919,26 @@ func Test_isPartitionIsLoaded(t *testing.T) {
t.Run("normal", func(t *testing.T) {
collID := int64(1)
partID := int64(2)
showPartitionsMock := func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
qc := &types.MockQueryCoord{}
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
PartitionIDs: []int64{partID},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowPartitionsFunc(showPartitionsMock))
qc.updateState(commonpb.StateCode_Healthy)
},
}, nil)
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
PartitionIDs: []int64{partID},
}, nil)
loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID})
assert.NoError(t, err)
assert.True(t, loaded)
......@@ -915,17 +947,26 @@ func Test_isPartitionIsLoaded(t *testing.T) {
t.Run("error", func(t *testing.T) {
collID := int64(1)
partID := int64(2)
showPartitionsMock := func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
qc := &types.MockQueryCoord{}
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
PartitionIDs: []int64{partID},
}, errors.New("error")
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowPartitionsFunc(showPartitionsMock))
qc.updateState(commonpb.StateCode_Healthy)
},
}, nil)
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
PartitionIDs: []int64{partID},
}, errors.New("error"))
loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID})
assert.Error(t, err)
assert.False(t, loaded)
......@@ -934,17 +975,26 @@ func Test_isPartitionIsLoaded(t *testing.T) {
t.Run("fail", func(t *testing.T) {
collID := int64(1)
partID := int64(2)
showPartitionsMock := func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
qc := &types.MockQueryCoord{}
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
PartitionIDs: []int64{partID},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowPartitionsFunc(showPartitionsMock))
qc.updateState(commonpb.StateCode_Healthy)
},
}, nil)
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
},
PartitionIDs: []int64{partID},
}, nil)
loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID})
assert.Error(t, err)
assert.False(t, loaded)
......
......@@ -30,7 +30,6 @@ import (
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/paramtable"
......@@ -39,11 +38,6 @@ import (
var embedetcdServer *embed.Etcd
// mock of query coordinator client
type queryCoordMock struct {
types.QueryCoord
}
func setup() {
os.Setenv("QUERY_NODE_ID", "1")
paramtable.Init()
......
......@@ -7,6 +7,7 @@ import (
"os"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
......@@ -420,68 +421,74 @@ func withQueryCoord(qc types.QueryCoord) Opt {
}
func withUnhealthyQueryCoord() Opt {
qc := newMockQueryCoord()
qc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) {
return &milvuspb.ComponentStates{
qc := &types.MockQueryCoord{}
qc.EXPECT().GetComponentStates(mock.Anything).Return(
&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Abnormal},
Status: failStatus(commonpb.ErrorCode_UnexpectedError, "error mock GetComponentStates"),
}, retry.Unrecoverable(errors.New("error mock GetComponentStates"))
}
}, retry.Unrecoverable(errors.New("error mock GetComponentStates")),
)
return withQueryCoord(qc)
}
func withInvalidQueryCoord() Opt {
qc := newMockQueryCoord()
qc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) {
return &milvuspb.ComponentStates{
qc := &types.MockQueryCoord{}
qc.EXPECT().GetComponentStates(mock.Anything).Return(
&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy},
Status: succStatus(),
}, nil
}
qc.ReleaseCollectionFunc = func(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return nil, errors.New("error mock ReleaseCollection")
}
qc.GetSegmentInfoFunc = func(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return nil, errors.New("error mock GetSegmentInfo")
}
}, nil,
)
qc.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return(
nil, errors.New("error mock ReleaseCollection"),
)
qc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return(
nil, errors.New("error mock GetSegmentInfo"),
)
return withQueryCoord(qc)
}
func withFailedQueryCoord() Opt {
qc := newMockQueryCoord()
qc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) {
return &milvuspb.ComponentStates{
qc := &types.MockQueryCoord{}
qc.EXPECT().GetComponentStates(mock.Anything).Return(
&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy},
Status: succStatus(),
}, nil
}
qc.ReleaseCollectionFunc = func(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return failStatus(commonpb.ErrorCode_UnexpectedError, "mock release collection error"), nil
}
qc.GetSegmentInfoFunc = func(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return &querypb.GetSegmentInfoResponse{
}, nil,
)
qc.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return(
failStatus(commonpb.ErrorCode_UnexpectedError, "mock release collection error"), nil,
)
qc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return(
&querypb.GetSegmentInfoResponse{
Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock get segment info error"),
}, nil
}
}, nil,
)
return withQueryCoord(qc)
}
func withValidQueryCoord() Opt {
qc := newMockQueryCoord()
qc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) {
return &milvuspb.ComponentStates{
qc := &types.MockQueryCoord{}
qc.EXPECT().GetComponentStates(mock.Anything).Return(
&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy},
Status: succStatus(),
}, nil
}
qc.ReleaseCollectionFunc = func(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return succStatus(), nil
}
qc.GetSegmentInfoFunc = func(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return &querypb.GetSegmentInfoResponse{
}, nil,
)
qc.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return(
succStatus(), nil,
)
qc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return(
&querypb.GetSegmentInfoResponse{
Status: succStatus(),
}, nil
}
}, nil,
)
return withQueryCoord(qc)
}
......
......@@ -48,8 +48,8 @@ type IMetaTable_AddCollection_Call struct {
}
// AddCollection is a helper method to define mock.On call
// - ctx context.Context
// - coll *model.Collection
// - ctx context.Context
// - coll *model.Collection
func (_e *IMetaTable_Expecter) AddCollection(ctx interface{}, coll interface{}) *IMetaTable_AddCollection_Call {
return &IMetaTable_AddCollection_Call{Call: _e.mock.On("AddCollection", ctx, coll)}
}
......@@ -86,7 +86,7 @@ type IMetaTable_AddCredential_Call struct {
}
// AddCredential is a helper method to define mock.On call
// - credInfo *internalpb.CredentialInfo
// - credInfo *internalpb.CredentialInfo
func (_e *IMetaTable_Expecter) AddCredential(credInfo interface{}) *IMetaTable_AddCredential_Call {
return &IMetaTable_AddCredential_Call{Call: _e.mock.On("AddCredential", credInfo)}
}
......@@ -123,8 +123,8 @@ type IMetaTable_AddPartition_Call struct {
}
// AddPartition is a helper method to define mock.On call
// - ctx context.Context
// - partition *model.Partition
// - ctx context.Context
// - partition *model.Partition
func (_e *IMetaTable_Expecter) AddPartition(ctx interface{}, partition interface{}) *IMetaTable_AddPartition_Call {
return &IMetaTable_AddPartition_Call{Call: _e.mock.On("AddPartition", ctx, partition)}
}
......@@ -161,10 +161,10 @@ type IMetaTable_AlterAlias_Call struct {
}
// AlterAlias is a helper method to define mock.On call
// - ctx context.Context
// - alias string
// - collectionName string
// - ts uint64
// - ctx context.Context
// - alias string
// - collectionName string
// - ts uint64
func (_e *IMetaTable_Expecter) AlterAlias(ctx interface{}, alias interface{}, collectionName interface{}, ts interface{}) *IMetaTable_AlterAlias_Call {
return &IMetaTable_AlterAlias_Call{Call: _e.mock.On("AlterAlias", ctx, alias, collectionName, ts)}
}
......@@ -201,10 +201,10 @@ type IMetaTable_AlterCollection_Call struct {
}
// AlterCollection is a helper method to define mock.On call
// - ctx context.Context
// - oldColl *model.Collection
// - newColl *model.Collection
// - ts uint64
// - ctx context.Context
// - oldColl *model.Collection
// - newColl *model.Collection
// - ts uint64
func (_e *IMetaTable_Expecter) AlterCollection(ctx interface{}, oldColl interface{}, newColl interface{}, ts interface{}) *IMetaTable_AlterCollection_Call {
return &IMetaTable_AlterCollection_Call{Call: _e.mock.On("AlterCollection", ctx, oldColl, newColl, ts)}
}
......@@ -241,7 +241,7 @@ type IMetaTable_AlterCredential_Call struct {
}
// AlterCredential is a helper method to define mock.On call
// - credInfo *internalpb.CredentialInfo
// - credInfo *internalpb.CredentialInfo
func (_e *IMetaTable_Expecter) AlterCredential(credInfo interface{}) *IMetaTable_AlterCredential_Call {
return &IMetaTable_AlterCredential_Call{Call: _e.mock.On("AlterCredential", credInfo)}
}
......@@ -278,10 +278,10 @@ type IMetaTable_ChangeCollectionState_Call struct {
}
// ChangeCollectionState is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - state etcdpb.CollectionState
// - ts uint64
// - ctx context.Context
// - collectionID int64
// - state etcdpb.CollectionState
// - ts uint64
func (_e *IMetaTable_Expecter) ChangeCollectionState(ctx interface{}, collectionID interface{}, state interface{}, ts interface{}) *IMetaTable_ChangeCollectionState_Call {
return &IMetaTable_ChangeCollectionState_Call{Call: _e.mock.On("ChangeCollectionState", ctx, collectionID, state, ts)}
}
......@@ -318,11 +318,11 @@ type IMetaTable_ChangePartitionState_Call struct {
}
// ChangePartitionState is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - partitionID int64
// - state etcdpb.PartitionState
// - ts uint64
// - ctx context.Context
// - collectionID int64
// - partitionID int64
// - state etcdpb.PartitionState
// - ts uint64
func (_e *IMetaTable_Expecter) ChangePartitionState(ctx interface{}, collectionID interface{}, partitionID interface{}, state interface{}, ts interface{}) *IMetaTable_ChangePartitionState_Call {
return &IMetaTable_ChangePartitionState_Call{Call: _e.mock.On("ChangePartitionState", ctx, collectionID, partitionID, state, ts)}
}
......@@ -359,10 +359,10 @@ type IMetaTable_CreateAlias_Call struct {
}
// CreateAlias is a helper method to define mock.On call
// - ctx context.Context
// - alias string
// - collectionName string
// - ts uint64
// - ctx context.Context
// - alias string
// - collectionName string
// - ts uint64
func (_e *IMetaTable_Expecter) CreateAlias(ctx interface{}, alias interface{}, collectionName interface{}, ts interface{}) *IMetaTable_CreateAlias_Call {
return &IMetaTable_CreateAlias_Call{Call: _e.mock.On("CreateAlias", ctx, alias, collectionName, ts)}
}
......@@ -399,8 +399,8 @@ type IMetaTable_CreateRole_Call struct {
}
// CreateRole is a helper method to define mock.On call
// - tenant string
// - entity *milvuspb.RoleEntity
// - tenant string
// - entity *milvuspb.RoleEntity
func (_e *IMetaTable_Expecter) CreateRole(tenant interface{}, entity interface{}) *IMetaTable_CreateRole_Call {
return &IMetaTable_CreateRole_Call{Call: _e.mock.On("CreateRole", tenant, entity)}
}
......@@ -437,7 +437,7 @@ type IMetaTable_DeleteCredential_Call struct {
}
// DeleteCredential is a helper method to define mock.On call
// - username string
// - username string
func (_e *IMetaTable_Expecter) DeleteCredential(username interface{}) *IMetaTable_DeleteCredential_Call {
return &IMetaTable_DeleteCredential_Call{Call: _e.mock.On("DeleteCredential", username)}
}
......@@ -474,9 +474,9 @@ type IMetaTable_DropAlias_Call struct {
}
// DropAlias is a helper method to define mock.On call
// - ctx context.Context
// - alias string
// - ts uint64
// - ctx context.Context
// - alias string
// - ts uint64
func (_e *IMetaTable_Expecter) DropAlias(ctx interface{}, alias interface{}, ts interface{}) *IMetaTable_DropAlias_Call {
return &IMetaTable_DropAlias_Call{Call: _e.mock.On("DropAlias", ctx, alias, ts)}
}
......@@ -513,8 +513,8 @@ type IMetaTable_DropGrant_Call struct {
}
// DropGrant is a helper method to define mock.On call
// - tenant string
// - role *milvuspb.RoleEntity
// - tenant string
// - role *milvuspb.RoleEntity
func (_e *IMetaTable_Expecter) DropGrant(tenant interface{}, role interface{}) *IMetaTable_DropGrant_Call {
return &IMetaTable_DropGrant_Call{Call: _e.mock.On("DropGrant", tenant, role)}
}
......@@ -551,8 +551,8 @@ type IMetaTable_DropRole_Call struct {
}
// DropRole is a helper method to define mock.On call
// - tenant string
// - roleName string
// - tenant string
// - roleName string
func (_e *IMetaTable_Expecter) DropRole(tenant interface{}, roleName interface{}) *IMetaTable_DropRole_Call {
return &IMetaTable_DropRole_Call{Call: _e.mock.On("DropRole", tenant, roleName)}
}
......@@ -598,10 +598,10 @@ type IMetaTable_GetCollectionByID_Call struct {
}
// GetCollectionByID is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - ts uint64
// - allowUnavailable bool
// - ctx context.Context
// - collectionID int64
// - ts uint64
// - allowUnavailable bool
func (_e *IMetaTable_Expecter) GetCollectionByID(ctx interface{}, collectionID interface{}, ts interface{}, allowUnavailable interface{}) *IMetaTable_GetCollectionByID_Call {
return &IMetaTable_GetCollectionByID_Call{Call: _e.mock.On("GetCollectionByID", ctx, collectionID, ts, allowUnavailable)}
}
......@@ -647,9 +647,9 @@ type IMetaTable_GetCollectionByName_Call struct {
}
// GetCollectionByName is a helper method to define mock.On call
// - ctx context.Context
// - collectionName string
// - ts uint64
// - ctx context.Context
// - collectionName string
// - ts uint64
func (_e *IMetaTable_Expecter) GetCollectionByName(ctx interface{}, collectionName interface{}, ts interface{}) *IMetaTable_GetCollectionByName_Call {
return &IMetaTable_GetCollectionByName_Call{Call: _e.mock.On("GetCollectionByName", ctx, collectionName, ts)}
}
......@@ -688,7 +688,7 @@ type IMetaTable_GetCollectionVirtualChannels_Call struct {
}
// GetCollectionVirtualChannels is a helper method to define mock.On call
// - colID int64
// - colID int64
func (_e *IMetaTable_Expecter) GetCollectionVirtualChannels(colID interface{}) *IMetaTable_GetCollectionVirtualChannels_Call {
return &IMetaTable_GetCollectionVirtualChannels_Call{Call: _e.mock.On("GetCollectionVirtualChannels", colID)}
}
......@@ -734,7 +734,7 @@ type IMetaTable_GetCredential_Call struct {
}
// GetCredential is a helper method to define mock.On call
// - username string
// - username string
func (_e *IMetaTable_Expecter) GetCredential(username interface{}) *IMetaTable_GetCredential_Call {
return &IMetaTable_GetCredential_Call{Call: _e.mock.On("GetCredential", username)}
}
......@@ -778,9 +778,9 @@ type IMetaTable_GetPartitionByName_Call struct {
}
// GetPartitionByName is a helper method to define mock.On call
// - collID int64
// - partitionName string
// - ts uint64
// - collID int64
// - partitionName string
// - ts uint64
func (_e *IMetaTable_Expecter) GetPartitionByName(collID interface{}, partitionName interface{}, ts interface{}) *IMetaTable_GetPartitionByName_Call {
return &IMetaTable_GetPartitionByName_Call{Call: _e.mock.On("GetPartitionByName", collID, partitionName, ts)}
}
......@@ -824,9 +824,9 @@ type IMetaTable_GetPartitionNameByID_Call struct {
}
// GetPartitionNameByID is a helper method to define mock.On call
// - collID int64
// - partitionID int64
// - ts uint64
// - collID int64
// - partitionID int64
// - ts uint64
func (_e *IMetaTable_Expecter) GetPartitionNameByID(collID interface{}, partitionID interface{}, ts interface{}) *IMetaTable_GetPartitionNameByID_Call {
return &IMetaTable_GetPartitionNameByID_Call{Call: _e.mock.On("GetPartitionNameByID", collID, partitionID, ts)}
}
......@@ -863,7 +863,7 @@ type IMetaTable_IsAlias_Call struct {
}
// IsAlias is a helper method to define mock.On call
// - name string
// - name string
func (_e *IMetaTable_Expecter) IsAlias(name interface{}) *IMetaTable_IsAlias_Call {
return &IMetaTable_IsAlias_Call{Call: _e.mock.On("IsAlias", name)}
}
......@@ -909,8 +909,8 @@ type IMetaTable_ListAbnormalCollections_Call struct {
}
// ListAbnormalCollections is a helper method to define mock.On call
// - ctx context.Context
// - ts uint64
// - ctx context.Context
// - ts uint64
func (_e *IMetaTable_Expecter) ListAbnormalCollections(ctx interface{}, ts interface{}) *IMetaTable_ListAbnormalCollections_Call {
return &IMetaTable_ListAbnormalCollections_Call{Call: _e.mock.On("ListAbnormalCollections", ctx, ts)}
}
......@@ -949,7 +949,7 @@ type IMetaTable_ListAliasesByID_Call struct {
}
// ListAliasesByID is a helper method to define mock.On call
// - collID int64
// - collID int64
func (_e *IMetaTable_Expecter) ListAliasesByID(collID interface{}) *IMetaTable_ListAliasesByID_Call {
return &IMetaTable_ListAliasesByID_Call{Call: _e.mock.On("ListAliasesByID", collID)}
}
......@@ -1033,8 +1033,8 @@ type IMetaTable_ListCollections_Call struct {
}
// ListCollections is a helper method to define mock.On call
// - ctx context.Context
// - ts uint64
// - ctx context.Context
// - ts uint64
func (_e *IMetaTable_Expecter) ListCollections(ctx interface{}, ts interface{}) *IMetaTable_ListCollections_Call {
return &IMetaTable_ListCollections_Call{Call: _e.mock.On("ListCollections", ctx, ts)}
}
......@@ -1125,7 +1125,7 @@ type IMetaTable_ListPolicy_Call struct {
}
// ListPolicy is a helper method to define mock.On call
// - tenant string
// - tenant string
func (_e *IMetaTable_Expecter) ListPolicy(tenant interface{}) *IMetaTable_ListPolicy_Call {
return &IMetaTable_ListPolicy_Call{Call: _e.mock.On("ListPolicy", tenant)}
}
......@@ -1171,7 +1171,7 @@ type IMetaTable_ListUserRole_Call struct {
}
// ListUserRole is a helper method to define mock.On call
// - tenant string
// - tenant string
func (_e *IMetaTable_Expecter) ListUserRole(tenant interface{}) *IMetaTable_ListUserRole_Call {
return &IMetaTable_ListUserRole_Call{Call: _e.mock.On("ListUserRole", tenant)}
}
......@@ -1208,9 +1208,9 @@ type IMetaTable_OperatePrivilege_Call struct {
}
// OperatePrivilege is a helper method to define mock.On call
// - tenant string
// - entity *milvuspb.GrantEntity
// - operateType milvuspb.OperatePrivilegeType
// - tenant string
// - entity *milvuspb.GrantEntity
// - operateType milvuspb.OperatePrivilegeType
func (_e *IMetaTable_Expecter) OperatePrivilege(tenant interface{}, entity interface{}, operateType interface{}) *IMetaTable_OperatePrivilege_Call {
return &IMetaTable_OperatePrivilege_Call{Call: _e.mock.On("OperatePrivilege", tenant, entity, operateType)}
}
......@@ -1247,10 +1247,10 @@ type IMetaTable_OperateUserRole_Call struct {
}
// OperateUserRole is a helper method to define mock.On call
// - tenant string
// - userEntity *milvuspb.UserEntity
// - roleEntity *milvuspb.RoleEntity
// - operateType milvuspb.OperateUserRoleType
// - tenant string
// - userEntity *milvuspb.UserEntity
// - roleEntity *milvuspb.RoleEntity
// - operateType milvuspb.OperateUserRoleType
func (_e *IMetaTable_Expecter) OperateUserRole(tenant interface{}, userEntity interface{}, roleEntity interface{}, operateType interface{}) *IMetaTable_OperateUserRole_Call {
return &IMetaTable_OperateUserRole_Call{Call: _e.mock.On("OperateUserRole", tenant, userEntity, roleEntity, operateType)}
}
......@@ -1287,9 +1287,9 @@ type IMetaTable_RemoveCollection_Call struct {
}
// RemoveCollection is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - ts uint64
// - ctx context.Context
// - collectionID int64
// - ts uint64
func (_e *IMetaTable_Expecter) RemoveCollection(ctx interface{}, collectionID interface{}, ts interface{}) *IMetaTable_RemoveCollection_Call {
return &IMetaTable_RemoveCollection_Call{Call: _e.mock.On("RemoveCollection", ctx, collectionID, ts)}
}
......@@ -1326,10 +1326,10 @@ type IMetaTable_RemovePartition_Call struct {
}
// RemovePartition is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - partitionID int64
// - ts uint64
// - ctx context.Context
// - collectionID int64
// - partitionID int64
// - ts uint64
func (_e *IMetaTable_Expecter) RemovePartition(ctx interface{}, collectionID interface{}, partitionID interface{}, ts interface{}) *IMetaTable_RemovePartition_Call {
return &IMetaTable_RemovePartition_Call{Call: _e.mock.On("RemovePartition", ctx, collectionID, partitionID, ts)}
}
......@@ -1366,10 +1366,10 @@ type IMetaTable_RenameCollection_Call struct {
}
// RenameCollection is a helper method to define mock.On call
// - ctx context.Context
// - oldName string
// - newName string
// - ts uint64
// - ctx context.Context
// - oldName string
// - newName string
// - ts uint64
func (_e *IMetaTable_Expecter) RenameCollection(ctx interface{}, oldName interface{}, newName interface{}, ts interface{}) *IMetaTable_RenameCollection_Call {
return &IMetaTable_RenameCollection_Call{Call: _e.mock.On("RenameCollection", ctx, oldName, newName, ts)}
}
......@@ -1415,8 +1415,8 @@ type IMetaTable_SelectGrant_Call struct {
}
// SelectGrant is a helper method to define mock.On call
// - tenant string
// - entity *milvuspb.GrantEntity
// - tenant string
// - entity *milvuspb.GrantEntity
func (_e *IMetaTable_Expecter) SelectGrant(tenant interface{}, entity interface{}) *IMetaTable_SelectGrant_Call {
return &IMetaTable_SelectGrant_Call{Call: _e.mock.On("SelectGrant", tenant, entity)}
}
......@@ -1462,9 +1462,9 @@ type IMetaTable_SelectRole_Call struct {
}
// SelectRole is a helper method to define mock.On call
// - tenant string
// - entity *milvuspb.RoleEntity
// - includeUserInfo bool
// - tenant string
// - entity *milvuspb.RoleEntity
// - includeUserInfo bool
func (_e *IMetaTable_Expecter) SelectRole(tenant interface{}, entity interface{}, includeUserInfo interface{}) *IMetaTable_SelectRole_Call {
return &IMetaTable_SelectRole_Call{Call: _e.mock.On("SelectRole", tenant, entity, includeUserInfo)}
}
......@@ -1510,9 +1510,9 @@ type IMetaTable_SelectUser_Call struct {
}
// SelectUser is a helper method to define mock.On call
// - tenant string
// - entity *milvuspb.UserEntity
// - includeRoleInfo bool
// - tenant string
// - entity *milvuspb.UserEntity
// - includeRoleInfo bool
func (_e *IMetaTable_Expecter) SelectUser(tenant interface{}, entity interface{}, includeRoleInfo interface{}) *IMetaTable_SelectUser_Call {
return &IMetaTable_SelectUser_Call{Call: _e.mock.On("SelectUser", tenant, entity, includeRoleInfo)}
}
......
......@@ -24,42 +24,24 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type queryCoordMockForQuota struct {
mockQueryCoord
retErr bool
retFailStatus bool
}
type dataCoordMockForQuota struct {
mockDataCoord
retErr bool
retFailStatus bool
}
func (q *queryCoordMockForQuota) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
if q.retErr {
return nil, fmt.Errorf("mock err")
}
if q.retFailStatus {
return &milvuspb.GetMetricsResponse{
Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock failure status"),
}, nil
}
return &milvuspb.GetMetricsResponse{
Status: succStatus(),
}, nil
}
func (d *dataCoordMockForQuota) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
if d.retErr {
return nil, fmt.Errorf("mock err")
......@@ -85,36 +67,44 @@ func TestQuotaCenter(t *testing.T) {
pcm := newProxyClientManager(core.proxyCreator)
t.Run("test QuotaCenter", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
go quotaCenter.run()
time.Sleep(10 * time.Millisecond)
quotaCenter.stop()
})
t.Run("test syncMetrics", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{Status: succStatus()}, nil)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
err = quotaCenter.syncMetrics()
assert.Error(t, err) // for empty response
quotaCenter = NewQuotaCenter(pcm, &queryCoordMockForQuota{retErr: true}, &dataCoordMockForQuota{}, core.tsoAllocator)
quotaCenter = NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
err = quotaCenter.syncMetrics()
assert.Error(t, err)
quotaCenter = NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{retErr: true}, core.tsoAllocator)
quotaCenter = NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{retFailStatus: true}, core.tsoAllocator)
err = quotaCenter.syncMetrics()
assert.Error(t, err)
quotaCenter = NewQuotaCenter(pcm, &queryCoordMockForQuota{retFailStatus: true}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock err"))
quotaCenter = NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{retErr: true}, core.tsoAllocator)
err = quotaCenter.syncMetrics()
assert.Error(t, err)
quotaCenter = NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{retFailStatus: true}, core.tsoAllocator)
qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock failure status"),
}, nil)
quotaCenter = NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
err = quotaCenter.syncMetrics()
assert.Error(t, err)
})
t.Run("test forceDeny", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
quotaCenter.forceDenyReading(commonpb.ErrorCode_ForceDeny)
assert.Equal(t, Limit(0), quotaCenter.currentRates[internalpb.RateType_DQLQuery])
assert.Equal(t, Limit(0), quotaCenter.currentRates[internalpb.RateType_DQLQuery])
......@@ -124,7 +114,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test calculateRates", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
err = quotaCenter.calculateRates()
assert.NoError(t, err)
alloc := newMockTsoAllocator()
......@@ -137,8 +128,9 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test getTimeTickDelayFactor", func(t *testing.T) {
qc := types.NewMockQueryCoord(t)
// test MaxTimestamp
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
factor := quotaCenter.getTimeTickDelayFactor(0)
assert.Equal(t, float64(1), factor)
......@@ -187,7 +179,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test getTimeTickDelayFactor factors", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
type ttCase struct {
maxTtDelay time.Duration
curTt time.Time
......@@ -233,7 +226,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test getNQInQueryFactor", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
factor := quotaCenter.getNQInQueryFactor()
assert.Equal(t, float64(1), factor)
......@@ -259,7 +253,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test getQueryLatencyFactor", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
factor := quotaCenter.getQueryLatencyFactor()
assert.Equal(t, float64(1), factor)
......@@ -284,7 +279,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test checkReadResult", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
factor := quotaCenter.getReadResultFactor()
assert.Equal(t, float64(1), factor)
......@@ -309,7 +305,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test calculateReadRates", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
quotaCenter.proxyMetrics = map[UniqueID]*metricsinfo.ProxyQuotaMetrics{
1: {Rms: []metricsinfo.RateMetric{
{Label: internalpb.RateType_DQLSearch.String(), Rate: 100},
......@@ -351,7 +348,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test calculateWriteRates", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
err = quotaCenter.calculateWriteRates()
assert.NoError(t, err)
......@@ -376,7 +374,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test getMemoryFactor basic", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
factor := quotaCenter.getMemoryFactor()
assert.Equal(t, float64(1), factor)
quotaCenter.dataNodeMetrics = map[UniqueID]*metricsinfo.DataNodeQuotaMetrics{1: {Hms: metricsinfo.HardwareMetrics{MemoryUsage: 100, Memory: 100}}}
......@@ -388,7 +387,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test getMemoryFactor factors", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
type memCase struct {
lowWater float64
highWater float64
......@@ -428,7 +428,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test ifDiskQuotaExceeded", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
paramtable.Get().Save(Params.QuotaConfig.DiskProtectionEnabled.Key, "false")
ok := quotaCenter.ifDiskQuotaExceeded()
......@@ -449,7 +450,8 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test setRates", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
quotaCenter.currentRates[internalpb.RateType_DMLInsert] = 100
quotaCenter.quotaStates[milvuspb.QuotaState_DenyToWrite] = commonpb.ErrorCode_MemoryQuotaExhausted
quotaCenter.quotaStates[milvuspb.QuotaState_DenyToRead] = commonpb.ErrorCode_ForceDeny
......@@ -458,14 +460,16 @@ func TestQuotaCenter(t *testing.T) {
})
t.Run("test recordMetrics", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
quotaCenter.quotaStates[milvuspb.QuotaState_DenyToWrite] = commonpb.ErrorCode_MemoryQuotaExhausted
quotaCenter.quotaStates[milvuspb.QuotaState_DenyToRead] = commonpb.ErrorCode_ForceDeny
quotaCenter.recordMetrics()
})
t.Run("test guaranteeMinRate", func(t *testing.T) {
quotaCenter := NewQuotaCenter(pcm, &queryCoordMockForQuota{}, &dataCoordMockForQuota{}, core.tsoAllocator)
qc := types.NewMockQueryCoord(t)
quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator)
minRate := Limit(100)
quotaCenter.currentRates[internalpb.RateType_DQLSearch] = Limit(50)
quotaCenter.guaranteeMinRate(float64(minRate), internalpb.RateType_DQLSearch)
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册