未验证 提交 7dfab0aa 编写于 作者: Y yah01 提交者: GitHub

Add unit test for case of failed to sync segments to shard leader (#16712)

Signed-off-by: Nyah01 <yang.cen@zilliz.com>
上级 97757405
......@@ -94,6 +94,21 @@ func waitLoadCollectionDone(ctx context.Context, queryCoord *QueryCoord, collect
return nil
}
func waitLoadCollectionRollbackDone(queryCoord *QueryCoord, collectionID UniqueID) bool {
maxRetryNum := 100
for cnt := 0; cnt < maxRetryNum; cnt++ {
_, err := queryCoord.meta.getCollectionInfoByID(collectionID)
if err != nil {
return true
}
log.Debug("waiting for rollback done...")
time.Sleep(100 * time.Millisecond)
}
return false
}
func TestGrpcTask(t *testing.T) {
refreshParams()
ctx := context.Background()
......@@ -1024,6 +1039,43 @@ func TestLoadPartitionsWithReplicas(t *testing.T) {
assert.Nil(t, err)
}
func TestLoadCollectionSyncSegmentsFail(t *testing.T) {
refreshParams()
ctx := context.Background()
defer removeAllSession()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
defer queryCoord.Stop()
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
defer node1.stop()
node1.syncReplicaSegments = returnFailedResult
// Failed to sync segments should cause rollback
loadCollectionReq := &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genDefaultCollectionSchema(false),
ReplicaNumber: 1,
}
status, err := queryCoord.LoadCollection(ctx, loadCollectionReq)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
// Wait for rollback done
rollbackDone := waitLoadCollectionRollbackDone(queryCoord, loadCollectionReq.CollectionID)
assert.True(t, rollbackDone)
assert.NoError(t, node1.stop())
assert.NoError(t, queryCoord.Stop())
assert.NoError(t, removeAllSession())
}
func Test_RepeatedLoadSamePartitions(t *testing.T) {
refreshParams()
ctx := context.Background()
......
......@@ -71,6 +71,7 @@ type queryNodeServerMock struct {
releaseCollection rpcHandler
releasePartition rpcHandler
releaseSegments rpcHandler
syncReplicaSegments rpcHandler
getSegmentInfos func() (*querypb.GetSegmentInfoResponse, error)
getMetrics func() (*milvuspb.GetMetricsResponse, error)
......@@ -95,6 +96,7 @@ func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock {
releaseCollection: returnSuccessResult,
releasePartition: returnSuccessResult,
releaseSegments: returnSuccessResult,
syncReplicaSegments: returnSuccessResult,
getSegmentInfos: returnSuccessGetSegmentInfoResult,
getMetrics: returnSuccessGetMetricsResult,
......@@ -273,9 +275,7 @@ func (qs *queryNodeServerMock) GetSegmentInfo(ctx context.Context, req *querypb.
}
func (qs *queryNodeServerMock) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
return qs.syncReplicaSegments()
}
func (qs *queryNodeServerMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
......
......@@ -480,8 +480,11 @@ func Test_LoadCollectionExecuteFail(t *testing.T) {
func TestLoadCollectionNoEnoughNodeFail(t *testing.T) {
refreshParams()
ctx := context.Background()
defer removeAllSession()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
defer queryCoord.Stop()
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
......@@ -489,16 +492,13 @@ func TestLoadCollectionNoEnoughNodeFail(t *testing.T) {
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID)
defer node1.stop()
defer node2.stop()
loadCollectionTask := genLoadCollectionTask(ctx, queryCoord)
loadCollectionTask.ReplicaNumber = 3
err = queryCoord.scheduler.processTask(loadCollectionTask)
assert.Error(t, err)
assert.NoError(t, node1.stop())
assert.NoError(t, node2.stop())
assert.NoError(t, queryCoord.Stop())
assert.NoError(t, removeAllSession())
}
func Test_LoadPartitionAssignTaskFail(t *testing.T) {
......
......@@ -168,12 +168,12 @@ func syncReplicaSegments(ctx context.Context, cluster Cluster, childTasks []task
}
for dmc, leaders := range shardLeaders {
for _, leader := range leaders {
segments, ok := shardSegments[dmc]
if !ok {
break
}
segments, ok := shardSegments[dmc]
if !ok {
continue
}
for _, leader := range leaders {
req := querypb.SyncReplicaSegmentsRequest{
VchannelName: dmc,
ReplicaSegments: make([]*querypb.ReplicaSegmentsInfo, 0, len(segments)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册