diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index 8a9aede36ea091d9b4666bd3ea085929b6e3bf5c..0f2eda4c226bcde845734b8e94e3924fd2d5e2bc 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -374,6 +374,7 @@ func (c *queryNodeCluster) releasePartitions(ctx context.Context, nodeID int64, log.Debug("ReleasePartitions: queryNode release partitions error", zap.String("error", err.Error())) return err } + for _, partitionID := range in.PartitionIDs { err = c.clusterMeta.releasePartition(in.CollectionID, partitionID) if err != nil { diff --git a/internal/querycoord/impl.go b/internal/querycoord/impl.go index 74453522da09a083c1123d910c4bf1059e566300..309b978ca8a7b35b0903a74472911c61be5fdfc7 100644 --- a/internal/querycoord/impl.go +++ b/internal/querycoord/impl.go @@ -141,21 +141,23 @@ func (qc *QueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadColle return status, err } + baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_grpcRequest) loadCollectionTask := &LoadCollectionTask{ - BaseTask: BaseTask{ - ctx: qc.loopCtx, - Condition: NewTaskCondition(qc.loopCtx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, LoadCollectionRequest: req, rootCoord: qc.rootCoordClient, dataCoord: qc.dataCoordClient, cluster: qc.cluster, meta: qc.meta, } - qc.scheduler.Enqueue([]task{loadCollectionTask}) + err := qc.scheduler.Enqueue(loadCollectionTask) + if err != nil { + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = err.Error() + return status, err + } - err := loadCollectionTask.WaitToFinish() + err = loadCollectionTask.WaitToFinish() if err != nil { status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.Reason = err.Error() @@ -188,20 +190,22 @@ func (qc *QueryCoord) ReleaseCollection(ctx context.Context, req *querypb.Releas return status, nil } + baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_grpcRequest) releaseCollectionTask := &ReleaseCollectionTask{ - BaseTask: BaseTask{ - ctx: qc.loopCtx, - Condition: NewTaskCondition(qc.loopCtx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, ReleaseCollectionRequest: req, cluster: qc.cluster, meta: qc.meta, rootCoord: qc.rootCoordClient, } - qc.scheduler.Enqueue([]task{releaseCollectionTask}) + err := qc.scheduler.Enqueue(releaseCollectionTask) + if err != nil { + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = err.Error() + return status, err + } - err := releaseCollectionTask.WaitToFinish() + err = releaseCollectionTask.WaitToFinish() if err != nil { status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.Reason = err.Error() @@ -329,20 +333,22 @@ func (qc *QueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadParti req.PartitionIDs = partitionIDsToLoad } + baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_grpcRequest) loadPartitionTask := &LoadPartitionTask{ - BaseTask: BaseTask{ - ctx: qc.loopCtx, - Condition: NewTaskCondition(qc.loopCtx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, LoadPartitionsRequest: req, dataCoord: qc.dataCoordClient, cluster: qc.cluster, meta: qc.meta, } - qc.scheduler.Enqueue([]task{loadPartitionTask}) + err := qc.scheduler.Enqueue(loadPartitionTask) + if err != nil { + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = err.Error() + return status, err + } - err := loadPartitionTask.WaitToFinish() + err = loadPartitionTask.WaitToFinish() if err != nil { status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.Reason = err.Error() @@ -398,18 +404,20 @@ func (qc *QueryCoord) ReleasePartitions(ctx context.Context, req *querypb.Releas } req.PartitionIDs = toReleasedPartitions + baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_grpcRequest) releasePartitionTask := &ReleasePartitionTask{ - BaseTask: BaseTask{ - ctx: qc.loopCtx, - Condition: NewTaskCondition(qc.loopCtx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, ReleasePartitionsRequest: req, cluster: qc.cluster, } - qc.scheduler.Enqueue([]task{releasePartitionTask}) + err := qc.scheduler.Enqueue(releasePartitionTask) + if err != nil { + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = err.Error() + return status, err + } - err := releasePartitionTask.WaitToFinish() + err = releasePartitionTask.WaitToFinish() if err != nil { status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.Reason = err.Error() diff --git a/internal/querycoord/impl_test.go b/internal/querycoord/impl_test.go index 384719ecdd59d8abf22be031f1d657d1eb0e43fc..9ed5405a8b4e2d35e6cb684d5f1bed58a2c3459d 100644 --- a/internal/querycoord/impl_test.go +++ b/internal/querycoord/impl_test.go @@ -13,6 +13,7 @@ package querycoord import ( "context" "encoding/json" + "errors" "testing" "time" @@ -328,6 +329,87 @@ func TestGrpcTask(t *testing.T) { assert.Nil(t, err) } +func TestGrpcTaskEnqueueFail(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + _, err = startQueryNodeServer(ctx) + assert.Nil(t, err) + + taskIDAllocator := queryCoord.scheduler.taskIDAllocator + failedAllocator := func() (UniqueID, error) { + return 0, errors.New("scheduler failed to allocate ID") + } + + queryCoord.scheduler.taskIDAllocator = failedAllocator + + t.Run("Test LoadPartition", func(t *testing.T) { + status, err := queryCoord.LoadPartitions(ctx, &querypb.LoadPartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadPartitions, + }, + CollectionID: defaultCollectionID, + PartitionIDs: []UniqueID{defaultPartitionID}, + Schema: genCollectionSchema(defaultCollectionID, false), + }) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) + assert.NotNil(t, err) + }) + + t.Run("Test LoadCollection", func(t *testing.T) { + status, err := queryCoord.LoadCollection(ctx, &querypb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + }, + CollectionID: defaultCollectionID, + Schema: genCollectionSchema(defaultCollectionID, false), + }) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) + assert.NotNil(t, err) + }) + + queryCoord.scheduler.taskIDAllocator = taskIDAllocator + status, err := queryCoord.LoadCollection(ctx, &querypb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + }, + CollectionID: defaultCollectionID, + Schema: genCollectionSchema(defaultCollectionID, false), + }) + assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) + assert.Nil(t, err) + queryCoord.scheduler.taskIDAllocator = failedAllocator + + t.Run("Test ReleasePartition", func(t *testing.T) { + status, err := queryCoord.ReleasePartitions(ctx, &querypb.ReleasePartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleasePartitions, + }, + CollectionID: defaultCollectionID, + PartitionIDs: []UniqueID{defaultPartitionID}, + }) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) + assert.NotNil(t, err) + }) + + t.Run("Test ReleaseCollection", func(t *testing.T) { + status, err := queryCoord.ReleaseCollection(ctx, &querypb.ReleaseCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleaseCollection, + }, + CollectionID: defaultCollectionID, + }) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) + assert.NotNil(t, err) + }) + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + func TestLoadBalanceTask(t *testing.T) { refreshParams() baseCtx := context.Background() @@ -371,7 +453,7 @@ func TestLoadBalanceTask(t *testing.T) { } loadBalanceTask := &LoadBalanceTask{ - BaseTask: BaseTask{ + BaseTask: &BaseTask{ ctx: baseCtx, Condition: NewTaskCondition(baseCtx), triggerCondition: querypb.TriggerCondition_nodeDown, @@ -382,7 +464,7 @@ func TestLoadBalanceTask(t *testing.T) { cluster: queryCoord.cluster, meta: queryCoord.meta, } - queryCoord.scheduler.Enqueue([]task{loadBalanceTask}) + queryCoord.scheduler.Enqueue(loadBalanceTask) res, err = queryCoord.ReleaseCollection(baseCtx, &querypb.ReleaseCollectionRequest{ Base: &commonpb.MsgBase{ @@ -400,6 +482,7 @@ func TestLoadBalanceTask(t *testing.T) { } func TestGrpcTaskBeforeHealthy(t *testing.T) { + refreshParams() ctx := context.Background() unHealthyCoord, err := startUnHealthyQueryCoord(ctx) assert.Nil(t, err) diff --git a/internal/querycoord/meta.go b/internal/querycoord/meta.go index 85d39047a61e02b0028b032b29ce3a4f279a9dd0..4e00f34603b3680d730e188d207fe2cf5b38141b 100644 --- a/internal/querycoord/meta.go +++ b/internal/querycoord/meta.go @@ -423,25 +423,23 @@ func (m *MetaReplica) releaseCollection(collectionID UniqueID) error { defer m.Unlock() delete(m.collectionInfos, collectionID) + var err error for id, info := range m.segmentInfos { if info.CollectionID == collectionID { - err := removeSegmentInfo(id, m.client) + err = removeSegmentInfo(id, m.client) if err != nil { - log.Error("remove segmentInfo error", zap.Any("error", err.Error()), zap.Int64("segmentID", id)) - return err + log.Warn("remove segmentInfo error", zap.Any("error", err.Error()), zap.Int64("segmentID", id)) } delete(m.segmentInfos, id) } } - delete(m.queryChannelInfos, collectionID) - err := removeGlobalCollectionInfo(collectionID, m.client) + err = removeGlobalCollectionInfo(collectionID, m.client) if err != nil { - log.Error("remove collectionInfo error", zap.Any("error", err.Error()), zap.Int64("collectionID", collectionID)) - return err + log.Warn("remove collectionInfo error", zap.Any("error", err.Error()), zap.Int64("collectionID", collectionID)) } - return nil + return err } func (m *MetaReplica) releasePartition(collectionID UniqueID, partitionID UniqueID) error { diff --git a/internal/querycoord/mock_3rd_component_test.go b/internal/querycoord/mock_3rd_component_test.go index 2488ef4c3f3f8f8d1f03ae71bfc78199893f49df..214152804ea9e5edfc161627f3ab3010b750998f 100644 --- a/internal/querycoord/mock_3rd_component_test.go +++ b/internal/querycoord/mock_3rd_component_test.go @@ -214,7 +214,9 @@ func (rc *rootCoordMock) createCollection(collectionID UniqueID) { if _, ok := rc.Col2partition[collectionID]; !ok { rc.CollectionIDs = append(rc.CollectionIDs, collectionID) - rc.Col2partition[collectionID] = make([]UniqueID, 0) + partitionIDs := make([]UniqueID, 0) + partitionIDs = append(partitionIDs, defaultPartitionID+1) + rc.Col2partition[collectionID] = partitionIDs } } @@ -222,13 +224,30 @@ func (rc *rootCoordMock) createPartition(collectionID UniqueID, partitionID Uniq rc.Lock() defer rc.Unlock() - if _, ok := rc.Col2partition[collectionID]; ok { - rc.Col2partition[collectionID] = append(rc.Col2partition[collectionID], partitionID) + if partitionIDs, ok := rc.Col2partition[collectionID]; ok { + partitionExist := false + for _, id := range partitionIDs { + if id == partitionID { + partitionExist = true + break + } + } + if !partitionExist { + rc.Col2partition[collectionID] = append(rc.Col2partition[collectionID], partitionID) + } + return nil } return errors.New("collection not exist") } +func (rc *rootCoordMock) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { + rc.createPartition(defaultCollectionID, defaultPartitionID) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil +} + func (rc *rootCoordMock) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { collectionID := in.CollectionID status := &commonpb.Status{ @@ -244,7 +263,6 @@ func (rc *rootCoordMock) ShowPartitions(ctx context.Context, in *milvuspb.ShowPa } rc.createCollection(collectionID) - rc.createPartition(collectionID, defaultPartitionID) return &milvuspb.ShowPartitionsResponse{ Status: status, @@ -267,16 +285,17 @@ type dataCoordMock struct { minioKV kv.BaseKV collections []UniqueID col2DmChannels map[UniqueID][]*datapb.VchannelInfo - partitionID2Segment map[UniqueID]UniqueID - Segment2Binlog map[UniqueID][]*datapb.SegmentBinlogs - assignedSegmentID UniqueID + partitionID2Segment map[UniqueID][]UniqueID + Segment2Binlog map[UniqueID]*datapb.SegmentBinlogs + baseSegmentID UniqueID + channelNumPerCol int } func newDataCoordMock(ctx context.Context) (*dataCoordMock, error) { collectionIDs := make([]UniqueID, 0) col2DmChannels := make(map[UniqueID][]*datapb.VchannelInfo) - partitionID2Segment := make(map[UniqueID]UniqueID) - segment2Binglog := make(map[UniqueID][]*datapb.SegmentBinlogs) + partitionID2Segments := make(map[UniqueID][]UniqueID) + segment2Binglog := make(map[UniqueID]*datapb.SegmentBinlogs) // create minio client option := &minioKV.Option{ @@ -296,9 +315,10 @@ func newDataCoordMock(ctx context.Context) (*dataCoordMock, error) { minioKV: kv, collections: collectionIDs, col2DmChannels: col2DmChannels, - partitionID2Segment: partitionID2Segment, + partitionID2Segment: partitionID2Segments, Segment2Binlog: segment2Binglog, - assignedSegmentID: defaultSegmentID, + baseSegmentID: defaultSegmentID, + channelNumPerCol: 2, }, nil } @@ -306,28 +326,36 @@ func (data *dataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetR collectionID := req.CollectionID partitionID := req.PartitionID - if _, ok := data.col2DmChannels[collectionID]; !ok { - segmentID := data.assignedSegmentID - data.partitionID2Segment[partitionID] = segmentID - fieldID2Paths, err := generateInsertBinLog(collectionID, partitionID, segmentID, "queryCoorf-mockDataCoord", data.minioKV) - if err != nil { - return nil, err - } - fieldBinlogs := make([]*datapb.FieldBinlog, 0) - for fieldID, path := range fieldID2Paths { - fieldBinlog := &datapb.FieldBinlog{ - FieldID: fieldID, - Binlogs: []string{path}, + if _, ok := data.partitionID2Segment[partitionID]; !ok { + segmentIDs := make([]UniqueID, 0) + for i := 0; i < data.channelNumPerCol; i++ { + segmentID := data.baseSegmentID + if _, ok := data.Segment2Binlog[segmentID]; !ok { + fieldID2Paths, err := generateInsertBinLog(collectionID, partitionID, segmentID, "queryCoorf-mockDataCoord", data.minioKV) + if err != nil { + return nil, err + } + fieldBinlogs := make([]*datapb.FieldBinlog, 0) + for fieldID, path := range fieldID2Paths { + fieldBinlog := &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: []string{path}, + } + fieldBinlogs = append(fieldBinlogs, fieldBinlog) + } + segmentBinlog := &datapb.SegmentBinlogs{ + SegmentID: segmentID, + FieldBinlogs: fieldBinlogs, + } + data.Segment2Binlog[segmentID] = segmentBinlog } - fieldBinlogs = append(fieldBinlogs, fieldBinlog) - } - data.Segment2Binlog[segmentID] = make([]*datapb.SegmentBinlogs, 0) - segmentBinlog := &datapb.SegmentBinlogs{ - SegmentID: segmentID, - FieldBinlogs: fieldBinlogs, + segmentIDs = append(segmentIDs, segmentID) + data.baseSegmentID++ } - data.Segment2Binlog[segmentID] = append(data.Segment2Binlog[segmentID], segmentBinlog) + data.partitionID2Segment[partitionID] = segmentIDs + } + if _, ok := data.col2DmChannels[collectionID]; !ok { channelInfos := make([]*datapb.VchannelInfo, 0) data.collections = append(data.collections, collectionID) collectionName := funcutil.RandomString(8) @@ -339,20 +367,24 @@ func (data *dataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetR SeekPosition: &internalpb.MsgPosition{ ChannelName: vChannel, }, - FlushedSegments: []*datapb.SegmentInfo{{ID: segmentID}}, } channelInfos = append(channelInfos, channelInfo) } data.col2DmChannels[collectionID] = channelInfos } - segmentID := data.partitionID2Segment[partitionID] + binlogs := make([]*datapb.SegmentBinlogs, 0) + for _, segmentID := range data.partitionID2Segment[partitionID] { + if _, ok := data.Segment2Binlog[segmentID]; ok { + binlogs = append(binlogs, data.Segment2Binlog[segmentID]) + } + } return &datapb.GetRecoveryInfoResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, Channels: data.col2DmChannels[collectionID], - Binlogs: data.Segment2Binlog[segmentID], + Binlogs: binlogs, }, nil } diff --git a/internal/querycoord/query_coord.go b/internal/querycoord/query_coord.go index 0acbefced456a89a24e4e1e186523fb24bcf8a26..e4c1f093b4dbaa7e934f4ba5f8e0c6233bb19878 100644 --- a/internal/querycoord/query_coord.go +++ b/internal/querycoord/query_coord.go @@ -87,6 +87,8 @@ func (qc *QueryCoord) Register() error { // Init function initializes the queryCoord's meta, cluster, etcdKV and task scheduler func (qc *QueryCoord) Init() error { + log.Debug("query coordinator start init") + //connect etcd connectEtcdFn := func() error { etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) if err != nil { @@ -221,19 +223,17 @@ func (qc *QueryCoord) watchNodeLoop() { SourceNodeIDs: offlineNodeIDs, } + baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_nodeDown) loadBalanceTask := &LoadBalanceTask{ - BaseTask: BaseTask{ - ctx: qc.loopCtx, - Condition: NewTaskCondition(qc.loopCtx), - triggerCondition: querypb.TriggerCondition_nodeDown, - }, + BaseTask: baseTask, LoadBalanceRequest: loadBalanceSegment, rootCoord: qc.rootCoordClient, dataCoord: qc.dataCoordClient, cluster: qc.cluster, meta: qc.meta, } - qc.scheduler.Enqueue([]task{loadBalanceTask}) + //TODO::deal enqueue error + qc.scheduler.Enqueue(loadBalanceTask) log.Debug("start a loadBalance task", zap.Any("task", loadBalanceTask)) } @@ -271,21 +271,19 @@ func (qc *QueryCoord) watchNodeLoop() { BalanceReason: querypb.TriggerCondition_nodeDown, } + baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_nodeDown) loadBalanceTask := &LoadBalanceTask{ - BaseTask: BaseTask{ - ctx: qc.loopCtx, - Condition: NewTaskCondition(qc.loopCtx), - triggerCondition: querypb.TriggerCondition_nodeDown, - }, + BaseTask: baseTask, LoadBalanceRequest: loadBalanceSegment, rootCoord: qc.rootCoordClient, dataCoord: qc.dataCoordClient, cluster: qc.cluster, meta: qc.meta, } - qc.scheduler.Enqueue([]task{loadBalanceTask}) - log.Debug("start a loadBalance task", zap.Any("task", loadBalanceTask)) qc.metricsCacheManager.InvalidateSystemInfoMetrics() + //TODO:: deal enqueue error + qc.scheduler.Enqueue(loadBalanceTask) + log.Debug("start a loadBalance task", zap.Any("task", loadBalanceTask)) } } } diff --git a/internal/querycoord/query_coord_test.go b/internal/querycoord/query_coord_test.go index f0cdae43cb2e115739b8fa470efa640f6564d059..e7e25f97f5340082bfa90be97db0560e614d260b 100644 --- a/internal/querycoord/query_coord_test.go +++ b/internal/querycoord/query_coord_test.go @@ -93,6 +93,11 @@ func startQueryCoord(ctx context.Context) (*QueryCoord, error) { return coord, nil } +func createDefaultPartition(ctx context.Context, queryCoord *QueryCoord) error { + _, err := queryCoord.rootCoordClient.CreatePartition(ctx, nil) + return err +} + func startUnHealthyQueryCoord(ctx context.Context) (*QueryCoord, error) { factory := msgstream.NewPmsFactory() diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index c3af7015c236892959f19aca8ddd1fdc667ce337..a7b33d9496abff51193ae92319b5436304991e8f 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -15,6 +15,7 @@ import ( "context" "errors" "fmt" + "sync" "time" "github.com/golang/protobuf/proto" @@ -39,6 +40,10 @@ const ( loadBalanceInfoPrefix = "queryCoord-loadBalanceInfo" ) +const ( + MaxRetryNum = 5 +) + type taskState int const ( @@ -46,6 +51,7 @@ const ( taskDoing taskState = 1 taskDone taskState = 3 taskExpired taskState = 4 + taskFailed taskState = 5 ) type task interface { @@ -55,33 +61,63 @@ type task interface { MsgBase() *commonpb.MsgBase Type() commonpb.MsgType Timestamp() Timestamp + TriggerCondition() querypb.TriggerCondition PreExecute(ctx context.Context) error Execute(ctx context.Context) error PostExecute(ctx context.Context) error + Reschedule(ctx context.Context) ([]task, error) + RollBack(ctx context.Context) []task WaitToFinish() error Notify(err error) TaskPriority() querypb.TriggerCondition + SetParentTask(t task) GetParentTask() task GetChildTask() []task AddChildTask(t task) + RemoveChildTaskByID(taskID UniqueID) IsValid() bool - Reschedule() ([]task, error) Marshal() ([]byte, error) State() taskState SetState(state taskState) + IsRetryable() bool + SetResultInfo(err error) + GetResultInfo() *commonpb.Status + UpdateTaskProcess() } type BaseTask struct { Condition - ctx context.Context - cancel context.CancelFunc - result *commonpb.Status - state taskState + ctx context.Context + cancel context.CancelFunc + result *commonpb.Status + resultMu sync.RWMutex + state taskState + stateMu sync.RWMutex + retryCount int + //sync.RWMutex taskID UniqueID triggerCondition querypb.TriggerCondition parentTask task childTasks []task + childTasksMu sync.RWMutex +} + +func newBaseTask(ctx context.Context, triggerType querypb.TriggerCondition) *BaseTask { + childCtx, cancel := context.WithCancel(ctx) + condition := NewTaskCondition(childCtx) + + baseTask := &BaseTask{ + ctx: childCtx, + cancel: cancel, + Condition: condition, + state: taskUndo, + retryCount: MaxRetryNum, + triggerCondition: triggerType, + childTasks: []task{}, + } + + return baseTask } func (bt *BaseTask) ID() UniqueID { @@ -96,41 +132,108 @@ func (bt *BaseTask) TraceCtx() context.Context { return bt.ctx } +func (bt *BaseTask) TriggerCondition() querypb.TriggerCondition { + return bt.triggerCondition +} + func (bt *BaseTask) TaskPriority() querypb.TriggerCondition { return bt.triggerCondition } +func (bt *BaseTask) SetParentTask(t task) { + bt.parentTask = t +} + func (bt *BaseTask) GetParentTask() task { return bt.parentTask } func (bt *BaseTask) GetChildTask() []task { + bt.childTasksMu.RLock() + defer bt.childTasksMu.RUnlock() + return bt.childTasks } func (bt *BaseTask) AddChildTask(t task) { + bt.childTasksMu.Lock() + defer bt.childTasksMu.Unlock() + bt.childTasks = append(bt.childTasks, t) } +func (bt *BaseTask) RemoveChildTaskByID(taskID UniqueID) { + bt.childTasksMu.Lock() + defer bt.childTasksMu.Unlock() + + result := make([]task, 0) + for _, t := range bt.childTasks { + if t.ID() != taskID { + result = append(result, t) + } + } + bt.childTasks = result +} + func (bt *BaseTask) IsValid() bool { return true } -func (bt *BaseTask) Reschedule() ([]task, error) { +func (bt *BaseTask) Reschedule(ctx context.Context) ([]task, error) { return nil, nil } func (bt *BaseTask) State() taskState { + bt.stateMu.RLock() + defer bt.stateMu.RUnlock() return bt.state } func (bt *BaseTask) SetState(state taskState) { + bt.stateMu.Lock() + defer bt.stateMu.Unlock() bt.state = state } +func (bt *BaseTask) IsRetryable() bool { + return bt.retryCount > 0 +} + +func (bt *BaseTask) SetResultInfo(err error) { + bt.resultMu.Lock() + defer bt.resultMu.Unlock() + + if bt.result == nil { + bt.result = &commonpb.Status{} + } + if err == nil { + bt.result.ErrorCode = commonpb.ErrorCode_Success + bt.result.Reason = "" + return + } + + bt.result.ErrorCode = commonpb.ErrorCode_UnexpectedError + bt.result.Reason = bt.result.Reason + ", " + err.Error() +} + +func (bt *BaseTask) GetResultInfo() *commonpb.Status { + bt.resultMu.RLock() + defer bt.resultMu.RUnlock() + return proto.Clone(bt.result).(*commonpb.Status) +} + +func (bt *BaseTask) UpdateTaskProcess() { + // TODO:: +} + +func (bt *BaseTask) RollBack(ctx context.Context) []task { + //TODO:: + return nil +} + //************************grpcTask***************************// type LoadCollectionTask struct { - BaseTask + *BaseTask *querypb.LoadCollectionRequest rootCoord types.RootCoord dataCoord types.DataCoord @@ -154,13 +257,28 @@ func (lct *LoadCollectionTask) Timestamp() Timestamp { return lct.Base.Timestamp } +func (lct *LoadCollectionTask) UpdateTaskProcess() { + collectionID := lct.CollectionID + childTasks := lct.GetChildTask() + allDone := true + for _, t := range childTasks { + if t.State() != taskDone { + allDone = false + } + } + if allDone { + err := lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_loadCollection) + if err != nil { + log.Error("loadCollectionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID)) + lct.SetResultInfo(err) + } + } +} + func (lct *LoadCollectionTask) PreExecute(ctx context.Context) error { collectionID := lct.CollectionID schema := lct.Schema - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - lct.result = status + lct.SetResultInfo(nil) log.Debug("start do LoadCollectionTask", zap.Int64("msgID", lct.ID()), zap.Int64("collectionID", collectionID), @@ -169,10 +287,10 @@ func (lct *LoadCollectionTask) PreExecute(ctx context.Context) error { } func (lct *LoadCollectionTask) Execute(ctx context.Context) error { + defer func() { + lct.retryCount-- + }() collectionID := lct.CollectionID - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } showPartitionRequest := &milvuspb.ShowPartitionsRequest{ Base: &commonpb.MsgBase{ @@ -182,8 +300,7 @@ func (lct *LoadCollectionTask) Execute(ctx context.Context) error { } showPartitionResponse, err := lct.rootCoord.ShowPartitions(ctx, showPartitionRequest) if err != nil { - status.Reason = err.Error() - lct.result = status + lct.SetResultInfo(err) return err } log.Debug("loadCollectionTask: get collection's all partitionIDs", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", showPartitionResponse.PartitionIDs)) @@ -233,8 +350,7 @@ func (lct *LoadCollectionTask) Execute(ctx context.Context) error { } recoveryInfo, err := lct.dataCoord.GetRecoveryInfo(ctx, getRecoveryInfoRequest) if err != nil { - status.Reason = err.Error() - lct.result = status + lct.SetResultInfo(err) return err } @@ -302,10 +418,10 @@ func (lct *LoadCollectionTask) Execute(ctx context.Context) error { } } - err = assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs) + err = assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, false) if err != nil { - status.Reason = err.Error() - lct.result = status + log.Warn("loadCollectionTask: assign child task failed", zap.Int64("collectionID", collectionID)) + lct.SetResultInfo(err) return err } log.Debug("loadCollectionTask: assign child task done", zap.Int64("collectionID", collectionID)) @@ -318,54 +434,52 @@ func (lct *LoadCollectionTask) Execute(ctx context.Context) error { func (lct *LoadCollectionTask) PostExecute(ctx context.Context) error { collectionID := lct.CollectionID - if lct.State() == taskDone { - err := lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_loadCollection) - if err != nil { - log.Debug("loadCollectionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID)) - return err - } - return nil - } if lct.result.ErrorCode != commonpb.ErrorCode_Success { - lct.childTasks = make([]task, 0) - nodes, err := lct.cluster.onlineNodes() + lct.childTasks = []task{} + err := lct.meta.releaseCollection(collectionID) if err != nil { - log.Debug(err.Error()) - } - for nodeID := range nodes { - req := &querypb.ReleaseCollectionRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ReleaseCollection, - MsgID: lct.Base.MsgID, - Timestamp: lct.Base.Timestamp, - SourceID: lct.Base.SourceID, - }, - DbID: lct.DbID, - CollectionID: lct.CollectionID, - NodeID: nodeID, - } - releaseCollectionTask := &ReleaseCollectionTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, - ReleaseCollectionRequest: req, - cluster: lct.cluster, - } - lct.AddChildTask(releaseCollectionTask) - log.Debug("loadCollectionTask: add a releaseCollectionTask to loadCollectionTask's childTask", zap.Any("task", releaseCollectionTask)) + log.Error("LoadCollectionTask: occur error when release collection info from meta", zap.Error(err)) } } - lct.meta.addCollection(collectionID, lct.Schema) + log.Debug("LoadCollectionTask postExecute done", zap.Int64("msgID", lct.ID()), zap.Int64("collectionID", collectionID)) return nil } +func (lct *LoadCollectionTask) RollBack(ctx context.Context) []task { + nodes, _ := lct.cluster.onlineNodes() + resultTasks := make([]task, 0) + //TODO::call rootCoord.ReleaseDQLMessageStream + for nodeID := range nodes { + //brute force rollBack, should optimize + req := &querypb.ReleaseCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleaseCollection, + MsgID: lct.Base.MsgID, + Timestamp: lct.Base.Timestamp, + SourceID: lct.Base.SourceID, + }, + DbID: lct.DbID, + CollectionID: lct.CollectionID, + NodeID: nodeID, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + baseTask.SetParentTask(lct) + releaseCollectionTask := &ReleaseCollectionTask{ + BaseTask: baseTask, + ReleaseCollectionRequest: req, + cluster: lct.cluster, + } + resultTasks = append(resultTasks, releaseCollectionTask) + } + log.Debug("loadCollectionTask: rollBack loadCollectionTask", zap.Any("loadCollectionTask", lct), zap.Any("rollBack task", resultTasks)) + return resultTasks +} + type ReleaseCollectionTask struct { - BaseTask + *BaseTask *querypb.ReleaseCollectionRequest cluster Cluster meta Meta @@ -390,10 +504,7 @@ func (rct *ReleaseCollectionTask) Timestamp() Timestamp { func (rct *ReleaseCollectionTask) PreExecute(context.Context) error { collectionID := rct.CollectionID - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - rct.result = status + rct.SetResultInfo(nil) log.Debug("start do ReleaseCollectionTask", zap.Int64("msgID", rct.ID()), zap.Int64("collectionID", collectionID)) @@ -401,10 +512,11 @@ func (rct *ReleaseCollectionTask) PreExecute(context.Context) error { } func (rct *ReleaseCollectionTask) Execute(ctx context.Context) error { + defer func() { + rct.retryCount-- + }() collectionID := rct.CollectionID - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } + // if nodeID ==0, it means that the release request has not been assigned to the specified query node if rct.NodeID <= 0 { rct.meta.releaseCollection(collectionID) @@ -419,17 +531,10 @@ func (rct *ReleaseCollectionTask) Execute(ctx context.Context) error { CollectionID: rct.CollectionID, } res, err := rct.rootCoord.ReleaseDQLMessageStream(rct.ctx, releaseDQLMessageStreamReq) - if err != nil { - log.Error("ReleaseCollectionTask: release collection end, releaseDQLMessageStream occur error", zap.Int64("collectionID", rct.CollectionID)) - status.Reason = err.Error() - rct.result = status - return err - } - if res.ErrorCode != commonpb.ErrorCode_Success { - log.Error("ReleaseCollectionTask: release collection end, releaseDQLMessageStream occur error", zap.Int64("collectionID", rct.CollectionID)) + if res.ErrorCode != commonpb.ErrorCode_Success || err != nil { + log.Warn("ReleaseCollectionTask: release collection end, releaseDQLMessageStream occur error", zap.Int64("collectionID", rct.CollectionID)) err = errors.New("rootCoord releaseDQLMessageStream failed") - status.Reason = err.Error() - rct.result = status + rct.SetResultInfo(err) return err } @@ -440,24 +545,22 @@ func (rct *ReleaseCollectionTask) Execute(ctx context.Context) error { for nodeID := range nodes { req := proto.Clone(rct.ReleaseCollectionRequest).(*querypb.ReleaseCollectionRequest) req.NodeID = nodeID + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + baseTask.SetParentTask(rct) releaseCollectionTask := &ReleaseCollectionTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, ReleaseCollectionRequest: req, cluster: rct.cluster, } + rct.AddChildTask(releaseCollectionTask) log.Debug("ReleaseCollectionTask: add a releaseCollectionTask to releaseCollectionTask's childTask", zap.Any("task", releaseCollectionTask)) } } else { err := rct.cluster.releaseCollection(ctx, rct.NodeID, rct.ReleaseCollectionRequest) if err != nil { - log.Error("ReleaseCollectionTask: release collection end, node occur error", zap.Int64("nodeID", rct.NodeID)) - status.Reason = err.Error() - rct.result = status + log.Warn("ReleaseCollectionTask: release collection end, node occur error", zap.Int64("nodeID", rct.NodeID)) + rct.SetResultInfo(err) return err } } @@ -471,6 +574,9 @@ func (rct *ReleaseCollectionTask) Execute(ctx context.Context) error { func (rct *ReleaseCollectionTask) PostExecute(context.Context) error { collectionID := rct.CollectionID + if rct.result.ErrorCode != commonpb.ErrorCode_Success { + rct.childTasks = []task{} + } log.Debug("ReleaseCollectionTask postExecute done", zap.Int64("msgID", rct.ID()), @@ -479,8 +585,15 @@ func (rct *ReleaseCollectionTask) PostExecute(context.Context) error { return nil } +func (rct *ReleaseCollectionTask) RollBack(ctx context.Context) []task { + //TODO:: + //if taskID == 0, recovery meta + //if taskID != 0, recovery collection on queryNode + return nil +} + type LoadPartitionTask struct { - BaseTask + *BaseTask *querypb.LoadPartitionsRequest dataCoord types.DataCoord cluster Cluster @@ -504,12 +617,30 @@ func (lpt *LoadPartitionTask) Timestamp() Timestamp { return lpt.Base.Timestamp } -func (lpt *LoadPartitionTask) PreExecute(context.Context) error { +func (lpt *LoadPartitionTask) UpdateTaskProcess() { collectionID := lpt.CollectionID - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, + partitionIDs := lpt.PartitionIDs + childTasks := lpt.GetChildTask() + allDone := true + for _, t := range childTasks { + if t.State() != taskDone { + allDone = false + } + } + if allDone { + for _, id := range partitionIDs { + err := lpt.meta.setLoadPercentage(collectionID, id, 100, querypb.LoadType_LoadPartition) + if err != nil { + log.Error("loadPartitionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", id)) + lpt.SetResultInfo(err) + } + } } - lpt.result = status +} + +func (lpt *LoadPartitionTask) PreExecute(context.Context) error { + collectionID := lpt.CollectionID + lpt.SetResultInfo(nil) log.Debug("start do LoadPartitionTask", zap.Int64("msgID", lpt.ID()), zap.Int64("collectionID", collectionID)) @@ -517,6 +648,9 @@ func (lpt *LoadPartitionTask) PreExecute(context.Context) error { } func (lpt *LoadPartitionTask) Execute(ctx context.Context) error { + defer func() { + lpt.retryCount-- + }() collectionID := lpt.CollectionID partitionIDs := lpt.PartitionIDs @@ -527,9 +661,6 @@ func (lpt *LoadPartitionTask) Execute(ctx context.Context) error { for _, id := range partitionIDs { lpt.meta.addPartition(collectionID, id) } - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } segmentsToLoad := make([]UniqueID, 0) loadSegmentReqs := make([]*querypb.LoadSegmentsRequest, 0) @@ -543,8 +674,7 @@ func (lpt *LoadPartitionTask) Execute(ctx context.Context) error { } recoveryInfo, err := lpt.dataCoord.GetRecoveryInfo(ctx, getRecoveryInfoRequest) if err != nil { - status.Reason = err.Error() - lpt.result = status + lpt.SetResultInfo(err) return err } @@ -586,10 +716,10 @@ func (lpt *LoadPartitionTask) Execute(ctx context.Context) error { log.Debug("LoadPartitionTask: set watchDmChannelsRequests", zap.Any("request", watchDmRequest), zap.Int64("collectionID", collectionID)) } } - err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs) + err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs, false) if err != nil { - status.Reason = err.Error() - lpt.result = status + log.Warn("LoadPartitionTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) + lpt.SetResultInfo(err) return err } log.Debug("LoadPartitionTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) @@ -604,81 +734,23 @@ func (lpt *LoadPartitionTask) Execute(ctx context.Context) error { func (lpt *LoadPartitionTask) PostExecute(ctx context.Context) error { collectionID := lpt.CollectionID partitionIDs := lpt.PartitionIDs - if lpt.State() == taskDone { - for _, id := range partitionIDs { - err := lpt.meta.setLoadPercentage(collectionID, id, 100, querypb.LoadType_LoadPartition) - if err != nil { - log.Debug("loadPartitionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", id)) - return err - } - } - return nil - } if lpt.result.ErrorCode != commonpb.ErrorCode_Success { - lpt.childTasks = make([]task, 0) + lpt.childTasks = []task{} if lpt.addCol { - nodes, err := lpt.cluster.onlineNodes() + err := lpt.meta.releaseCollection(collectionID) if err != nil { - log.Debug(err.Error()) - } - for nodeID := range nodes { - req := &querypb.ReleaseCollectionRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ReleaseCollection, - MsgID: lpt.Base.MsgID, - Timestamp: lpt.Base.Timestamp, - SourceID: lpt.Base.SourceID, - }, - DbID: lpt.DbID, - CollectionID: lpt.CollectionID, - NodeID: nodeID, - } - releaseCollectionTask := &ReleaseCollectionTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, - ReleaseCollectionRequest: req, - cluster: lpt.cluster, - } - lpt.AddChildTask(releaseCollectionTask) - log.Debug("loadPartitionTask: add a releaseCollectionTask to loadPartitionTask's childTask", zap.Any("task", releaseCollectionTask)) + log.Error("LoadPartitionTask: occur error when release collection info from meta", zap.Error(err)) } } else { - nodes, err := lpt.cluster.onlineNodes() - if err != nil { - log.Debug(err.Error()) - } - for nodeID := range nodes { - req := &querypb.ReleasePartitionsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ReleasePartitions, - MsgID: lpt.Base.MsgID, - Timestamp: lpt.Base.Timestamp, - SourceID: lpt.Base.SourceID, - }, - DbID: lpt.DbID, - CollectionID: lpt.CollectionID, - PartitionIDs: partitionIDs, - NodeID: nodeID, - } - - releasePartitionTask := &ReleasePartitionTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, - - ReleasePartitionsRequest: req, - cluster: lpt.cluster, + for _, partitionID := range partitionIDs { + err := lpt.meta.releasePartition(collectionID, partitionID) + if err != nil { + log.Error("LoadPartitionTask: occur error when release partition info from meta", zap.Error(err)) } - lpt.AddChildTask(releasePartitionTask) - log.Debug("loadPartitionTask: add a releasePartitionTask to loadPartitionTask's childTask", zap.Any("task", releasePartitionTask)) } } } + log.Debug("LoadPartitionTask postExecute done", zap.Int64("msgID", lpt.ID()), zap.Int64("collectionID", collectionID), @@ -686,8 +758,65 @@ func (lpt *LoadPartitionTask) PostExecute(ctx context.Context) error { return nil } +func (lpt *LoadPartitionTask) RollBack(ctx context.Context) []task { + partitionIDs := lpt.PartitionIDs + resultTasks := make([]task, 0) + //brute force rollBack, should optimize + if lpt.addCol { + nodes, _ := lpt.cluster.onlineNodes() + for nodeID := range nodes { + req := &querypb.ReleaseCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleaseCollection, + MsgID: lpt.Base.MsgID, + Timestamp: lpt.Base.Timestamp, + SourceID: lpt.Base.SourceID, + }, + DbID: lpt.DbID, + CollectionID: lpt.CollectionID, + NodeID: nodeID, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + baseTask.SetParentTask(lpt) + releaseCollectionTask := &ReleaseCollectionTask{ + BaseTask: baseTask, + ReleaseCollectionRequest: req, + cluster: lpt.cluster, + } + resultTasks = append(resultTasks, releaseCollectionTask) + } + } else { + nodes, _ := lpt.cluster.onlineNodes() + for nodeID := range nodes { + req := &querypb.ReleasePartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleasePartitions, + MsgID: lpt.Base.MsgID, + Timestamp: lpt.Base.Timestamp, + SourceID: lpt.Base.SourceID, + }, + DbID: lpt.DbID, + CollectionID: lpt.CollectionID, + PartitionIDs: partitionIDs, + NodeID: nodeID, + } + + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + baseTask.SetParentTask(lpt) + releasePartitionTask := &ReleasePartitionTask{ + BaseTask: baseTask, + ReleasePartitionsRequest: req, + cluster: lpt.cluster, + } + resultTasks = append(resultTasks, releasePartitionTask) + } + } + log.Debug("loadPartitionTask: rollBack loadPartitionTask", zap.Any("loadPartitionTask", lpt), zap.Any("rollBack task", resultTasks)) + return resultTasks +} + type ReleasePartitionTask struct { - BaseTask + *BaseTask *querypb.ReleasePartitionsRequest cluster Cluster } @@ -710,10 +839,7 @@ func (rpt *ReleasePartitionTask) Timestamp() Timestamp { func (rpt *ReleasePartitionTask) PreExecute(context.Context) error { collectionID := rpt.CollectionID - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - rpt.result = status + rpt.SetResultInfo(nil) log.Debug("start do releasePartitionTask", zap.Int64("msgID", rpt.ID()), zap.Int64("collectionID", collectionID)) @@ -721,11 +847,12 @@ func (rpt *ReleasePartitionTask) PreExecute(context.Context) error { } func (rpt *ReleasePartitionTask) Execute(ctx context.Context) error { + defer func() { + rpt.retryCount-- + }() collectionID := rpt.CollectionID partitionIDs := rpt.PartitionIDs - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } + // if nodeID ==0, it means that the release request has not been assigned to the specified query node if rpt.NodeID <= 0 { nodes, err := rpt.cluster.onlineNodes() @@ -735,13 +862,10 @@ func (rpt *ReleasePartitionTask) Execute(ctx context.Context) error { for nodeID := range nodes { req := proto.Clone(rpt.ReleasePartitionsRequest).(*querypb.ReleasePartitionsRequest) req.NodeID = nodeID + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + baseTask.SetParentTask(rpt) releasePartitionTask := &ReleasePartitionTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, - + BaseTask: baseTask, ReleasePartitionsRequest: req, cluster: rpt.cluster, } @@ -751,9 +875,8 @@ func (rpt *ReleasePartitionTask) Execute(ctx context.Context) error { } else { err := rpt.cluster.releasePartitions(ctx, rpt.NodeID, rpt.ReleasePartitionsRequest) if err != nil { - log.Error("ReleasePartitionsTask: release partition end, node occur error", zap.String("nodeID", fmt.Sprintln(rpt.NodeID))) - status.Reason = err.Error() - rpt.result = status + log.Warn("ReleasePartitionsTask: release partition end, node occur error", zap.String("nodeID", fmt.Sprintln(rpt.NodeID))) + rpt.SetResultInfo(err) return err } } @@ -769,6 +892,9 @@ func (rpt *ReleasePartitionTask) Execute(ctx context.Context) error { func (rpt *ReleasePartitionTask) PostExecute(context.Context) error { collectionID := rpt.CollectionID partitionIDs := rpt.PartitionIDs + if rpt.result.ErrorCode != commonpb.ErrorCode_Success { + rpt.childTasks = []task{} + } log.Debug("ReleasePartitionTask postExecute done", zap.Int64("msgID", rpt.ID()), @@ -778,12 +904,20 @@ func (rpt *ReleasePartitionTask) PostExecute(context.Context) error { return nil } +func (rpt *ReleasePartitionTask) RollBack(ctx context.Context) []task { + //TODO:: + //if taskID == 0, recovery meta + //if taskID != 0, recovery partition on queryNode + return nil +} + //****************************internal task*******************************// type LoadSegmentTask struct { - BaseTask + *BaseTask *querypb.LoadSegmentsRequest - meta Meta - cluster Cluster + meta Meta + cluster Cluster + excludeNodeIDs []int64 } func (lst *LoadSegmentTask) MsgBase() *commonpb.MsgBase { @@ -795,12 +929,12 @@ func (lst *LoadSegmentTask) Marshal() ([]byte, error) { } func (lst *LoadSegmentTask) IsValid() bool { - onService, err := lst.cluster.isOnline(lst.NodeID) + online, err := lst.cluster.isOnline(lst.NodeID) if err != nil { return false } - return lst.ctx != nil && onService + return lst.ctx != nil && online } func (lst *LoadSegmentTask) Type() commonpb.MsgType { @@ -811,15 +945,21 @@ func (lst *LoadSegmentTask) Timestamp() Timestamp { return lst.Base.Timestamp } +func (lst *LoadSegmentTask) UpdateTaskProcess() { + parentTask := lst.GetParentTask() + if parentTask == nil { + log.Warn("LoadSegmentTask: parentTask should not be nil") + return + } + parentTask.UpdateTaskProcess() +} + func (lst *LoadSegmentTask) PreExecute(context.Context) error { segmentIDs := make([]UniqueID, 0) for _, info := range lst.Infos { segmentIDs = append(segmentIDs, info.SegmentID) } - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - lst.result = status + lst.SetResultInfo(nil) log.Debug("start do loadSegmentTask", zap.Int64s("segmentIDs", segmentIDs), zap.Int64("loaded nodeID", lst.NodeID), @@ -828,14 +968,14 @@ func (lst *LoadSegmentTask) PreExecute(context.Context) error { } func (lst *LoadSegmentTask) Execute(ctx context.Context) error { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } + defer func() { + lst.retryCount-- + }() + err := lst.cluster.loadSegments(ctx, lst.NodeID, lst.LoadSegmentsRequest) if err != nil { - log.Error("LoadSegmentTask: loadSegment occur error", zap.Int64("taskID", lst.ID())) - status.Reason = err.Error() - lst.result = status + log.Warn("LoadSegmentTask: loadSegment occur error", zap.Int64("taskID", lst.ID())) + lst.SetResultInfo(err) return err } @@ -843,20 +983,26 @@ func (lst *LoadSegmentTask) Execute(ctx context.Context) error { zap.Int64("taskID", lst.ID())) return nil } + func (lst *LoadSegmentTask) PostExecute(context.Context) error { log.Debug("loadSegmentTask postExecute done", zap.Int64("taskID", lst.ID())) return nil } -func (lst *LoadSegmentTask) Reschedule() ([]task, error) { +func (lst *LoadSegmentTask) Reschedule(ctx context.Context) ([]task, error) { segmentIDs := make([]UniqueID, 0) collectionID := lst.Infos[0].CollectionID reScheduledTask := make([]task, 0) for _, info := range lst.Infos { segmentIDs = append(segmentIDs, info.SegmentID) } - segment2Nodes := shuffleSegmentsToQueryNode(segmentIDs, lst.cluster) + lst.excludeNodeIDs = append(lst.excludeNodeIDs, lst.NodeID) + segment2Nodes, err := shuffleSegmentsToQueryNode(segmentIDs, lst.cluster, false, lst.excludeNodeIDs) + if err != nil { + log.Error("loadSegment reschedule failed", zap.Int64s("excludeNodes", lst.excludeNodeIDs), zap.Error(err)) + return nil, err + } node2segmentInfos := make(map[int64][]*querypb.SegmentLoadInfo) for index, info := range lst.Infos { nodeID := segment2Nodes[index] @@ -867,12 +1013,10 @@ func (lst *LoadSegmentTask) Reschedule() ([]task, error) { } for nodeID, infos := range node2segmentInfos { + loadSegmentBaseTask := newBaseTask(ctx, lst.TriggerCondition()) + loadSegmentBaseTask.SetParentTask(lst.GetParentTask()) loadSegmentTask := &LoadSegmentTask{ - BaseTask: BaseTask{ - ctx: lst.ctx, - Condition: NewTaskCondition(lst.ctx), - triggerCondition: lst.LoadCondition, - }, + BaseTask: loadSegmentBaseTask, LoadSegmentsRequest: &querypb.LoadSegmentsRequest{ Base: lst.Base, NodeID: nodeID, @@ -880,8 +1024,9 @@ func (lst *LoadSegmentTask) Reschedule() ([]task, error) { Schema: lst.Schema, LoadCondition: lst.LoadCondition, }, - meta: lst.meta, - cluster: lst.cluster, + meta: lst.meta, + cluster: lst.cluster, + excludeNodeIDs: lst.excludeNodeIDs, } reScheduledTask = append(reScheduledTask, loadSegmentTask) log.Debug("LoadSegmentTask: add a loadSegmentTask to RescheduleTasks", zap.Any("task", loadSegmentTask)) @@ -902,13 +1047,10 @@ func (lst *LoadSegmentTask) Reschedule() ([]task, error) { RequestChannelID: queryChannel, ResultChannelID: queryResultChannel, } + watchQueryChannelBaseTask := newBaseTask(ctx, lst.TriggerCondition()) + watchQueryChannelBaseTask.SetParentTask(lst.GetParentTask()) watchQueryChannelTask := &WatchQueryChannelTask{ - BaseTask: BaseTask{ - ctx: lst.ctx, - Condition: NewTaskCondition(lst.ctx), - triggerCondition: lst.LoadCondition, - }, - + BaseTask: watchQueryChannelBaseTask, AddQueryChannelRequest: addQueryChannelRequest, cluster: lst.cluster, } @@ -921,7 +1063,7 @@ func (lst *LoadSegmentTask) Reschedule() ([]task, error) { } type ReleaseSegmentTask struct { - BaseTask + *BaseTask *querypb.ReleaseSegmentsRequest cluster Cluster } @@ -935,11 +1077,11 @@ func (rst *ReleaseSegmentTask) Marshal() ([]byte, error) { } func (rst *ReleaseSegmentTask) IsValid() bool { - onService, err := rst.cluster.isOnline(rst.NodeID) + online, err := rst.cluster.isOnline(rst.NodeID) if err != nil { return false } - return rst.ctx != nil && onService + return rst.ctx != nil && online } func (rst *ReleaseSegmentTask) Type() commonpb.MsgType { @@ -952,10 +1094,7 @@ func (rst *ReleaseSegmentTask) Timestamp() Timestamp { func (rst *ReleaseSegmentTask) PreExecute(context.Context) error { segmentIDs := rst.SegmentIDs - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - rst.result = status + rst.SetResultInfo(nil) log.Debug("start do releaseSegmentTask", zap.Int64s("segmentIDs", segmentIDs), zap.Int64("loaded nodeID", rst.NodeID), @@ -964,14 +1103,14 @@ func (rst *ReleaseSegmentTask) PreExecute(context.Context) error { } func (rst *ReleaseSegmentTask) Execute(ctx context.Context) error { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } + defer func() { + rst.retryCount-- + }() + err := rst.cluster.releaseSegments(rst.ctx, rst.NodeID, rst.ReleaseSegmentsRequest) if err != nil { - log.Error("ReleaseSegmentTask: releaseSegment occur error", zap.Int64("taskID", rst.ID())) - status.Reason = err.Error() - rst.result = status + log.Warn("ReleaseSegmentTask: releaseSegment occur error", zap.Int64("taskID", rst.ID())) + rst.SetResultInfo(err) return err } @@ -990,10 +1129,11 @@ func (rst *ReleaseSegmentTask) PostExecute(context.Context) error { } type WatchDmChannelTask struct { - BaseTask + *BaseTask *querypb.WatchDmChannelsRequest - meta Meta - cluster Cluster + meta Meta + cluster Cluster + excludeNodeIDs []int64 } func (wdt *WatchDmChannelTask) MsgBase() *commonpb.MsgBase { @@ -1005,11 +1145,11 @@ func (wdt *WatchDmChannelTask) Marshal() ([]byte, error) { } func (wdt *WatchDmChannelTask) IsValid() bool { - onService, err := wdt.cluster.isOnline(wdt.NodeID) + online, err := wdt.cluster.isOnline(wdt.NodeID) if err != nil { return false } - return wdt.ctx != nil && onService + return wdt.ctx != nil && online } func (wdt *WatchDmChannelTask) Type() commonpb.MsgType { @@ -1020,16 +1160,22 @@ func (wdt *WatchDmChannelTask) Timestamp() Timestamp { return wdt.Base.Timestamp } +func (wdt *WatchDmChannelTask) UpdateTaskProcess() { + parentTask := wdt.GetParentTask() + if parentTask == nil { + log.Warn("WatchDmChannelTask: parentTask should not be nil") + return + } + parentTask.UpdateTaskProcess() +} + func (wdt *WatchDmChannelTask) PreExecute(context.Context) error { channelInfos := wdt.Infos channels := make([]string, 0) for _, info := range channelInfos { channels = append(channels, info.ChannelName) } - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - wdt.result = status + wdt.SetResultInfo(nil) log.Debug("start do watchDmChannelTask", zap.Strings("dmChannels", channels), zap.Int64("loaded nodeID", wdt.NodeID), @@ -1038,14 +1184,14 @@ func (wdt *WatchDmChannelTask) PreExecute(context.Context) error { } func (wdt *WatchDmChannelTask) Execute(ctx context.Context) error { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } + defer func() { + wdt.retryCount-- + }() + err := wdt.cluster.watchDmChannels(wdt.ctx, wdt.NodeID, wdt.WatchDmChannelsRequest) if err != nil { - log.Error("WatchDmChannelTask: watchDmChannel occur error", zap.Int64("taskID", wdt.ID())) - status.Reason = err.Error() - wdt.result = status + log.Warn("WatchDmChannelTask: watchDmChannel occur error", zap.Int64("taskID", wdt.ID())) + wdt.SetResultInfo(err) return err } @@ -1060,7 +1206,7 @@ func (wdt *WatchDmChannelTask) PostExecute(context.Context) error { return nil } -func (wdt *WatchDmChannelTask) Reschedule() ([]task, error) { +func (wdt *WatchDmChannelTask) Reschedule(ctx context.Context) ([]task, error) { collectionID := wdt.CollectionID channelIDs := make([]string, 0) reScheduledTask := make([]task, 0) @@ -1068,7 +1214,12 @@ func (wdt *WatchDmChannelTask) Reschedule() ([]task, error) { channelIDs = append(channelIDs, info.ChannelName) } - channel2Nodes := shuffleChannelsToQueryNode(channelIDs, wdt.cluster) + wdt.excludeNodeIDs = append(wdt.excludeNodeIDs, wdt.NodeID) + channel2Nodes, err := shuffleChannelsToQueryNode(channelIDs, wdt.cluster, false, wdt.excludeNodeIDs) + if err != nil { + log.Error("watchDmChannel reschedule failed", zap.Int64s("excludeNodes", wdt.excludeNodeIDs), zap.Error(err)) + return nil, err + } node2channelInfos := make(map[int64][]*datapb.VchannelInfo) for index, info := range wdt.Infos { nodeID := channel2Nodes[index] @@ -1079,12 +1230,10 @@ func (wdt *WatchDmChannelTask) Reschedule() ([]task, error) { } for nodeID, infos := range node2channelInfos { - loadSegmentTask := &WatchDmChannelTask{ - BaseTask: BaseTask{ - ctx: wdt.ctx, - Condition: NewTaskCondition(wdt.ctx), - triggerCondition: wdt.triggerCondition, - }, + watchDmChannelBaseTask := newBaseTask(ctx, wdt.TriggerCondition()) + watchDmChannelBaseTask.SetParentTask(wdt.GetParentTask()) + watchDmChannelTask := &WatchDmChannelTask{ + BaseTask: watchDmChannelBaseTask, WatchDmChannelsRequest: &querypb.WatchDmChannelsRequest{ Base: wdt.Base, NodeID: nodeID, @@ -1094,11 +1243,12 @@ func (wdt *WatchDmChannelTask) Reschedule() ([]task, error) { Schema: wdt.Schema, ExcludeInfos: wdt.ExcludeInfos, }, - meta: wdt.meta, - cluster: wdt.cluster, + meta: wdt.meta, + cluster: wdt.cluster, + excludeNodeIDs: wdt.excludeNodeIDs, } - reScheduledTask = append(reScheduledTask, loadSegmentTask) - log.Debug("WatchDmChannelTask: add a watchDmChannelTask to RescheduleTasks", zap.Any("task", loadSegmentTask)) + reScheduledTask = append(reScheduledTask, watchDmChannelTask) + log.Debug("WatchDmChannelTask: add a watchDmChannelTask to RescheduleTasks", zap.Any("task", watchDmChannelTask)) hasWatchQueryChannel := wdt.cluster.hasWatchedQueryChannel(wdt.ctx, nodeID, collectionID) if !hasWatchQueryChannel { @@ -1116,13 +1266,10 @@ func (wdt *WatchDmChannelTask) Reschedule() ([]task, error) { RequestChannelID: queryChannel, ResultChannelID: queryResultChannel, } + watchQueryChannelBaseTask := newBaseTask(ctx, wdt.TriggerCondition()) + watchQueryChannelBaseTask.SetParentTask(wdt.GetParentTask()) watchQueryChannelTask := &WatchQueryChannelTask{ - BaseTask: BaseTask{ - ctx: wdt.ctx, - Condition: NewTaskCondition(wdt.ctx), - triggerCondition: wdt.triggerCondition, - }, - + BaseTask: watchQueryChannelBaseTask, AddQueryChannelRequest: addQueryChannelRequest, cluster: wdt.cluster, } @@ -1135,7 +1282,7 @@ func (wdt *WatchDmChannelTask) Reschedule() ([]task, error) { } type WatchQueryChannelTask struct { - BaseTask + *BaseTask *querypb.AddQueryChannelRequest cluster Cluster } @@ -1149,12 +1296,12 @@ func (wqt *WatchQueryChannelTask) Marshal() ([]byte, error) { } func (wqt *WatchQueryChannelTask) IsValid() bool { - onService, err := wqt.cluster.isOnline(wqt.NodeID) + online, err := wqt.cluster.isOnline(wqt.NodeID) if err != nil { return false } - return wqt.ctx != nil && onService + return wqt.ctx != nil && online } func (wqt *WatchQueryChannelTask) Type() commonpb.MsgType { @@ -1165,11 +1312,17 @@ func (wqt *WatchQueryChannelTask) Timestamp() Timestamp { return wqt.Base.Timestamp } -func (wqt *WatchQueryChannelTask) PreExecute(context.Context) error { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, +func (wqt *WatchQueryChannelTask) UpdateTaskProcess() { + parentTask := wqt.GetParentTask() + if parentTask == nil { + log.Warn("WatchQueryChannelTask: parentTask should not be nil") + return } - wqt.result = status + parentTask.UpdateTaskProcess() +} + +func (wqt *WatchQueryChannelTask) PreExecute(context.Context) error { + wqt.SetResultInfo(nil) log.Debug("start do WatchQueryChannelTask", zap.Int64("collectionID", wqt.CollectionID), zap.String("queryChannel", wqt.RequestChannelID), @@ -1180,14 +1333,14 @@ func (wqt *WatchQueryChannelTask) PreExecute(context.Context) error { } func (wqt *WatchQueryChannelTask) Execute(ctx context.Context) error { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } + defer func() { + wqt.retryCount-- + }() + err := wqt.cluster.addQueryChannel(wqt.ctx, wqt.NodeID, wqt.AddQueryChannelRequest) if err != nil { - log.Error("WatchQueryChannelTask: watchQueryChannel occur error", zap.Int64("taskID", wqt.ID())) - status.Reason = err.Error() - wqt.result = status + log.Warn("WatchQueryChannelTask: watchQueryChannel occur error", zap.Int64("taskID", wqt.ID())) + wqt.SetResultInfo(err) return err } @@ -1214,7 +1367,7 @@ type HandoffTask struct { //*********************** ***load balance task*** ************************// type LoadBalanceTask struct { - BaseTask + *BaseTask *querypb.LoadBalanceRequest rootCoord types.RootCoord dataCoord types.DataCoord @@ -1239,10 +1392,7 @@ func (lbt *LoadBalanceTask) Timestamp() Timestamp { } func (lbt *LoadBalanceTask) PreExecute(context.Context) error { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - lbt.result = status + lbt.SetResultInfo(nil) log.Debug("start do LoadBalanceTask", zap.Int64s("sourceNodeIDs", lbt.SourceNodeIDs), zap.Any("balanceReason", lbt.BalanceReason), @@ -1251,20 +1401,20 @@ func (lbt *LoadBalanceTask) PreExecute(context.Context) error { } func (lbt *LoadBalanceTask) Execute(ctx context.Context) error { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } + defer func() { + lbt.retryCount-- + }() if lbt.triggerCondition == querypb.TriggerCondition_nodeDown { for _, nodeID := range lbt.SourceNodeIDs { - lbt.meta.deleteSegmentInfoByNodeID(nodeID) collectionInfos := lbt.cluster.getCollectionInfosByID(lbt.ctx, nodeID) for _, info := range collectionInfos { collectionID := info.CollectionID metaInfo, err := lbt.meta.getCollectionInfoByID(collectionID) if err != nil { - log.Error("LoadBalanceTask: getCollectionInfoByID occur error", zap.String("error", err.Error())) - continue + log.Warn("LoadBalanceTask: getCollectionInfoByID occur error", zap.String("error", err.Error())) + lbt.SetResultInfo(err) + return err } loadType := metaInfo.LoadType schema := metaInfo.Schema @@ -1277,8 +1427,7 @@ func (lbt *LoadBalanceTask) Execute(ctx context.Context) error { dmChannels, err := lbt.meta.getDmChannelsByNodeID(collectionID, nodeID) if err != nil { - status.Reason = err.Error() - lbt.result = status + lbt.SetResultInfo(err) return err } @@ -1292,8 +1441,7 @@ func (lbt *LoadBalanceTask) Execute(ctx context.Context) error { } recoveryInfo, err := lbt.dataCoord.GetRecoveryInfo(ctx, getRecoveryInfo) if err != nil { - status.Reason = err.Error() - lbt.result = status + lbt.SetResultInfo(err) return err } @@ -1364,10 +1512,10 @@ func (lbt *LoadBalanceTask) Execute(ctx context.Context) error { } } } - err = assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs) + err = assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, true) if err != nil { - status.Reason = err.Error() - lbt.result = status + log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) + lbt.SetResultInfo(err) return err } log.Debug("loadBalanceTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) @@ -1388,12 +1536,21 @@ func (lbt *LoadBalanceTask) Execute(ctx context.Context) error { } func (lbt *LoadBalanceTask) PostExecute(context.Context) error { - for _, id := range lbt.SourceNodeIDs { - err := lbt.cluster.removeNodeInfo(id) - if err != nil { - log.Error("LoadBalanceTask: remove mode info error", zap.Int64("nodeID", id)) + if lbt.result.ErrorCode == commonpb.ErrorCode_Success { + for _, id := range lbt.SourceNodeIDs { + err := lbt.cluster.removeNodeInfo(id) + if err != nil { + log.Error("LoadBalanceTask: occur error when removing node info from cluster", zap.Int64("nodeID", id)) + } + err = lbt.meta.deleteSegmentInfoByNodeID(id) + if err != nil { + log.Error("LoadBalanceTask: occur error when removing node info from meta", zap.Int64("nodeID", id)) + } } + } else { + lbt.childTasks = []task{} } + log.Debug("LoadBalanceTask postExecute done", zap.Int64s("sourceNodeIDs", lbt.SourceNodeIDs), zap.Any("balanceReason", lbt.BalanceReason), @@ -1401,7 +1558,7 @@ func (lbt *LoadBalanceTask) PostExecute(context.Context) error { return nil } -func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster) []int64 { +func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster, wait bool, excludeNodeIDs []int64) ([]int64, error) { maxNumChannels := 0 nodes := make(map[int64]Node) var err error @@ -1409,10 +1566,21 @@ func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster) []int64 { nodes, err = cluster.onlineNodes() if err != nil { log.Debug(err.Error()) + if !wait { + return nil, err + } time.Sleep(1 * time.Second) continue } - break + for _, id := range excludeNodeIDs { + delete(nodes, id) + } + if len(nodes) > 0 { + break + } + if !wait { + return nil, errors.New("no queryNode to allocate") + } } for nodeID := range nodes { @@ -1423,7 +1591,7 @@ func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster) []int64 { } res := make([]int64, 0) if len(dmChannels) == 0 { - return res + return res, nil } offset := 0 @@ -1439,7 +1607,7 @@ func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster) []int64 { res = append(res, nodeID) offset++ if offset == len(dmChannels) { - return res + return res, nil } } } else { @@ -1447,7 +1615,7 @@ func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster) []int64 { res = append(res, nodeID) offset++ if offset == len(dmChannels) { - return res + return res, nil } } } @@ -1457,7 +1625,7 @@ func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster) []int64 { } } -func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster Cluster) []int64 { +func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster Cluster, wait bool, excludeNodeIDs []int64) ([]int64, error) { maxNumSegments := 0 nodes := make(map[int64]Node) var err error @@ -1465,10 +1633,21 @@ func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster Cluster) []int64 nodes, err = cluster.onlineNodes() if err != nil { log.Debug(err.Error()) + if !wait { + return nil, err + } time.Sleep(1 * time.Second) continue } - break + for _, id := range excludeNodeIDs { + delete(nodes, id) + } + if len(nodes) > 0 { + break + } + if !wait { + return nil, errors.New("no queryNode to allocate") + } } for nodeID := range nodes { numSegments, _ := cluster.getNumSegments(nodeID) @@ -1479,7 +1658,7 @@ func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster Cluster) []int64 res := make([]int64, 0) if len(segmentIDs) == 0 { - return res + return res, nil } offset := 0 @@ -1495,7 +1674,7 @@ func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster Cluster) []int64 res = append(res, nodeID) offset++ if offset == len(segmentIDs) { - return res + return res, nil } } } else { @@ -1503,7 +1682,7 @@ func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster Cluster) []int64 res = append(res, nodeID) offset++ if offset == len(segmentIDs) { - return res + return res, nil } } } @@ -1544,14 +1723,15 @@ func mergeVChannelInfo(info1 *datapb.VchannelInfo, info2 *datapb.VchannelInfo) * FlushedSegments: flushedSegments, } } + func assignInternalTask(ctx context.Context, collectionID UniqueID, parentTask task, meta Meta, cluster Cluster, loadSegmentRequests []*querypb.LoadSegmentsRequest, - watchDmChannelRequests []*querypb.WatchDmChannelsRequest) error { - + watchDmChannelRequests []*querypb.WatchDmChannelsRequest, + wait bool) error { sp, _ := trace.StartSpanFromContext(ctx) defer sp.Finish() segmentsToLoad := make([]UniqueID, 0) @@ -1562,9 +1742,17 @@ func assignInternalTask(ctx context.Context, for _, req := range watchDmChannelRequests { channelsToWatch = append(channelsToWatch, req.Infos[0].ChannelName) } - segment2Nodes := shuffleSegmentsToQueryNode(segmentsToLoad, cluster) - watchRequest2Nodes := shuffleChannelsToQueryNode(channelsToWatch, cluster) + segment2Nodes, err := shuffleSegmentsToQueryNode(segmentsToLoad, cluster, wait, nil) + if err != nil { + log.Error("assignInternalTask: segment to node failed", zap.Any("segments map", segment2Nodes), zap.Int64("collectionID", collectionID)) + return err + } log.Debug("assignInternalTask: segment to node", zap.Any("segments map", segment2Nodes), zap.Int64("collectionID", collectionID)) + watchRequest2Nodes, err := shuffleChannelsToQueryNode(channelsToWatch, cluster, wait, nil) + if err != nil { + log.Error("assignInternalTask: watch request to node failed", zap.Any("request map", watchRequest2Nodes), zap.Int64("collectionID", collectionID)) + return err + } log.Debug("assignInternalTask: watch request to node", zap.Any("request map", watchRequest2Nodes), zap.Int64("collectionID", collectionID)) watchQueryChannelInfo := make(map[int64]bool) @@ -1592,16 +1780,14 @@ func assignInternalTask(ctx context.Context, for nodeID, loadSegmentsReq := range node2Segments { ctx = opentracing.ContextWithSpan(context.Background(), sp) loadSegmentsReq.NodeID = nodeID + baseTask := newBaseTask(ctx, parentTask.TriggerCondition()) + baseTask.SetParentTask(parentTask) loadSegmentTask := &LoadSegmentTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, - + BaseTask: baseTask, LoadSegmentsRequest: loadSegmentsReq, meta: meta, cluster: cluster, + excludeNodeIDs: []int64{}, } parentTask.AddChildTask(loadSegmentTask) log.Debug("assignInternalTask: add a loadSegmentTask childTask", zap.Any("task", loadSegmentTask)) @@ -1611,15 +1797,14 @@ func assignInternalTask(ctx context.Context, ctx = opentracing.ContextWithSpan(context.Background(), sp) watchDmChannelReq := watchDmChannelRequests[index] watchDmChannelReq.NodeID = nodeID + baseTask := newBaseTask(ctx, parentTask.TriggerCondition()) + baseTask.SetParentTask(parentTask) watchDmChannelTask := &WatchDmChannelTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, WatchDmChannelsRequest: watchDmChannelReq, meta: meta, cluster: cluster, + excludeNodeIDs: []int64{}, } parentTask.AddChildTask(watchDmChannelTask) log.Debug("assignInternalTask: add a watchDmChannelTask childTask", zap.Any("task", watchDmChannelTask)) @@ -1642,12 +1827,10 @@ func assignInternalTask(ctx context.Context, RequestChannelID: queryChannel, ResultChannelID: queryResultChannel, } + baseTask := newBaseTask(ctx, parentTask.TriggerCondition()) + baseTask.SetParentTask(parentTask) watchQueryChannelTask := &WatchQueryChannelTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, AddQueryChannelRequest: addQueryChannelRequest, cluster: cluster, @@ -1656,6 +1839,5 @@ func assignInternalTask(ctx context.Context, log.Debug("assignInternalTask: add a watchQueryChannelTask childTask", zap.Any("task", watchQueryChannelTask)) } } - return nil } diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index d792649a3d7f0a02e43c149de16f14de5d49a902..7eb98ad017e4792e1e079b4f75b6efdd741ea458 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -59,31 +59,29 @@ func (queue *TaskQueue) taskFull() bool { return int64(queue.tasks.Len()) >= queue.maxTask } -func (queue *TaskQueue) addTask(tasks []task) { +func (queue *TaskQueue) addTask(t task) { queue.Lock() defer queue.Unlock() - for _, t := range tasks { - if queue.tasks.Len() == 0 { - queue.taskChan <- 1 - queue.tasks.PushBack(t) - continue - } + if queue.tasks.Len() == 0 { + queue.taskChan <- 1 + queue.tasks.PushBack(t) + return + } - for e := queue.tasks.Back(); e != nil; e = e.Prev() { - if t.TaskPriority() > e.Value.(task).TaskPriority() { - if e.Prev() == nil { - queue.taskChan <- 1 - queue.tasks.InsertBefore(t, e) - break - } - continue + for e := queue.tasks.Back(); e != nil; e = e.Prev() { + if t.TaskPriority() > e.Value.(task).TaskPriority() { + if e.Prev() == nil { + queue.taskChan <- 1 + queue.tasks.InsertBefore(t, e) + break } - //TODO:: take care of timestamp - queue.taskChan <- 1 - queue.tasks.InsertAfter(t, e) - break + continue } + //TODO:: take care of timestamp + queue.taskChan <- 1 + queue.tasks.InsertAfter(t, e) + break } } @@ -123,12 +121,13 @@ func NewTaskQueue() *TaskQueue { // TaskScheduler controls the scheduling of trigger tasks and internal tasks type TaskScheduler struct { - triggerTaskQueue *TaskQueue - activateTaskChan chan task - meta Meta - cluster Cluster - taskIDAllocator func() (UniqueID, error) - client *etcdkv.EtcdKV + triggerTaskQueue *TaskQueue + activateTaskChan chan task + meta Meta + cluster Cluster + taskIDAllocator func() (UniqueID, error) + client *etcdkv.EtcdKV + stopActivateTaskLoopChan chan int rootCoord types.RootCoord dataCoord types.DataCoord @@ -141,17 +140,20 @@ type TaskScheduler struct { func NewTaskScheduler(ctx context.Context, meta Meta, cluster Cluster, kv *etcdkv.EtcdKV, rootCoord types.RootCoord, dataCoord types.DataCoord) (*TaskScheduler, error) { ctx1, cancel := context.WithCancel(ctx) taskChan := make(chan task, 1024) + stopTaskLoopChan := make(chan int, 1) s := &TaskScheduler{ - ctx: ctx1, - cancel: cancel, - meta: meta, - cluster: cluster, - activateTaskChan: taskChan, - client: kv, - rootCoord: rootCoord, - dataCoord: dataCoord, + ctx: ctx1, + cancel: cancel, + meta: meta, + cluster: cluster, + activateTaskChan: taskChan, + client: kv, + stopActivateTaskLoopChan: stopTaskLoopChan, + rootCoord: rootCoord, + dataCoord: dataCoord, } s.triggerTaskQueue = NewTaskQueue() + //init id allocator etcdKV, err := tsoutil.NewTSOKVBase(Params.EtcdEndpoints, Params.KvRootPath, "queryCoordTaskID") if err != nil { return nil, err @@ -166,6 +168,7 @@ func NewTaskScheduler(ctx context.Context, meta Meta, cluster Cluster, kv *etcdk } err = s.reloadFromKV() if err != nil { + log.Error("reload task from kv failed", zap.Error(err)) return nil, err } @@ -192,7 +195,7 @@ func (scheduler *TaskScheduler) reloadFromKV() error { if err != nil { return err } - t, err := scheduler.unmarshalTask(triggerTaskValues[index]) + t, err := scheduler.unmarshalTask(taskID, triggerTaskValues[index]) if err != nil { return err } @@ -205,7 +208,7 @@ func (scheduler *TaskScheduler) reloadFromKV() error { if err != nil { return err } - t, err := scheduler.unmarshalTask(activeTaskValues[index]) + t, err := scheduler.unmarshalTask(taskID, activeTaskValues[index]) if err != nil { return err } @@ -232,15 +235,17 @@ func (scheduler *TaskScheduler) reloadFromKV() error { } var doneTriggerTask task = nil - for id, t := range triggerTasks { - if taskInfos[id] == taskDone { + for _, t := range triggerTasks { + if t.State() == taskDone { doneTriggerTask = t for _, childTask := range activeTasks { + childTask.SetParentTask(t) //replace child task after reScheduler t.AddChildTask(childTask) } + t.SetResultInfo(nil) continue } - scheduler.triggerTaskQueue.addTask([]task{t}) + scheduler.triggerTaskQueue.addTask(t) } if doneTriggerTask != nil { @@ -250,26 +255,23 @@ func (scheduler *TaskScheduler) reloadFromKV() error { return nil } -func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) { +func (scheduler *TaskScheduler) unmarshalTask(taskID UniqueID, t string) (task, error) { header := commonpb.MsgHeader{} err := proto.Unmarshal([]byte(t), &header) if err != nil { return nil, fmt.Errorf("Failed to unmarshal message header, err %s ", err.Error()) } var newTask task + baseTask := newBaseTask(scheduler.ctx, querypb.TriggerCondition_grpcRequest) switch header.Base.MsgType { case commonpb.MsgType_LoadCollection: loadReq := querypb.LoadCollectionRequest{} err = proto.Unmarshal([]byte(t), &loadReq) if err != nil { - log.Error(err.Error()) + return nil, err } loadCollectionTask := &LoadCollectionTask{ - BaseTask: BaseTask{ - ctx: scheduler.ctx, - Condition: NewTaskCondition(scheduler.ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, LoadCollectionRequest: &loadReq, rootCoord: scheduler.rootCoord, dataCoord: scheduler.dataCoord, @@ -281,14 +283,10 @@ func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) { loadReq := querypb.LoadPartitionsRequest{} err = proto.Unmarshal([]byte(t), &loadReq) if err != nil { - log.Error(err.Error()) + return nil, err } loadPartitionTask := &LoadPartitionTask{ - BaseTask: BaseTask{ - ctx: scheduler.ctx, - Condition: NewTaskCondition(scheduler.ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, LoadPartitionsRequest: &loadReq, dataCoord: scheduler.dataCoord, cluster: scheduler.cluster, @@ -299,14 +297,10 @@ func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) { loadReq := querypb.ReleaseCollectionRequest{} err = proto.Unmarshal([]byte(t), &loadReq) if err != nil { - log.Error(err.Error()) + return nil, err } releaseCollectionTask := &ReleaseCollectionTask{ - BaseTask: BaseTask{ - ctx: scheduler.ctx, - Condition: NewTaskCondition(scheduler.ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, ReleaseCollectionRequest: &loadReq, cluster: scheduler.cluster, meta: scheduler.meta, @@ -317,96 +311,79 @@ func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) { loadReq := querypb.ReleasePartitionsRequest{} err = proto.Unmarshal([]byte(t), &loadReq) if err != nil { - log.Error(err.Error()) + return nil, err } releasePartitionTask := &ReleasePartitionTask{ - BaseTask: BaseTask{ - ctx: scheduler.ctx, - Condition: NewTaskCondition(scheduler.ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, ReleasePartitionsRequest: &loadReq, cluster: scheduler.cluster, } newTask = releasePartitionTask case commonpb.MsgType_LoadSegments: + //TODO::trigger condition may be different loadReq := querypb.LoadSegmentsRequest{} err = proto.Unmarshal([]byte(t), &loadReq) if err != nil { - log.Error(err.Error()) + return nil, err } loadSegmentTask := &LoadSegmentTask{ - BaseTask: BaseTask{ - ctx: scheduler.ctx, - Condition: NewTaskCondition(scheduler.ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, LoadSegmentsRequest: &loadReq, cluster: scheduler.cluster, meta: scheduler.meta, + excludeNodeIDs: []int64{}, } newTask = loadSegmentTask case commonpb.MsgType_ReleaseSegments: + //TODO::trigger condition may be different loadReq := querypb.ReleaseSegmentsRequest{} err = proto.Unmarshal([]byte(t), &loadReq) if err != nil { - log.Error(err.Error()) + return nil, err } releaseSegmentTask := &ReleaseSegmentTask{ - BaseTask: BaseTask{ - ctx: scheduler.ctx, - Condition: NewTaskCondition(scheduler.ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, ReleaseSegmentsRequest: &loadReq, cluster: scheduler.cluster, } newTask = releaseSegmentTask case commonpb.MsgType_WatchDmChannels: + //TODO::trigger condition may be different loadReq := querypb.WatchDmChannelsRequest{} err = proto.Unmarshal([]byte(t), &loadReq) if err != nil { - log.Error(err.Error()) + return nil, err } watchDmChannelTask := &WatchDmChannelTask{ - BaseTask: BaseTask{ - ctx: scheduler.ctx, - Condition: NewTaskCondition(scheduler.ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, WatchDmChannelsRequest: &loadReq, cluster: scheduler.cluster, meta: scheduler.meta, + excludeNodeIDs: []int64{}, } newTask = watchDmChannelTask case commonpb.MsgType_WatchQueryChannels: + //TODO::trigger condition may be different loadReq := querypb.AddQueryChannelRequest{} err = proto.Unmarshal([]byte(t), &loadReq) if err != nil { - log.Error(err.Error()) + return nil, err } watchQueryChannelTask := &WatchQueryChannelTask{ - BaseTask: BaseTask{ - ctx: scheduler.ctx, - Condition: NewTaskCondition(scheduler.ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, + BaseTask: baseTask, AddQueryChannelRequest: &loadReq, cluster: scheduler.cluster, } newTask = watchQueryChannelTask case commonpb.MsgType_LoadBalanceSegments: + //TODO::trigger condition may be different loadReq := querypb.LoadBalanceRequest{} err = proto.Unmarshal([]byte(t), &loadReq) if err != nil { - log.Error(err.Error()) + return nil, err } loadBalanceTask := &LoadBalanceTask{ - BaseTask: BaseTask{ - ctx: scheduler.ctx, - Condition: NewTaskCondition(scheduler.ctx), - triggerCondition: loadReq.BalanceReason, - }, + BaseTask: baseTask, LoadBalanceRequest: &loadReq, rootCoord: scheduler.rootCoord, dataCoord: scheduler.dataCoord, @@ -420,105 +397,115 @@ func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) { return nil, err } + newTask.SetID(taskID) return newTask, nil } // Enqueue pushs a trigger task to triggerTaskQueue and assigns task id -func (scheduler *TaskScheduler) Enqueue(tasks []task) { - for _, t := range tasks { - id, err := scheduler.taskIDAllocator() - if err != nil { - log.Error(err.Error()) - } - t.SetID(id) +func (scheduler *TaskScheduler) Enqueue(t task) error { + id, err := scheduler.taskIDAllocator() + if err != nil { + log.Error("allocator trigger taskID failed", zap.Error(err)) + return err + } + t.SetID(id) + kvs := make(map[string]string) + taskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, t.ID()) + blobs, err := t.Marshal() + if err != nil { + log.Error("error when save marshal task", zap.Int64("taskID", t.ID()), zap.Error(err)) + return err + } + kvs[taskKey] = string(blobs) + stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) + kvs[stateKey] = strconv.Itoa(int(taskUndo)) + err = scheduler.client.MultiSave(kvs) + if err != nil { + //TODO::clean etcd meta + log.Error("error when save trigger task to etcd", zap.Int64("taskID", t.ID()), zap.Error(err)) + return err + } + t.SetState(taskUndo) + scheduler.triggerTaskQueue.addTask(t) + log.Debug("EnQueue a triggerTask and save to etcd", zap.Int64("taskID", t.ID())) + + return nil +} + +func (scheduler *TaskScheduler) processTask(t task) error { + var taskInfoKey string + // assign taskID for childTask and update triggerTask's childTask to etcd + updateKVFn := func(parentTask task) error { kvs := make(map[string]string) - taskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, t.ID()) - blobs, err := t.Marshal() - if err != nil { - log.Error("error when save marshal task", zap.Int64("taskID", t.ID()), zap.String("error", err.Error())) + kvs[taskInfoKey] = strconv.Itoa(int(taskDone)) + for _, childTask := range parentTask.GetChildTask() { + id, err := scheduler.taskIDAllocator() + if err != nil { + return err + } + childTask.SetID(id) + childTaskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, childTask.ID()) + blobs, err := childTask.Marshal() + if err != nil { + return err + } + kvs[childTaskKey] = string(blobs) + stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, childTask.ID()) + kvs[stateKey] = strconv.Itoa(int(taskUndo)) } - kvs[taskKey] = string(blobs) - stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) - kvs[stateKey] = strconv.Itoa(int(taskUndo)) - err = scheduler.client.MultiSave(kvs) + err := scheduler.client.MultiSave(kvs) if err != nil { - log.Error("error when save trigger task to etcd", zap.Int64("taskID", t.ID()), zap.String("error", err.Error())) + return err } - log.Debug("EnQueue a triggerTask and save to etcd", zap.Int64("taskID", t.ID())) - t.SetState(taskUndo) + return nil } - scheduler.triggerTaskQueue.addTask(tasks) -} - -func (scheduler *TaskScheduler) processTask(t task) error { span, ctx := trace.StartSpanFromContext(t.TraceCtx(), opentracing.Tags{ "Type": t.Type(), "ID": t.ID(), }) + var err error defer span.Finish() + + defer func() { + //task postExecute + span.LogFields(oplog.Int64("processTask: scheduler process PostExecute", t.ID())) + t.PostExecute(ctx) + }() + + // task preExecute span.LogFields(oplog.Int64("processTask: scheduler process PreExecute", t.ID())) t.PreExecute(ctx) - - key := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) - err := scheduler.client.Save(key, strconv.Itoa(int(taskDoing))) + taskInfoKey = fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) + err = scheduler.client.Save(taskInfoKey, strconv.Itoa(int(taskDoing))) if err != nil { - log.Error("processTask: update task state err", zap.String("reason", err.Error()), zap.Int64("taskID", t.ID())) trace.LogError(span, err) + t.SetResultInfo(err) return err } t.SetState(taskDoing) + // task execute span.LogFields(oplog.Int64("processTask: scheduler process Execute", t.ID())) err = t.Execute(ctx) if err != nil { - log.Debug("processTask: execute err", zap.String("reason", err.Error()), zap.Int64("taskID", t.ID())) trace.LogError(span, err) return err } - - for _, childTask := range t.GetChildTask() { - if childTask == nil { - log.Error("processTask: child task equal nil") - continue - } - - id, err := scheduler.taskIDAllocator() - if err != nil { - return err - } - childTask.SetID(id) - kvs := make(map[string]string) - taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, childTask.ID()) - blobs, err := childTask.Marshal() - if err != nil { - log.Error("processTask: marshal task err", zap.String("reason", err.Error())) - trace.LogError(span, err) - return err - } - kvs[taskKey] = string(blobs) - stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, childTask.ID()) - kvs[stateKey] = strconv.Itoa(int(taskUndo)) - err = scheduler.client.MultiSave(kvs) - if err != nil { - log.Error("processTask: save active task info err", zap.String("reason", err.Error())) - trace.LogError(span, err) - return err - } - log.Debug("processTask: save active task to etcd", zap.Int64("parent taskID", t.ID()), zap.Int64("child taskID", childTask.ID())) - } - - err = scheduler.client.Save(key, strconv.Itoa(int(taskDone))) + err = updateKVFn(t) if err != nil { - log.Error("processTask: update task state err", zap.String("reason", err.Error()), zap.Int64("taskID", t.ID())) trace.LogError(span, err) + t.SetResultInfo(err) return err } + log.Debug("processTask: update etcd success", zap.Int64("parent taskID", t.ID())) + if t.Type() == commonpb.MsgType_LoadCollection || t.Type() == commonpb.MsgType_LoadPartitions { + t.Notify(nil) + } - span.LogFields(oplog.Int64("processTask: scheduler process PostExecute", t.ID())) - t.PostExecute(ctx) t.SetState(taskDone) + t.UpdateTaskProcess() return nil } @@ -526,140 +513,258 @@ func (scheduler *TaskScheduler) processTask(t task) error { func (scheduler *TaskScheduler) scheduleLoop() { defer scheduler.wg.Done() activeTaskWg := &sync.WaitGroup{} + var triggerTask task + + processInternalTaskFn := func(activateTasks []task, triggerTask task) { + log.Debug("scheduleLoop: num of child task", zap.Int("num child task", len(activateTasks))) + for _, childTask := range activateTasks { + if childTask != nil { + log.Debug("scheduleLoop: add a activate task to activateChan", zap.Int64("taskID", childTask.ID())) + scheduler.activateTaskChan <- childTask + activeTaskWg.Add(1) + go scheduler.waitActivateTaskDone(activeTaskWg, childTask, triggerTask) + } + } + activeTaskWg.Wait() + } + + rollBackInterTaskFn := func(triggerTask task, originInternalTasks []task, rollBackTasks []task) error { + saves := make(map[string]string) + removes := make([]string, 0) + childTaskIDs := make([]int64, 0) + for _, t := range originInternalTasks { + childTaskIDs = append(childTaskIDs, t.ID()) + taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID()) + removes = append(removes, taskKey) + stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) + removes = append(removes, stateKey) + } + + for _, t := range rollBackTasks { + id, err := scheduler.taskIDAllocator() + if err != nil { + return err + } + t.SetID(id) + taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID()) + blobs, err := t.Marshal() + if err != nil { + return err + } + saves[taskKey] = string(blobs) + stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) + saves[stateKey] = strconv.Itoa(int(taskUndo)) + } + + err := scheduler.client.MultiSaveAndRemove(saves, removes) + if err != nil { + return err + } + for _, taskID := range childTaskIDs { + triggerTask.RemoveChildTaskByID(taskID) + } + for _, t := range rollBackTasks { + triggerTask.AddChildTask(t) + } + + return nil + } + + removeTaskFromKVFn := func(triggerTask task) error { + keys := make([]string, 0) + taskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, triggerTask.ID()) + stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, triggerTask.ID()) + keys = append(keys, taskKey) + keys = append(keys, stateKey) + childTasks := triggerTask.GetChildTask() + for _, t := range childTasks { + taskKey = fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID()) + stateKey = fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) + keys = append(keys, taskKey) + keys = append(keys, stateKey) + } + err := scheduler.client.MultiRemove(keys) + if err != nil { + return err + } + return nil + } for { - var err error = nil + var err error select { case <-scheduler.ctx.Done(): + scheduler.stopActivateTaskLoopChan <- 1 return case <-scheduler.triggerTaskQueue.Chan(): - t := scheduler.triggerTaskQueue.PopTask() - log.Debug("scheduleLoop: pop a triggerTask from triggerTaskQueue", zap.Int64("taskID", t.ID())) - if t.State() < taskDone { - err = scheduler.processTask(t) + triggerTask = scheduler.triggerTaskQueue.PopTask() + log.Debug("scheduleLoop: pop a triggerTask from triggerTaskQueue", zap.Int64("triggerTaskID", triggerTask.ID())) + alreadyNotify := true + if triggerTask.State() == taskUndo || triggerTask.State() == taskDoing { + err = scheduler.processTask(triggerTask) if err != nil { - log.Error("scheduleLoop: process task error", zap.Any("error", err.Error())) - t.Notify(err) - t.PostExecute(scheduler.ctx) - } - if t.Type() == commonpb.MsgType_LoadCollection || t.Type() == commonpb.MsgType_LoadPartitions { - t.Notify(err) + log.Debug("scheduleLoop: process triggerTask failed", zap.Int64("triggerTaskID", triggerTask.ID()), zap.Error(err)) + alreadyNotify = false } } - log.Debug("scheduleLoop: num of child task", zap.Int("num child task", len(t.GetChildTask()))) - for _, childTask := range t.GetChildTask() { - if childTask != nil { - log.Debug("scheduleLoop: add a activate task to activateChan", zap.Int64("taskID", childTask.ID())) - scheduler.activateTaskChan <- childTask - activeTaskWg.Add(1) - go scheduler.waitActivateTaskDone(activeTaskWg, childTask) - } + if triggerTask.Type() != commonpb.MsgType_LoadCollection && triggerTask.Type() != commonpb.MsgType_LoadPartitions { + alreadyNotify = false } - activeTaskWg.Wait() - if t.Type() == commonpb.MsgType_LoadCollection || t.Type() == commonpb.MsgType_LoadPartitions { - t.PostExecute(scheduler.ctx) + + childTasks := triggerTask.GetChildTask() + if len(childTasks) != 0 { + activateTasks := make([]task, len(childTasks)) + copy(activateTasks, childTasks) + processInternalTaskFn(activateTasks, triggerTask) + resultStatus := triggerTask.GetResultInfo() + if resultStatus.ErrorCode != commonpb.ErrorCode_Success { + rollBackTasks := triggerTask.RollBack(scheduler.ctx) + log.Debug("scheduleLoop: start rollBack after triggerTask failed", + zap.Int64("triggerTaskID", triggerTask.ID()), + zap.Any("rollBackTasks", rollBackTasks)) + err = rollBackInterTaskFn(triggerTask, childTasks, rollBackTasks) + if err != nil { + log.Error("scheduleLoop: rollBackInternalTask error", + zap.Int64("triggerTaskID", triggerTask.ID()), + zap.Error(err)) + triggerTask.SetResultInfo(err) + } else { + processInternalTaskFn(rollBackTasks, triggerTask) + } + } } - keys := make([]string, 0) - taskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, t.ID()) - stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) - keys = append(keys, taskKey) - keys = append(keys, stateKey) - err = scheduler.client.MultiRemove(keys) + err = removeTaskFromKVFn(triggerTask) if err != nil { - log.Error("scheduleLoop: error when remove trigger task to etcd", zap.Int64("taskID", t.ID())) - t.Notify(err) - continue + log.Error("scheduleLoop: error when remove trigger and internal tasks from etcd", zap.Int64("triggerTaskID", triggerTask.ID()), zap.Error(err)) + triggerTask.SetResultInfo(err) + } else { + log.Debug("scheduleLoop: trigger task done and delete from etcd", zap.Int64("triggerTaskID", triggerTask.ID())) + } + + resultStatus := triggerTask.GetResultInfo() + if resultStatus.ErrorCode != commonpb.ErrorCode_Success { + triggerTask.SetState(taskFailed) + if !alreadyNotify { + triggerTask.Notify(errors.New(resultStatus.Reason)) + } + } else { + triggerTask.UpdateTaskProcess() + triggerTask.SetState(taskExpired) + if !alreadyNotify { + triggerTask.Notify(nil) + } } - log.Debug("scheduleLoop: trigger task done and delete from etcd", zap.Int64("taskID", t.ID())) - t.Notify(err) } } } -func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task) { +func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task, triggerTask task) { defer wg.Done() - err := t.WaitToFinish() - if err != nil { - log.Debug("waitActivateTaskDone: activate task return err", zap.Any("error", err.Error()), zap.Int64("taskID", t.ID())) - redoFunc1 := func() { - if !t.IsValid() { - reScheduledTasks, err := t.Reschedule() - if err != nil { - log.Error(err.Error()) - return - } - removes := make([]string, 0) - taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID()) - removes = append(removes, taskKey) - stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) - removes = append(removes, stateKey) - - saves := make(map[string]string) - reSchedID := make([]int64, 0) - for _, rt := range reScheduledTasks { - if rt != nil { - id, err := scheduler.taskIDAllocator() - if err != nil { - log.Error(err.Error()) - continue - } - rt.SetID(id) - taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, rt.ID()) - blobs, err := rt.Marshal() - if err != nil { - log.Error("waitActivateTaskDone: error when marshal active task") - continue - //TODO::xige-16 deal error when marshal task failed - } - saves[taskKey] = string(blobs) - stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, rt.ID()) - saves[stateKey] = strconv.Itoa(int(taskUndo)) - reSchedID = append(reSchedID, rt.ID()) + var err error + redoFunc1 := func() { + if !t.IsValid() || !t.IsRetryable() { + log.Debug("waitActivateTaskDone: reSchedule the activate task", + zap.Int64("taskID", t.ID()), + zap.Int64("triggerTaskID", triggerTask.ID())) + reScheduledTasks, err := t.Reschedule(scheduler.ctx) + if err != nil { + log.Error("waitActivateTaskDone: reschedule task error", + zap.Int64("taskID", t.ID()), + zap.Int64("triggerTaskID", triggerTask.ID()), + zap.Error(err)) + triggerTask.SetResultInfo(err) + return + } + removes := make([]string, 0) + taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID()) + removes = append(removes, taskKey) + stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) + removes = append(removes, stateKey) + + saves := make(map[string]string) + for _, rt := range reScheduledTasks { + if rt != nil { + id, err := scheduler.taskIDAllocator() + if err != nil { + log.Error("waitActivateTaskDone: allocate id error", + zap.Int64("triggerTaskID", triggerTask.ID()), + zap.Error(err)) + triggerTask.SetResultInfo(err) + return } - } - err = scheduler.client.MultiSaveAndRemove(saves, removes) - if err != nil { - log.Error("waitActivateTaskDone: error when save and remove task from etcd") - //TODO::xige-16 deal error when save meta failed - } - log.Debug("waitActivateTaskDone: delete failed active task and save reScheduled task to etcd", zap.Int64("failed taskID", t.ID()), zap.Int64s("reScheduled taskIDs", reSchedID)) - - for _, rt := range reScheduledTasks { - if rt != nil { - log.Debug("waitActivateTaskDone: add a reScheduled active task to activateChan", zap.Int64("taskID", rt.ID())) - scheduler.activateTaskChan <- rt - wg.Add(1) - go scheduler.waitActivateTaskDone(wg, rt) + rt.SetID(id) + log.Debug("waitActivateTaskDone: reScheduler set id", zap.Int64("id", rt.ID())) + taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, rt.ID()) + blobs, err := rt.Marshal() + if err != nil { + log.Error("waitActivateTaskDone: error when marshal active task", + zap.Int64("triggerTaskID", triggerTask.ID()), + zap.Error(err)) + triggerTask.SetResultInfo(err) + return } + saves[taskKey] = string(blobs) + stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, rt.ID()) + saves[stateKey] = strconv.Itoa(int(taskUndo)) } - //delete task from etcd - } else { - log.Debug("waitActivateTaskDone: retry the active task", zap.Int64("taskID", t.ID())) - scheduler.activateTaskChan <- t - wg.Add(1) - go scheduler.waitActivateTaskDone(wg, t) } + //TODO::queryNode auto watch queryChannel, then update etcd use same id directly + err = scheduler.client.MultiSaveAndRemove(saves, removes) + if err != nil { + log.Error("waitActivateTaskDone: error when save and remove task from etcd", zap.Int64("triggerTaskID", triggerTask.ID())) + triggerTask.SetResultInfo(err) + return + } + triggerTask.RemoveChildTaskByID(t.ID()) + log.Debug("waitActivateTaskDone: delete failed active task and save reScheduled task to etcd", + zap.Int64("triggerTaskID", triggerTask.ID()), + zap.Int64("failed taskID", t.ID()), + zap.Any("reScheduled tasks", reScheduledTasks)) + + for _, rt := range reScheduledTasks { + if rt != nil { + triggerTask.AddChildTask(rt) + log.Debug("waitActivateTaskDone: add a reScheduled active task to activateChan", zap.Int64("taskID", rt.ID())) + scheduler.activateTaskChan <- rt + wg.Add(1) + go scheduler.waitActivateTaskDone(wg, rt, triggerTask) + } + } + //delete task from etcd + } else { + log.Debug("waitActivateTaskDone: retry the active task", + zap.Int64("taskID", t.ID()), + zap.Int64("triggerTaskID", triggerTask.ID())) + scheduler.activateTaskChan <- t + wg.Add(1) + go scheduler.waitActivateTaskDone(wg, t, triggerTask) } + } - redoFunc2 := func() { - if t.IsValid() { - log.Debug("waitActivateTaskDone: retry the active task", zap.Int64("taskID", t.ID())) - scheduler.activateTaskChan <- t - wg.Add(1) - go scheduler.waitActivateTaskDone(wg, t) - } else { - removes := make([]string, 0) - taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID()) - removes = append(removes, taskKey) - stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) - removes = append(removes, stateKey) - err = scheduler.client.MultiRemove(removes) - if err != nil { - log.Error("waitActivateTaskDone: error when remove task from etcd", zap.Int64("taskID", t.ID())) - } + redoFunc2 := func(err error) { + if t.IsValid() { + if !t.IsRetryable() { + log.Error("waitActivateTaskDone: activate task failed after retry", + zap.Int64("taskID", t.ID()), + zap.Int64("triggerTaskID", triggerTask.ID())) + triggerTask.SetResultInfo(err) + return } + log.Debug("waitActivateTaskDone: retry the active task", + zap.Int64("taskID", t.ID()), + zap.Int64("triggerTaskID", triggerTask.ID())) + scheduler.activateTaskChan <- t + wg.Add(1) + go scheduler.waitActivateTaskDone(wg, t, triggerTask) } + } + err = t.WaitToFinish() + if err != nil { + log.Debug("waitActivateTaskDone: activate task return err", + zap.Int64("taskID", t.ID()), + zap.Int64("triggerTaskID", triggerTask.ID()), + zap.Error(err)) switch t.Type() { case commonpb.MsgType_LoadSegments: @@ -667,48 +772,37 @@ func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task) case commonpb.MsgType_WatchDmChannels: redoFunc1() case commonpb.MsgType_WatchQueryChannels: - redoFunc2() + redoFunc2(err) case commonpb.MsgType_ReleaseSegments: - redoFunc2() + redoFunc2(err) case commonpb.MsgType_ReleaseCollection: - redoFunc2() + redoFunc2(err) case commonpb.MsgType_ReleasePartitions: - redoFunc2() + redoFunc2(err) default: //TODO:: case commonpb.MsgType_RemoveDmChannels: } } else { - keys := make([]string, 0) - taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID()) - stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) - keys = append(keys, taskKey) - keys = append(keys, stateKey) - err = scheduler.client.MultiRemove(keys) - if err != nil { - log.Error("waitActivateTaskDone: error when remove task from etcd", zap.Int64("taskID", t.ID())) - } - log.Debug("waitActivateTaskDone: delete activate task from etcd", zap.Int64("taskID", t.ID())) + log.Debug("waitActivateTaskDone: one activate task done", + zap.Int64("taskID", t.ID()), + zap.Int64("triggerTaskID", triggerTask.ID())) } - log.Debug("waitActivateTaskDone: one activate task done", zap.Int64("taskID", t.ID())) } func (scheduler *TaskScheduler) processActivateTaskLoop() { defer scheduler.wg.Done() for { select { - case <-scheduler.ctx.Done(): + case <-scheduler.stopActivateTaskLoopChan: + log.Debug("processActivateTaskLoop, ctx done") return + case t := <-scheduler.activateTaskChan: if t == nil { log.Error("processActivateTaskLoop: pop a nil active task", zap.Int64("taskID", t.ID())) continue } - stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) - err := scheduler.client.Save(stateKey, strconv.Itoa(int(taskDoing))) - if err != nil { - t.Notify(err) - continue - } + log.Debug("processActivateTaskLoop: pop a active task from activateChan", zap.Int64("taskID", t.ID())) go func() { err := scheduler.processTask(t) diff --git a/internal/querycoord/task_scheduler_test.go b/internal/querycoord/task_scheduler_test.go index ca83e0261bc34f5adc982782faa99474cc1c88a4..ece00f41436ad7790df5b0872e8425c86ceb4caf 100644 --- a/internal/querycoord/task_scheduler_test.go +++ b/internal/querycoord/task_scheduler_test.go @@ -49,6 +49,7 @@ func (tt *testTask) Timestamp() Timestamp { } func (tt *testTask) PreExecute(ctx context.Context) error { + tt.SetResultInfo(nil) log.Debug("test task preExecute...") return nil } @@ -59,7 +60,7 @@ func (tt *testTask) Execute(ctx context.Context) error { switch tt.baseMsg.MsgType { case commonpb.MsgType_LoadSegments: childTask := &LoadSegmentTask{ - BaseTask: BaseTask{ + BaseTask: &BaseTask{ ctx: tt.ctx, Condition: NewTaskCondition(tt.ctx), triggerCondition: tt.triggerCondition, @@ -70,13 +71,14 @@ func (tt *testTask) Execute(ctx context.Context) error { }, NodeID: tt.nodeID, }, - meta: tt.meta, - cluster: tt.cluster, + meta: tt.meta, + cluster: tt.cluster, + excludeNodeIDs: []int64{}, } tt.AddChildTask(childTask) case commonpb.MsgType_WatchDmChannels: childTask := &WatchDmChannelTask{ - BaseTask: BaseTask{ + BaseTask: &BaseTask{ ctx: tt.ctx, Condition: NewTaskCondition(tt.ctx), triggerCondition: tt.triggerCondition, @@ -87,13 +89,14 @@ func (tt *testTask) Execute(ctx context.Context) error { }, NodeID: tt.nodeID, }, - cluster: tt.cluster, - meta: tt.meta, + cluster: tt.cluster, + meta: tt.meta, + excludeNodeIDs: []int64{}, } tt.AddChildTask(childTask) case commonpb.MsgType_WatchQueryChannels: childTask := &WatchQueryChannelTask{ - BaseTask: BaseTask{ + BaseTask: &BaseTask{ ctx: tt.ctx, Condition: NewTaskCondition(tt.ctx), triggerCondition: tt.triggerCondition, @@ -129,12 +132,7 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) { queryNode.addQueryChannels = returnFailedResult nodeID := queryNode.queryNodeID - for { - _, err = queryCoord.cluster.getNodeByID(nodeID) - if err == nil { - break - } - } + waitQueryNodeOnline(queryCoord.cluster, nodeID) testTask := &testTask{ BaseTask: BaseTask{ ctx: baseCtx, @@ -148,7 +146,7 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) { meta: queryCoord.meta, nodeID: nodeID, } - queryCoord.scheduler.Enqueue([]task{testTask}) + queryCoord.scheduler.Enqueue(testTask) queryNode.stop() err = removeNodeSession(queryNode.queryNodeID) @@ -169,7 +167,11 @@ func TestUnMarshalTask(t *testing.T) { refreshParams() kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) assert.Nil(t, err) - taskScheduler := &TaskScheduler{} + baseCtx, cancel := context.WithCancel(context.Background()) + taskScheduler := &TaskScheduler{ + ctx: baseCtx, + cancel: cancel, + } t.Run("Test LoadCollectionTask", func(t *testing.T) { loadTask := &LoadCollectionTask{ @@ -187,7 +189,7 @@ func TestUnMarshalTask(t *testing.T) { value, err := kv.Load("testMarshalLoadCollection") assert.Nil(t, err) - task, err := taskScheduler.unmarshalTask(value) + task, err := taskScheduler.unmarshalTask(1000, value) assert.Nil(t, err) assert.Equal(t, task.Type(), commonpb.MsgType_LoadCollection) }) @@ -208,7 +210,7 @@ func TestUnMarshalTask(t *testing.T) { value, err := kv.Load("testMarshalLoadPartition") assert.Nil(t, err) - task, err := taskScheduler.unmarshalTask(value) + task, err := taskScheduler.unmarshalTask(1001, value) assert.Nil(t, err) assert.Equal(t, task.Type(), commonpb.MsgType_LoadPartitions) }) @@ -229,7 +231,7 @@ func TestUnMarshalTask(t *testing.T) { value, err := kv.Load("testMarshalReleaseCollection") assert.Nil(t, err) - task, err := taskScheduler.unmarshalTask(value) + task, err := taskScheduler.unmarshalTask(1002, value) assert.Nil(t, err) assert.Equal(t, task.Type(), commonpb.MsgType_ReleaseCollection) }) @@ -250,7 +252,7 @@ func TestUnMarshalTask(t *testing.T) { value, err := kv.Load("testMarshalReleasePartition") assert.Nil(t, err) - task, err := taskScheduler.unmarshalTask(value) + task, err := taskScheduler.unmarshalTask(1003, value) assert.Nil(t, err) assert.Equal(t, task.Type(), commonpb.MsgType_ReleasePartitions) }) @@ -271,7 +273,7 @@ func TestUnMarshalTask(t *testing.T) { value, err := kv.Load("testMarshalLoadSegment") assert.Nil(t, err) - task, err := taskScheduler.unmarshalTask(value) + task, err := taskScheduler.unmarshalTask(1004, value) assert.Nil(t, err) assert.Equal(t, task.Type(), commonpb.MsgType_LoadSegments) }) @@ -292,7 +294,7 @@ func TestUnMarshalTask(t *testing.T) { value, err := kv.Load("testMarshalReleaseSegment") assert.Nil(t, err) - task, err := taskScheduler.unmarshalTask(value) + task, err := taskScheduler.unmarshalTask(1005, value) assert.Nil(t, err) assert.Equal(t, task.Type(), commonpb.MsgType_ReleaseSegments) }) @@ -313,7 +315,7 @@ func TestUnMarshalTask(t *testing.T) { value, err := kv.Load("testMarshalWatchDmChannel") assert.Nil(t, err) - task, err := taskScheduler.unmarshalTask(value) + task, err := taskScheduler.unmarshalTask(1006, value) assert.Nil(t, err) assert.Equal(t, task.Type(), commonpb.MsgType_WatchDmChannels) }) @@ -334,7 +336,7 @@ func TestUnMarshalTask(t *testing.T) { value, err := kv.Load("testMarshalWatchQueryChannel") assert.Nil(t, err) - task, err := taskScheduler.unmarshalTask(value) + task, err := taskScheduler.unmarshalTask(1007, value) assert.Nil(t, err) assert.Equal(t, task.Type(), commonpb.MsgType_WatchQueryChannels) }) @@ -356,17 +358,22 @@ func TestUnMarshalTask(t *testing.T) { value, err := kv.Load("testMarshalLoadBalanceTask") assert.Nil(t, err) - task, err := taskScheduler.unmarshalTask(value) + task, err := taskScheduler.unmarshalTask(1008, value) assert.Nil(t, err) assert.Equal(t, task.Type(), commonpb.MsgType_LoadBalanceSegments) }) + + taskScheduler.Close() } func TestReloadTaskFromKV(t *testing.T) { refreshParams() kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) assert.Nil(t, err) + baseCtx, cancel := context.WithCancel(context.Background()) taskScheduler := &TaskScheduler{ + ctx: baseCtx, + cancel: cancel, client: kv, triggerTaskQueue: NewTaskQueue(), } diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index f03e725843cc9c5b1219daa05c0aa10d2a1358df..82e6b2fadff864869b16a5f2c35c87e182420908 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -17,9 +17,215 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" ) +func genLoadCollectionTask(ctx context.Context, queryCoord *QueryCoord) *LoadCollectionTask { + req := &querypb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + }, + CollectionID: defaultCollectionID, + Schema: genCollectionSchema(defaultCollectionID, false), + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + loadCollectionTask := &LoadCollectionTask{ + BaseTask: baseTask, + LoadCollectionRequest: req, + rootCoord: queryCoord.rootCoordClient, + dataCoord: queryCoord.dataCoordClient, + cluster: queryCoord.cluster, + meta: queryCoord.meta, + } + return loadCollectionTask +} + +func genLoadPartitionTask(ctx context.Context, queryCoord *QueryCoord) *LoadPartitionTask { + req := &querypb.LoadPartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadPartitions, + }, + CollectionID: defaultCollectionID, + PartitionIDs: []UniqueID{defaultPartitionID}, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + loadPartitionTask := &LoadPartitionTask{ + BaseTask: baseTask, + LoadPartitionsRequest: req, + dataCoord: queryCoord.dataCoordClient, + cluster: queryCoord.cluster, + meta: queryCoord.meta, + } + return loadPartitionTask +} + +func genReleaseCollectionTask(ctx context.Context, queryCoord *QueryCoord) *ReleaseCollectionTask { + req := &querypb.ReleaseCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleaseCollection, + }, + CollectionID: defaultCollectionID, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + releaseCollectionTask := &ReleaseCollectionTask{ + BaseTask: baseTask, + ReleaseCollectionRequest: req, + rootCoord: queryCoord.rootCoordClient, + cluster: queryCoord.cluster, + meta: queryCoord.meta, + } + + return releaseCollectionTask +} + +func genReleasePartitionTask(ctx context.Context, queryCoord *QueryCoord) *ReleasePartitionTask { + req := &querypb.ReleasePartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleasePartitions, + }, + CollectionID: defaultCollectionID, + PartitionIDs: []UniqueID{defaultPartitionID}, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + releasePartitionTask := &ReleasePartitionTask{ + BaseTask: baseTask, + ReleasePartitionsRequest: req, + cluster: queryCoord.cluster, + } + + return releasePartitionTask +} + +func genReleaseSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *ReleaseSegmentTask { + req := &querypb.ReleaseSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleaseSegments, + }, + NodeID: nodeID, + CollectionID: defaultCollectionID, + PartitionIDs: []UniqueID{defaultPartitionID}, + SegmentIDs: []UniqueID{defaultSegmentID}, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + releaseSegmentTask := &ReleaseSegmentTask{ + BaseTask: baseTask, + ReleaseSegmentsRequest: req, + cluster: queryCoord.cluster, + } + return releaseSegmentTask +} + +func genWatchDmChannelTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *WatchDmChannelTask { + schema := genCollectionSchema(defaultCollectionID, false) + vChannelInfo := &datapb.VchannelInfo{ + CollectionID: defaultCollectionID, + ChannelName: "testDmChannel", + } + req := &querypb.WatchDmChannelsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchDmChannels, + }, + NodeID: nodeID, + CollectionID: defaultCollectionID, + PartitionID: defaultPartitionID, + Schema: schema, + Infos: []*datapb.VchannelInfo{vChannelInfo}, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + baseTask.taskID = 100 + watchDmChannelTask := &WatchDmChannelTask{ + BaseTask: baseTask, + WatchDmChannelsRequest: req, + cluster: queryCoord.cluster, + meta: queryCoord.meta, + excludeNodeIDs: []int64{}, + } + + parentReq := &querypb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + }, + CollectionID: defaultCollectionID, + Schema: genCollectionSchema(defaultCollectionID, false), + } + baseParentTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + baseParentTask.taskID = 10 + parentTask := &LoadCollectionTask{ + BaseTask: baseParentTask, + LoadCollectionRequest: parentReq, + rootCoord: queryCoord.rootCoordClient, + dataCoord: queryCoord.dataCoordClient, + meta: queryCoord.meta, + cluster: queryCoord.cluster, + } + parentTask.SetState(taskDone) + parentTask.SetResultInfo(nil) + parentTask.AddChildTask(watchDmChannelTask) + watchDmChannelTask.SetParentTask(parentTask) + + queryCoord.meta.addCollection(defaultCollectionID, schema) + return watchDmChannelTask +} +func genLoadSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *LoadSegmentTask { + schema := genCollectionSchema(defaultCollectionID, false) + segmentInfo := &querypb.SegmentLoadInfo{ + SegmentID: defaultSegmentID, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + } + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadSegments, + }, + NodeID: nodeID, + Schema: schema, + Infos: []*querypb.SegmentLoadInfo{segmentInfo}, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + baseTask.taskID = 100 + loadSegmentTask := &LoadSegmentTask{ + BaseTask: baseTask, + LoadSegmentsRequest: req, + cluster: queryCoord.cluster, + meta: queryCoord.meta, + excludeNodeIDs: []int64{}, + } + + parentReq := &querypb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + }, + CollectionID: defaultCollectionID, + Schema: genCollectionSchema(defaultCollectionID, false), + } + baseParentTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) + baseParentTask.taskID = 10 + parentTask := &LoadCollectionTask{ + BaseTask: baseParentTask, + LoadCollectionRequest: parentReq, + rootCoord: queryCoord.rootCoordClient, + dataCoord: queryCoord.dataCoordClient, + meta: queryCoord.meta, + cluster: queryCoord.cluster, + } + parentTask.SetState(taskDone) + parentTask.SetResultInfo(nil) + parentTask.AddChildTask(loadSegmentTask) + loadSegmentTask.SetParentTask(parentTask) + + queryCoord.meta.addCollection(defaultCollectionID, schema) + return loadSegmentTask +} + +func waitTaskFinalState(t task, state taskState) { + for { + if t.State() == state { + break + } + } +} + func TestTriggerTask(t *testing.T) { refreshParams() ctx := context.Background() @@ -28,98 +234,32 @@ func TestTriggerTask(t *testing.T) { node, err := startQueryNodeServer(ctx) assert.Nil(t, err) + waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID) t.Run("Test LoadCollection", func(t *testing.T) { - req := &querypb.LoadCollectionRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_LoadCollection, - }, - CollectionID: defaultCollectionID, - Schema: genCollectionSchema(defaultCollectionID, false), - } - loadCollectionTask := &LoadCollectionTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, - LoadCollectionRequest: req, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - cluster: queryCoord.cluster, - meta: queryCoord.meta, - } + loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) err = queryCoord.scheduler.processTask(loadCollectionTask) assert.Nil(t, err) }) t.Run("Test ReleaseCollection", func(t *testing.T) { - req := &querypb.ReleaseCollectionRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ReleaseCollection, - }, - CollectionID: defaultCollectionID, - } - loadCollectionTask := &ReleaseCollectionTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, - ReleaseCollectionRequest: req, - rootCoord: queryCoord.rootCoordClient, - cluster: queryCoord.cluster, - meta: queryCoord.meta, - } - - err = queryCoord.scheduler.processTask(loadCollectionTask) + releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.processTask(releaseCollectionTask) assert.Nil(t, err) }) t.Run("Test LoadPartition", func(t *testing.T) { - req := &querypb.LoadPartitionsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_LoadPartitions, - }, - CollectionID: defaultCollectionID, - PartitionIDs: []UniqueID{defaultPartitionID}, - } - loadCollectionTask := &LoadPartitionTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, - LoadPartitionsRequest: req, - dataCoord: queryCoord.dataCoordClient, - cluster: queryCoord.cluster, - meta: queryCoord.meta, - } + loadPartitionTask := genLoadPartitionTask(ctx, queryCoord) - err = queryCoord.scheduler.processTask(loadCollectionTask) + err = queryCoord.scheduler.processTask(loadPartitionTask) assert.Nil(t, err) }) t.Run("Test ReleasePartition", func(t *testing.T) { - req := &querypb.ReleasePartitionsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ReleasePartitions, - }, - CollectionID: defaultCollectionID, - PartitionIDs: []UniqueID{defaultPartitionID}, - } - loadCollectionTask := &ReleasePartitionTask{ - BaseTask: BaseTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - triggerCondition: querypb.TriggerCondition_grpcRequest, - }, - ReleasePartitionsRequest: req, - cluster: queryCoord.cluster, - } + releasePartitionTask := genReleaseCollectionTask(ctx, queryCoord) - err = queryCoord.scheduler.processTask(loadCollectionTask) + err = queryCoord.scheduler.processTask(releasePartitionTask) assert.Nil(t, err) }) @@ -128,3 +268,388 @@ func TestTriggerTask(t *testing.T) { err = removeAllSession() assert.Nil(t, err) } + +func Test_LoadCollectionAfterLoadPartition(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID) + + loadPartitionTask := genLoadPartitionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadPartitionTask) + assert.Nil(t, err) + + loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadCollectionTask) + assert.Nil(t, err) + + releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(releaseCollectionTask) + assert.Nil(t, err) + + err = releaseCollectionTask.WaitToFinish() + assert.Nil(t, err) + + node.stop() + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_RepeatLoadCollection(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID) + + loadCollectionTask1 := genLoadCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadCollectionTask1) + assert.Nil(t, err) + + createDefaultPartition(ctx, queryCoord) + loadCollectionTask2 := genLoadCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadCollectionTask2) + assert.Nil(t, err) + + releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(releaseCollectionTask) + assert.Nil(t, err) + + err = releaseCollectionTask.WaitToFinish() + assert.Nil(t, err) + + node.stop() + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_LoadCollectionAssignTaskFail(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadCollectionTask) + assert.Nil(t, err) + + err = loadCollectionTask.WaitToFinish() + assert.NotNil(t, err) + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_LoadCollectionExecuteFail(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + + node.loadSegment = returnFailedResult + waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID) + + loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadCollectionTask) + assert.Nil(t, err) + + waitTaskFinalState(loadCollectionTask, taskFailed) + + node.stop() + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_LoadPartitionAssignTaskFail(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + loadPartitionTask := genLoadPartitionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadPartitionTask) + assert.Nil(t, err) + + err = loadPartitionTask.WaitToFinish() + assert.NotNil(t, err) + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_LoadPartitionExecuteFail(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + + node.loadSegment = returnFailedResult + + waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID) + loadPartitionTask := genLoadPartitionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadPartitionTask) + assert.Nil(t, err) + + waitTaskFinalState(loadPartitionTask, taskFailed) + + node.stop() + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_LoadPartitionExecuteFailAfterLoadCollection(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + + waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID) + loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadCollectionTask) + assert.Nil(t, err) + + waitTaskFinalState(loadCollectionTask, taskExpired) + + createDefaultPartition(ctx, queryCoord) + node.watchDmChannels = returnFailedResult + + loadPartitionTask := genLoadPartitionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadPartitionTask) + assert.Nil(t, err) + + waitTaskFinalState(loadPartitionTask, taskFailed) + + node.stop() + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_ReleaseCollectionExecuteFail(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + node.releaseCollection = returnFailedResult + + waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID) + releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(releaseCollectionTask) + assert.Nil(t, err) + + waitTaskFinalState(releaseCollectionTask, taskFailed) + + node.stop() + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_LoadSegmentReschedule(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + node1.loadSegment = returnFailedResult + + node2, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) + + loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadCollectionTask) + assert.Nil(t, err) + + waitTaskFinalState(loadCollectionTask, taskExpired) + + node1.stop() + node2.stop() + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_WatchDmChannelReschedule(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + node1.watchDmChannels = returnFailedResult + + node2, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) + + loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadCollectionTask) + assert.Nil(t, err) + + waitTaskFinalState(loadCollectionTask, taskExpired) + + node1.stop() + node2.stop() + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_ReleaseSegmentTask(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + releaseSegmentTask := genReleaseSegmentTask(ctx, queryCoord, node1.queryNodeID) + queryCoord.scheduler.activateTaskChan <- releaseSegmentTask + + waitTaskFinalState(releaseSegmentTask, taskDone) + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_RescheduleDmChannelWithWatchQueryChannel(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + node2, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) + + node1.watchDmChannels = returnFailedResult + watchDmChannelTask := genWatchDmChannelTask(ctx, queryCoord, node1.queryNodeID) + loadCollectionTask := watchDmChannelTask.parentTask + queryCoord.scheduler.triggerTaskQueue.addTask(loadCollectionTask) + + waitTaskFinalState(loadCollectionTask, taskExpired) + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_RescheduleSegmentWithWatchQueryChannel(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + node2, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) + + node1.loadSegment = returnFailedResult + loadSegmentTask := genLoadSegmentTask(ctx, queryCoord, node1.queryNodeID) + loadCollectionTask := loadSegmentTask.parentTask + queryCoord.scheduler.triggerTaskQueue.addTask(loadCollectionTask) + + waitTaskFinalState(loadCollectionTask, taskExpired) + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_RescheduleSegmentEndWithFail(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + node1.loadSegment = returnFailedResult + node2, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + node2.loadSegment = returnFailedResult + + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) + + loadSegmentTask := genLoadSegmentTask(ctx, queryCoord, node1.queryNodeID) + loadCollectionTask := loadSegmentTask.parentTask + queryCoord.scheduler.triggerTaskQueue.addTask(loadCollectionTask) + + waitTaskFinalState(loadCollectionTask, taskFailed) + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} + +func Test_RescheduleDmChannelsEndWithFail(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + node1.watchDmChannels = returnFailedResult + node2, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + node2.watchDmChannels = returnFailedResult + + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) + + watchDmChannelTask := genWatchDmChannelTask(ctx, queryCoord, node1.queryNodeID) + loadCollectionTask := watchDmChannelTask.parentTask + queryCoord.scheduler.triggerTaskQueue.addTask(loadCollectionTask) + + waitTaskFinalState(loadCollectionTask, taskFailed) + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +}