diff --git a/internal/querynodev2/tasks/scheduler.go b/internal/querynodev2/tasks/scheduler.go index 404d5aea8c21bf1a1afec5e5c0f27468285efe46..fd78c98034af523c8a65a453214a2d11c3aa05da 100644 --- a/internal/querynodev2/tasks/scheduler.go +++ b/internal/querynodev2/tasks/scheduler.go @@ -3,6 +3,7 @@ package tasks import ( "context" "fmt" + "sync" "go.uber.org/atomic" @@ -25,7 +26,9 @@ type Scheduler struct { queryProcessQueue chan *QueryTask queryWaitQueue chan *QueryTask - pool *conc.Pool[any] + pool *conc.Pool[any] + runningThreadNum int + cond *sync.Cond } func NewScheduler() *Scheduler { @@ -39,6 +42,7 @@ func NewScheduler() *Scheduler { // queryProcessQueue: make(chan), pool: conc.NewPool[any](maxReadConcurrency, ants.WithPreAlloc(true)), + cond: sync.NewCond(&sync.Mutex{}), } } @@ -151,7 +155,23 @@ func (s *Scheduler) processAll(ctx context.Context) { } func (s *Scheduler) process(t Task) { + s.cond.L.Lock() + for s.runningThreadNum >= s.pool.Cap() { + s.cond.Wait() + } + s.runningThreadNum += t.Weight() + s.cond.L.Unlock() + s.pool.Submit(func() (any, error) { + defer func() { + s.cond.L.Lock() + defer s.cond.L.Unlock() + s.runningThreadNum -= t.Weight() + if s.runningThreadNum < s.pool.Cap() { + s.cond.Broadcast() + } + }() + metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() err := t.Execute() diff --git a/internal/querynodev2/tasks/task.go b/internal/querynodev2/tasks/task.go index cc1f5f5d8904dccb864e80450a7b9bc88891af34..322e08b35f29f0204134847f3d42c32c25c378f4 100644 --- a/internal/querynodev2/tasks/task.go +++ b/internal/querynodev2/tasks/task.go @@ -25,6 +25,7 @@ type Task interface { Done(err error) Canceled() error Wait() error + Weight() int } type SearchTask struct { @@ -235,6 +236,10 @@ func (t *SearchTask) Wait() error { return <-t.notifier } +func (t *SearchTask) Weight() int { + return int(t.nq) +} + func (t *SearchTask) Result() *internalpb.SearchResults { return t.result }