未验证 提交 97aa2bd8 编写于 作者: Y yah01 提交者: GitHub

Merge load segment tasks to improve performance (#19234)

Signed-off-by: Nyah01 <yang.cen@zilliz.com>
Signed-off-by: Nyah01 <yang.cen@zilliz.com>
上级 84b3c7ea
......@@ -38,8 +38,7 @@ func NewScheduler() *Scheduler {
}
func (scheduler *Scheduler) Start(ctx context.Context) {
scheduler.wg.Add(1)
go scheduler.schedule(ctx)
scheduler.schedule(ctx)
}
func (scheduler *Scheduler) Stop() {
......@@ -48,10 +47,10 @@ func (scheduler *Scheduler) Stop() {
}
func (scheduler *Scheduler) schedule(ctx context.Context) {
defer scheduler.wg.Done()
ticker := time.NewTicker(500 * time.Millisecond)
scheduler.wg.Add(1)
go func() {
defer scheduler.wg.Done()
ticker := time.NewTicker(500 * time.Millisecond)
for {
select {
case <-ctx.Done():
......
......@@ -306,6 +306,9 @@ func (s *Server) Start() error {
log.Info("start job scheduler...")
s.jobScheduler.Start(s.ctx)
log.Info("start task scheduler...")
s.taskScheduler.Start(s.ctx)
log.Info("start checker controller...")
s.checkerController.Start(s.ctx)
......@@ -336,6 +339,9 @@ func (s *Server) Stop() error {
log.Info("stop checker controller...")
s.checkerController.Stop()
log.Info("stop task scheduler...")
s.taskScheduler.Stop()
log.Info("stop job scheduler...")
s.jobScheduler.Stop()
......
......@@ -12,7 +12,6 @@ import (
grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
......@@ -42,20 +41,12 @@ type Cluster interface {
Stop()
}
type segmentIndex struct {
NodeID int64
CollectionID int64
Shard string
}
// QueryCluster is used to send requests to QueryNodes and manage connections
type QueryCluster struct {
*clients
nodeManager *NodeManager
wg sync.WaitGroup
ch chan struct{}
scheduler *typeutil.GroupScheduler[segmentIndex, *commonpb.Status]
}
func NewCluster(nodeManager *NodeManager) *QueryCluster {
......@@ -63,21 +54,18 @@ func NewCluster(nodeManager *NodeManager) *QueryCluster {
clients: newClients(),
nodeManager: nodeManager,
ch: make(chan struct{}),
scheduler: typeutil.NewGroupScheduler[segmentIndex, *commonpb.Status](),
}
c.wg.Add(1)
go c.updateLoop()
return c
}
func (c *QueryCluster) Start(ctx context.Context) {
c.scheduler.Start(ctx)
c.wg.Add(1)
go c.updateLoop()
}
func (c *QueryCluster) Stop() {
c.clients.closeAll()
close(c.ch)
c.scheduler.Stop()
c.wg.Wait()
}
......@@ -101,13 +89,6 @@ func (c *QueryCluster) updateLoop() {
}
func (c *QueryCluster) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
// task := NewLoadSegmentsTask(c, nodeID, req)
// c.scheduler.Add(task)
// return task.Wait()
return c.loadSegments(ctx, nodeID, req)
}
func (c *QueryCluster) loadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
var status *commonpb.Status
var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
......
......@@ -73,7 +73,6 @@ func (suite *ClusterTestSuite) setupCluster() {
suite.nodeManager.Add(node)
}
suite.cluster = NewCluster(suite.nodeManager)
suite.cluster.Start(context.Background())
}
func (suite *ClusterTestSuite) createTestServers() []querypb.QueryNodeServer {
......
......@@ -7,6 +7,7 @@ import (
"github.com/milvus-io/milvus/api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
......@@ -23,6 +24,8 @@ type actionIndex struct {
}
type Executor struct {
doneCh chan struct{}
wg sync.WaitGroup
meta *meta.Meta
dist *meta.DistributionManager
broker meta.Broker
......@@ -30,6 +33,9 @@ type Executor struct {
cluster session.Cluster
nodeMgr *session.NodeManager
// Merge load segment requests
merger *Merger[segmentIndex, *querypb.LoadSegmentsRequest]
executingActions sync.Map
}
......@@ -40,21 +46,33 @@ func NewExecutor(meta *meta.Meta,
cluster session.Cluster,
nodeMgr *session.NodeManager) *Executor {
return &Executor{
doneCh: make(chan struct{}),
meta: meta,
dist: dist,
broker: broker,
targetMgr: targetMgr,
cluster: cluster,
nodeMgr: nodeMgr,
merger: NewMerger[segmentIndex, *querypb.LoadSegmentsRequest](),
executingActions: sync.Map{},
}
}
func (ex *Executor) Start(ctx context.Context) {
ex.merger.Start(ctx)
ex.scheduleRequests()
}
func (ex *Executor) Stop() {
ex.merger.Stop()
ex.wg.Wait()
}
// Execute executes the given action,
// does nothing and returns false if the action is already committed,
// returns true otherwise.
func (ex *Executor) Execute(task Task, step int, action Action) bool {
func (ex *Executor) Execute(task Task, step int) bool {
index := actionIndex{
Task: task.ID(),
Step: step,
......@@ -65,42 +83,114 @@ func (ex *Executor) Execute(task Task, step int, action Action) bool {
}
log := log.With(
zap.Int64("task", task.ID()),
zap.Int64("taskID", task.ID()),
zap.Int("step", step),
zap.Int64("source", task.SourceID()),
)
go func() {
log.Info("execute the action of task")
switch action := action.(type) {
switch task.Actions()[step].(type) {
case *SegmentAction:
ex.executeSegmentAction(task.(*SegmentTask), action)
ex.executeSegmentAction(task.(*SegmentTask), step)
case *ChannelAction:
ex.executeDmChannelAction(task.(*ChannelTask), action)
ex.executeDmChannelAction(task.(*ChannelTask), step)
}
ex.executingActions.Delete(index)
}()
return true
}
func (ex *Executor) executeSegmentAction(task *SegmentTask, action *SegmentAction) {
switch action.Type() {
func (ex *Executor) scheduleRequests() {
ex.wg.Add(1)
go func() {
defer ex.wg.Done()
for mergeTask := range ex.merger.Chan() {
task := mergeTask.(*LoadSegmentsTask)
log.Info("get merge task, process it",
zap.Int64("collectionID", task.req.GetCollectionID()),
zap.String("shard", task.req.GetInfos()[0].GetInsertChannel()),
zap.Int64("nodeID", task.req.GetDstNodeID()),
zap.Int("taskNum", len(task.tasks)),
)
go ex.processMergeTask(mergeTask.(*LoadSegmentsTask))
}
}()
}
func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) {
task := mergeTask.tasks[0]
action := task.Actions()[mergeTask.steps[0]]
defer func() {
for i := range mergeTask.tasks {
mergeTask.tasks[i].SetErr(task.Err())
ex.removeTask(mergeTask.tasks[i], mergeTask.steps[i])
}
}()
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("segmentID", task.segmentID),
zap.Int64("node", action.Node()),
zap.Int64("source", task.SourceID()),
)
// Get shard leader for the given replica and segment
channel := mergeTask.req.GetInfos()[0].GetInsertChannel()
leader, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), channel)
if !ok {
msg := "no shard leader for the segment to execute loading"
task.SetErr(utils.WrapError(msg, ErrTaskStale))
log.Warn(msg, zap.String("shard", channel))
return
}
ctx, cancel := context.WithTimeout(task.Context(), actionTimeout)
status, err := ex.cluster.LoadSegments(ctx, leader, mergeTask.req)
cancel()
if err != nil {
log.Warn("failed to load segment, it may be a false failure", zap.Error(err))
return
}
if status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("failed to load segment", zap.String("reason", status.GetReason()))
return
}
}
func (ex *Executor) removeTask(task Task, step int) {
log.Info("excute task done, remove it",
zap.Int64("taskID", task.ID()),
zap.Int("step", step),
zap.Error(task.Err()))
index := actionIndex{
Task: task.ID(),
Step: step,
}
ex.executingActions.Delete(index)
}
func (ex *Executor) executeSegmentAction(task *SegmentTask, step int) {
switch task.Actions()[step].Type() {
case ActionTypeGrow:
ex.loadSegment(task, action)
ex.loadSegment(task, step)
case ActionTypeReduce:
ex.releaseSegment(task, action)
ex.releaseSegment(task, step)
}
}
func (ex *Executor) loadSegment(task *SegmentTask, action *SegmentAction) {
// loadSegment commits the request to merger,
// not really executes the request
func (ex *Executor) loadSegment(task *SegmentTask, step int) {
action := task.Actions()[step].(*SegmentAction)
log := log.With(
zap.Int64("task", task.ID()),
zap.Int64("collection", task.CollectionID()),
zap.Int64("segment", task.segmentID),
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("segmentID", task.segmentID),
zap.Int64("node", action.Node()),
zap.Int64("source", task.SourceID()),
)
......@@ -152,24 +242,21 @@ func (ex *Executor) loadSegment(task *SegmentTask, action *SegmentAction) {
}
req := packLoadSegmentRequest(task, action, schema, loadMeta, loadInfo, deltaPositions)
status, err := ex.cluster.LoadSegments(ctx, leader, req)
if err != nil {
log.Warn("failed to load segment, it may be a false failure", zap.Error(err))
return
}
if status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("failed to load segment", zap.String("reason", status.GetReason()))
return
}
loadTask := NewLoadSegmentsTask(task, step, req)
ex.merger.Add(loadTask)
log.Info("load segment task committed")
}
func (ex *Executor) releaseSegment(task *SegmentTask, action *SegmentAction) {
func (ex *Executor) releaseSegment(task *SegmentTask, step int) {
defer ex.removeTask(task, step)
action := task.Actions()[step].(*SegmentAction)
defer action.isReleaseCommitted.Store(true)
log := log.With(
zap.Int64("task", task.ID()),
zap.Int64("collection", task.CollectionID()),
zap.Int64("segment", task.segmentID),
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("segmentID", task.segmentID),
zap.Int64("node", action.Node()),
zap.Int64("source", task.SourceID()),
)
......@@ -215,20 +302,23 @@ func (ex *Executor) releaseSegment(task *SegmentTask, action *SegmentAction) {
}
}
func (ex *Executor) executeDmChannelAction(task *ChannelTask, action *ChannelAction) {
switch action.Type() {
func (ex *Executor) executeDmChannelAction(task *ChannelTask, step int) {
switch task.Actions()[step].Type() {
case ActionTypeGrow:
ex.subDmChannel(task, action)
ex.subDmChannel(task, step)
case ActionTypeReduce:
ex.unsubDmChannel(task, action)
ex.unsubDmChannel(task, step)
}
}
func (ex *Executor) subDmChannel(task *ChannelTask, action *ChannelAction) {
func (ex *Executor) subDmChannel(task *ChannelTask, step int) {
defer ex.removeTask(task, step)
action := task.Actions()[step].(*ChannelAction)
log := log.With(
zap.Int64("task", task.ID()),
zap.Int64("collection", task.CollectionID()),
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.String("channel", task.Channel()),
zap.Int64("node", action.Node()),
zap.Int64("source", task.SourceID()),
......@@ -294,10 +384,13 @@ func (ex *Executor) subDmChannel(task *ChannelTask, action *ChannelAction) {
log.Info("subscribe DmChannel done")
}
func (ex *Executor) unsubDmChannel(task *ChannelTask, action *ChannelAction) {
func (ex *Executor) unsubDmChannel(task *ChannelTask, step int) {
defer ex.removeTask(task, step)
action := task.Actions()[step].(*ChannelAction)
log := log.With(
zap.Int64("task", task.ID()),
zap.Int64("collection", task.CollectionID()),
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.String("channel", task.Channel()),
zap.Int64("node", action.Node()),
zap.Int64("source", task.SourceID()),
......
package session
package task
import (
"context"
"time"
"github.com/milvus-io/milvus/api/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
var _ typeutil.MergeableTask[segmentIndex, *commonpb.Status] = (*LoadSegmentsTask)(nil)
type MergeableTask[K comparable, R any] interface {
ID() K
Merge(other MergeableTask[K, R])
}
var _ MergeableTask[segmentIndex, *querypb.LoadSegmentsRequest] = (*LoadSegmentsTask)(nil)
type segmentIndex struct {
NodeID int64
CollectionID int64
Shard string
}
type LoadSegmentsTask struct {
doneCh chan struct{}
cluster *QueryCluster
nodeID int64
req *querypb.LoadSegmentsRequest
result *commonpb.Status
err error
tasks []*SegmentTask
steps []int
req *querypb.LoadSegmentsRequest
}
func NewLoadSegmentsTask(cluster *QueryCluster, nodeID int64, req *querypb.LoadSegmentsRequest) *LoadSegmentsTask {
func NewLoadSegmentsTask(task *SegmentTask, step int, req *querypb.LoadSegmentsRequest) *LoadSegmentsTask {
return &LoadSegmentsTask{
doneCh: make(chan struct{}),
cluster: cluster,
nodeID: nodeID,
req: req,
tasks: []*SegmentTask{task},
steps: []int{step},
req: req,
}
}
func (task *LoadSegmentsTask) ID() segmentIndex {
return segmentIndex{
NodeID: task.nodeID,
NodeID: task.req.GetDstNodeID(),
CollectionID: task.req.GetCollectionID(),
Shard: task.req.GetInfos()[0].GetInsertChannel(),
}
}
func (task *LoadSegmentsTask) Execute() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
status, err := task.cluster.loadSegments(ctx, task.nodeID, task.req)
if err != nil {
task.err = err
return err
}
task.result = status
return nil
}
func (task *LoadSegmentsTask) Merge(other typeutil.MergeableTask[segmentIndex, *commonpb.Status]) {
task.req.Infos = append(task.req.Infos, other.(*LoadSegmentsTask).req.GetInfos()...)
func (task *LoadSegmentsTask) Merge(other MergeableTask[segmentIndex, *querypb.LoadSegmentsRequest]) {
otherTask := other.(*LoadSegmentsTask)
task.tasks = append(task.tasks, otherTask.tasks...)
task.steps = append(task.steps, otherTask.steps...)
task.req.Infos = append(task.req.Infos, otherTask.req.GetInfos()...)
deltaPositions := make(map[string]*internalpb.MsgPosition)
for _, position := range task.req.DeltaPositions {
deltaPositions[position.GetChannelName()] = position
}
for _, position := range other.(*LoadSegmentsTask).req.GetDeltaPositions() {
for _, position := range otherTask.req.GetDeltaPositions() {
merged, ok := deltaPositions[position.GetChannelName()]
if !ok || merged.GetTimestamp() > position.GetTimestamp() {
merged = position
......@@ -69,19 +62,6 @@ func (task *LoadSegmentsTask) Merge(other typeutil.MergeableTask[segmentIndex, *
}
}
func (task *LoadSegmentsTask) SetResult(result *commonpb.Status) {
task.result = result
}
func (task *LoadSegmentsTask) SetError(err error) {
task.err = err
}
func (task *LoadSegmentsTask) Done() {
close(task.doneCh)
}
func (task *LoadSegmentsTask) Wait() (*commonpb.Status, error) {
<-task.doneCh
return task.result, task.err
func (task *LoadSegmentsTask) Result() *querypb.LoadSegmentsRequest {
return task.req
}
package typeutil
package task
import (
"context"
......@@ -6,74 +6,70 @@ import (
"time"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
// GroupScheduler schedules requests,
// all requests within the same partition & node will run sequentially,
// with group commit
// Merger merges tasks with the same mergeID.
const (
taskQueueCap = 16
waitQueueCap = 128
)
type MergeableTask[K comparable, R any] interface {
ID() K
Execute() error
Merge(other MergeableTask[K, R])
SetResult(R)
SetError(error)
Done()
Wait() (R, error)
}
type GroupScheduler[K comparable, R any] struct {
type Merger[K comparable, R any] struct {
stopCh chan struct{}
wg sync.WaitGroup
processors *ConcurrentSet[K] // Tasks of having processor
processors *typeutil.ConcurrentSet[K] // Tasks of having processor
queues map[K]chan MergeableTask[K, R] // TaskID -> Queue
waitQueue chan MergeableTask[K, R]
outCh chan MergeableTask[K, R]
}
func NewGroupScheduler[K comparable, R any]() *GroupScheduler[K, R] {
return &GroupScheduler[K, R]{
func NewMerger[K comparable, R any]() *Merger[K, R] {
return &Merger[K, R]{
stopCh: make(chan struct{}),
processors: NewConcurrentSet[K](),
processors: typeutil.NewConcurrentSet[K](),
queues: make(map[K]chan MergeableTask[K, R]),
waitQueue: make(chan MergeableTask[K, R], waitQueueCap),
outCh: make(chan MergeableTask[K, R], taskQueueCap),
}
}
func (scheduler *GroupScheduler[K, R]) Start(ctx context.Context) {
scheduler.wg.Add(1)
go scheduler.schedule(ctx)
func (merger *Merger[K, R]) Start(ctx context.Context) {
merger.schedule(ctx)
}
func (scheduler *GroupScheduler[K, R]) Stop() {
close(scheduler.stopCh)
scheduler.wg.Wait()
func (merger *Merger[K, R]) Stop() {
close(merger.stopCh)
merger.wg.Wait()
close(merger.outCh)
}
func (scheduler *GroupScheduler[K, R]) schedule(ctx context.Context) {
defer scheduler.wg.Done()
func (merger *Merger[K, R]) Chan() <-chan MergeableTask[K, R] {
return merger.outCh
}
ticker := time.NewTicker(500 * time.Millisecond)
func (merger *Merger[K, R]) schedule(ctx context.Context) {
merger.wg.Add(1)
go func() {
defer merger.wg.Done()
ticker := time.NewTicker(500 * time.Millisecond)
for {
select {
case <-ctx.Done():
log.Info("GroupScheduler stopped due to context canceled")
log.Info("Merger stopped due to context canceled")
return
case <-scheduler.stopCh:
log.Info("GroupScheduler stopped")
case <-merger.stopCh:
log.Info("Merger stopped")
return
case task := <-scheduler.waitQueue:
queue, ok := scheduler.queues[task.ID()]
case task := <-merger.waitQueue:
queue, ok := merger.queues[task.ID()]
if !ok {
queue = make(chan MergeableTask[K, R], taskQueueCap)
scheduler.queues[task.ID()] = queue
merger.queues[task.ID()] = queue
}
outer:
for {
......@@ -81,17 +77,17 @@ func (scheduler *GroupScheduler[K, R]) schedule(ctx context.Context) {
case queue <- task:
break outer
default: // Queue full, flush and retry
scheduler.startProcessor(task.ID(), queue)
merger.merge(task.ID(), queue)
}
}
case <-ticker.C:
for id, queue := range scheduler.queues {
for id, queue := range merger.queues {
if len(queue) > 0 {
scheduler.startProcessor(id, queue)
merger.merge(id, queue)
} else {
// Release resource if no job for the task
delete(scheduler.queues, id)
// Release resource if no task for the queue
delete(merger.queues, id)
}
}
}
......@@ -99,53 +95,45 @@ func (scheduler *GroupScheduler[K, R]) schedule(ctx context.Context) {
}()
}
func (scheduler *GroupScheduler[K, R]) isStopped() bool {
func (merger *Merger[K, R]) isStopped() bool {
select {
case <-scheduler.stopCh:
case <-merger.stopCh:
return true
default:
return false
}
}
func (scheduler *GroupScheduler[K, R]) Add(job MergeableTask[K, R]) {
scheduler.waitQueue <- job
func (merger *Merger[K, R]) Add(task MergeableTask[K, R]) {
merger.waitQueue <- task
}
func (scheduler *GroupScheduler[K, R]) startProcessor(id K, queue chan MergeableTask[K, R]) {
if scheduler.isStopped() {
func (merger *Merger[K, R]) merge(id K, queue chan MergeableTask[K, R]) {
if merger.isStopped() {
return
}
if !scheduler.processors.Insert(id) {
if !merger.processors.Insert(id) {
return
}
scheduler.wg.Add(1)
go scheduler.processQueue(id, queue)
merger.wg.Add(1)
go merger.mergeQueue(id, queue)
}
// processQueue processes tasks in the given queue,
// mergeQueue merges tasks in the given queue,
// it only processes tasks with the number of the length of queue at the time,
// to avoid leaking goroutines
func (scheduler *GroupScheduler[K, R]) processQueue(id K, queue chan MergeableTask[K, R]) {
defer scheduler.wg.Done()
defer scheduler.processors.Remove(id)
func (merger *Merger[K, R]) mergeQueue(id K, queue chan MergeableTask[K, R]) {
defer merger.wg.Done()
defer merger.processors.Remove(id)
len := len(queue)
buffer := make([]MergeableTask[K, R], len)
for i := range buffer {
buffer[i] = <-queue
if i > 0 {
buffer[0].Merge(buffer[i])
}
task := <-queue
for i := 1; i < len; i++ {
task.Merge(<-queue)
}
buffer[0].Execute()
buffer[0].Done()
result, err := buffer[0].Wait()
for _, buffer := range buffer[1:] {
buffer.SetResult(result)
buffer.SetError(err)
buffer.Done()
}
log.Info("merge tasks done",
zap.Any("mergeID", task.ID()))
merger.outCh <- task
}
package task
import (
"context"
"testing"
"time"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/stretchr/testify/suite"
)
type MergerSuite struct {
suite.Suite
// Data
collectionID int64
replicaID int64
nodeID int64
requests map[int64]*querypb.LoadSegmentsRequest
merger *Merger[segmentIndex, *querypb.LoadSegmentsRequest]
}
func (suite *MergerSuite) SetupSuite() {
suite.collectionID = 1000
suite.replicaID = 100
suite.nodeID = 1
suite.requests = map[int64]*querypb.LoadSegmentsRequest{
1: {
DstNodeID: suite.nodeID,
CollectionID: suite.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 1,
InsertChannel: "dmc0",
},
},
DeltaPositions: []*internalpb.MsgPosition{
{
ChannelName: "dmc0",
Timestamp: 2,
},
{
ChannelName: "dmc1",
Timestamp: 3,
},
},
},
2: {
DstNodeID: suite.nodeID,
CollectionID: suite.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 2,
InsertChannel: "dmc0",
},
},
DeltaPositions: []*internalpb.MsgPosition{
{
ChannelName: "dmc0",
Timestamp: 3,
},
{
ChannelName: "dmc1",
Timestamp: 2,
},
},
},
3: {
DstNodeID: suite.nodeID,
CollectionID: suite.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 3,
InsertChannel: "dmc0",
},
},
DeltaPositions: []*internalpb.MsgPosition{
{
ChannelName: "dmc0",
Timestamp: 1,
},
{
ChannelName: "dmc1",
Timestamp: 1,
},
},
},
}
}
func (suite *MergerSuite) SetupTest() {
suite.merger = NewMerger[segmentIndex, *querypb.LoadSegmentsRequest]()
}
func (suite *MergerSuite) TestMerge() {
const (
requestNum = 5
timeout = 5 * time.Second
)
ctx := context.Background()
for segmentID := int64(1); segmentID <= 3; segmentID++ {
task := NewSegmentTask(ctx, timeout, 0, suite.collectionID, suite.replicaID,
NewSegmentAction(suite.nodeID, ActionTypeGrow, segmentID))
suite.merger.Add(NewLoadSegmentsTask(task, 0, suite.requests[segmentID]))
}
suite.merger.Start(ctx)
defer suite.merger.Stop()
taskI := <-suite.merger.Chan()
task := taskI.(*LoadSegmentsTask)
suite.Len(task.tasks, 3)
suite.Len(task.steps, 3)
suite.EqualValues(1, task.Result().DeltaPositions[0].Timestamp)
suite.EqualValues(1, task.Result().DeltaPositions[1].Timestamp)
}
func TestMerger(t *testing.T) {
suite.Run(t, new(MergerSuite))
}
......@@ -2,7 +2,11 @@
package task
import mock "github.com/stretchr/testify/mock"
import (
context "context"
mock "github.com/stretchr/testify/mock"
)
// MockScheduler is an autogenerated mock type for the Scheduler type
type MockScheduler struct {
......@@ -184,6 +188,61 @@ func (_c *MockScheduler_RemoveByNode_Call) Return() *MockScheduler_RemoveByNode_
return _c
}
// Start provides a mock function with given fields: ctx
func (_m *MockScheduler) Start(ctx context.Context) {
_m.Called(ctx)
}
// MockScheduler_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
type MockScheduler_Start_Call struct {
*mock.Call
}
// Start is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockScheduler_Expecter) Start(ctx interface{}) *MockScheduler_Start_Call {
return &MockScheduler_Start_Call{Call: _e.mock.On("Start", ctx)}
}
func (_c *MockScheduler_Start_Call) Run(run func(ctx context.Context)) *MockScheduler_Start_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockScheduler_Start_Call) Return() *MockScheduler_Start_Call {
_c.Call.Return()
return _c
}
// Stop provides a mock function with given fields:
func (_m *MockScheduler) Stop() {
_m.Called()
}
// MockScheduler_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
type MockScheduler_Stop_Call struct {
*mock.Call
}
// Stop is a helper method to define mock.On call
func (_e *MockScheduler_Expecter) Stop() *MockScheduler_Stop_Call {
return &MockScheduler_Stop_Call{Call: _e.mock.On("Stop")}
}
func (_c *MockScheduler_Stop_Call) Run(run func()) *MockScheduler_Stop_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockScheduler_Stop_Call) Return() *MockScheduler_Stop_Call {
_c.Call.Return()
return _c
}
type mockConstructorTestingTNewMockScheduler interface {
mock.TestingT
Cleanup(func())
......
......@@ -111,6 +111,8 @@ func (queue *taskQueue) Range(fn func(task Task) bool) {
}
type Scheduler interface {
Start(ctx context.Context)
Stop()
Add(task Task) error
Dispatch(node int64)
RemoveByNode(node int64)
......@@ -167,6 +169,14 @@ func NewScheduler(ctx context.Context,
}
}
func (scheduler *taskScheduler) Start(ctx context.Context) {
scheduler.executor.Start(ctx)
}
func (scheduler *taskScheduler) Stop() {
scheduler.executor.Stop()
}
func (scheduler *taskScheduler) Add(task Task) error {
scheduler.rwmutex.Lock()
defer scheduler.rwmutex.Unlock()
......@@ -482,11 +492,11 @@ func (scheduler *taskScheduler) process(task Task) bool {
task.SetErr(ErrTaskStale)
}
actions, step := task.Actions(), task.Step()
step := task.Step()
log = log.With(zap.Int("step", step))
switch task.Status() {
case TaskStatusStarted:
if scheduler.executor.Execute(task, step, actions[step]) {
if scheduler.executor.Execute(task, step) {
return true
}
......
......@@ -114,6 +114,7 @@ func (suite *TaskSuite) SetupTest() {
suite.cluster = session.NewMockCluster(suite.T())
suite.scheduler = suite.newScheduler()
suite.scheduler.Start(context.Background())
}
func (suite *TaskSuite) BeforeTest(suiteName, testName string) {
......@@ -814,18 +815,24 @@ func (suite *TaskSuite) AssertTaskNum(process, wait, channel, segment int) {
}
func (suite *TaskSuite) dispatchAndWait(node int64) {
timeout := 10 * time.Second
suite.scheduler.Dispatch(node)
for {
count := 0
var keys []any
count := 0
for start := time.Now(); time.Since(start) < timeout; {
count = 0
keys = make([]any, 0)
suite.scheduler.executor.executingActions.Range(func(key, value any) bool {
keys = append(keys, key)
count++
return true
})
if count == 0 {
break
return
}
time.Sleep(200 * time.Millisecond)
}
suite.FailNow("executor hangs in executing tasks", "count=%d keys=%+v", count, keys)
}
func (suite *TaskSuite) newScheduler() *taskScheduler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册