未验证 提交 302897f8 编写于 作者: W wei liu 提交者: GitHub

refine look aside balance logic (#25837)

Signed-off-by: NWei Liu <wei.liu@zilliz.com>
上级 a669440e
......@@ -19,6 +19,7 @@ package proxy
import (
"context"
"math"
"math/rand"
"strconv"
"sync"
"time"
......@@ -35,11 +36,6 @@ import (
"go.uber.org/zap"
)
var (
checkQueryNodeHealthInterval = 500 * time.Millisecond
CostMetricsExpireTime = 1000 * time.Millisecond
)
type LookAsideBalancer struct {
clientMgr shardClientMgr
......@@ -88,6 +84,9 @@ func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int
log := log.Ctx(ctx).WithRateGroup("proxy.LookAsideBalancer", 1, 60)
targetNode := int64(-1)
targetScore := float64(math.MaxFloat64)
rand.Shuffle(len(availableNodes), func(i, j int) {
availableNodes[i], availableNodes[j] = availableNodes[j], availableNodes[i]
})
for _, node := range availableNodes {
if b.unreachableQueryNodes.Contain(node) {
log.RatedWarn(5, "query node is unreachable, skip it",
......@@ -117,7 +116,8 @@ func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int
// update executing task cost
totalNQ, _ := b.executingTaskTotalNQ.Get(targetNode)
totalNQ.Add(cost)
nq := totalNQ.Add(cost)
metrics.ProxyExecutingTotalNq.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Set(float64(nq))
return targetNode, nil
}
......@@ -126,28 +126,31 @@ func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int
func (b *LookAsideBalancer) CancelWorkload(node int64, nq int64) {
totalNQ, ok := b.executingTaskTotalNQ.Get(node)
if ok {
totalNQ.Sub(nq)
nq := totalNQ.Sub(nq)
metrics.ProxyExecutingTotalNq.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Set(float64(nq))
}
}
// UpdateCostMetrics used for cache some metrics of recent search/query cost
func (b *LookAsideBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
// cache the latest query node cost metrics for updating the score
b.metricsMap.Insert(node, cost)
if cost != nil {
b.metricsMap.Insert(node, cost)
}
b.metricsUpdateTs.Insert(node, time.Now().UnixMilli())
}
// calculateScore compute the query node's workload score
// https://www.usenix.org/conference/nsdi15/technical-sessions/presentation/suresh
func (b *LookAsideBalancer) calculateScore(node int64, cost *internalpb.CostAggregation, executingNQ int64) float64 {
if cost == nil || cost.ResponseTime == 0 || cost.ServiceTime == 0 {
return math.Pow(float64(1+executingNQ), 3.0)
if cost == nil || cost.GetResponseTime() == 0 {
return math.Pow(float64(executingNQ), 3.0)
}
// for multi-replica cases, when there are no task which waiting in queue,
// the response time will effect the score, to prevent the score based on a too old value
// we expire the cost metrics by second if no task in queue.
if executingNQ == 0 && cost.TotalNQ == 0 && b.isNodeCostMetricsTooOld(node) {
if executingNQ == 0 && b.isNodeCostMetricsTooOld(node) {
return 0
}
......@@ -167,13 +170,14 @@ func (b *LookAsideBalancer) isNodeCostMetricsTooOld(node int64) bool {
return false
}
return time.Now().UnixMilli()-lastUpdateTs > CostMetricsExpireTime.Milliseconds()
return time.Now().UnixMilli()-lastUpdateTs > Params.ProxyCfg.CostMetricsExpireTime.GetAsInt64()
}
func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
log := log.Ctx(ctx).WithRateGroup("proxy.LookAsideBalancer", 1, 60)
defer b.wg.Done()
checkQueryNodeHealthInterval := Params.ProxyCfg.CheckQueryNodeHealthInterval.GetAsDuration(time.Millisecond)
ticker := time.NewTicker(checkQueryNodeHealthInterval)
defer ticker.Stop()
log.Info("Start check query node health loop")
......@@ -190,7 +194,7 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
b.metricsUpdateTs.Range(func(node int64, lastUpdateTs int64) bool {
if now-lastUpdateTs > checkQueryNodeHealthInterval.Milliseconds() {
futures = append(futures, pool.Submit(func() (any, error) {
checkInterval := paramtable.Get().ProxyCfg.HealthCheckTimetout.GetAsDuration(time.Millisecond)
checkInterval := Params.ProxyCfg.HealthCheckTimetout.GetAsDuration(time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), checkInterval)
defer cancel()
......
......@@ -272,7 +272,7 @@ func (suite *LookAsideBalancerSuite) TestSelectNode() {
}
for node, result := range c.result {
suite.Equal(result, counter[node])
suite.True(math.Abs(float64(result-counter[node])) <= float64(1))
}
})
}
......@@ -302,7 +302,7 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
suite.balancer.unreachableQueryNodes.Insert(2)
suite.Eventually(func() bool {
return suite.balancer.unreachableQueryNodes.Contain(1)
}, 2*time.Second, 100*time.Millisecond)
}, 3*time.Second, 100*time.Millisecond)
targetNode, err := suite.balancer.SelectNode(context.Background(), []int64{1}, 1)
suite.ErrorIs(err, merr.ErrServiceUnavailable)
suite.Equal(int64(-1), targetNode)
......@@ -331,11 +331,11 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() {
suite.balancer.metricsUpdateTs.Insert(3, time.Now().UnixMilli())
suite.Eventually(func() bool {
return suite.balancer.unreachableQueryNodes.Contain(3)
}, 2*time.Second, 100*time.Millisecond)
}, 5*time.Second, 100*time.Millisecond)
suite.Eventually(func() bool {
return !suite.balancer.unreachableQueryNodes.Contain(3)
}, 3*time.Second, 100*time.Millisecond)
}, 5*time.Second, 100*time.Millisecond)
}
func TestLookAsideBalancerSuite(t *testing.T) {
......
......@@ -32,7 +32,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
......@@ -227,27 +226,6 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque
return ret, nil
}
func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
return nil, merr.WrapErrCollectionNotFound(req.Req.GetCollectionID())
}
// Send task to scheduler and wait until it finished.
task := tasks.NewQueryTask(ctx, collection, node.manager, req)
if err := node.scheduler.Add(task); err != nil {
log.Warn("failed to add query task into scheduler", zap.Error(err))
return nil, err
}
err := task.Wait()
if err != nil {
log.Warn("failed to execute task by node scheduler", zap.Error(err))
return nil, err
}
return task.Result(), nil
}
func (node *QueryNode) optimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, deleg delegator.ShardDelegator) (*querypb.SearchRequest, error) {
// no hook applied, just return
if node.queryHook == nil {
......
......@@ -102,7 +102,7 @@ func (w *LocalWorker) SearchSegments(ctx context.Context, req *querypb.SearchReq
}
func (w *LocalWorker) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
return w.node.querySegments(ctx, req)
return w.node.QuerySegments(ctx, req)
}
func (w *LocalWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
......
......@@ -75,11 +75,8 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
return nil, err
}
requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
if result.CostAggregation == nil {
return nil, false
}
return result.CostAggregation, true
requestCosts := lo.Map(results, func(result *internalpb.SearchResults, _ int) *internalpb.CostAggregation {
return result.GetCostAggregation()
})
searchResults.CostAggregation = mergeRequestCost(requestCosts)
......
......@@ -742,10 +742,8 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
result := task.Result()
if result.CostAggregation != nil {
// update channel's response time
result.CostAggregation.ResponseTime = latency.Milliseconds()
}
result.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
result.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ()
return result, nil
}
......@@ -767,6 +765,8 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
tr := timerecord.NewTimeRecorderWithTrace(ctx, "SearchRequest")
if !node.lifetime.Add(commonpbutil.IsHealthy) {
msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID())
err := merr.WrapErrServiceNotReady(msg)
......@@ -844,7 +844,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
return failRet, nil
}
tr := timerecord.NewTimeRecorderWithTrace(ctx, "searchRequestReduce")
tr.RecordSpan()
result, err := segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
......@@ -852,18 +852,17 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
failRet.Status.Reason = err.Error()
return failRet, nil
}
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards).
Observe(float64(tr.ElapseSpan().Milliseconds()))
Observe(float64(reduceLatency.Milliseconds()))
collector.Rate.Add(metricsinfo.NQPerSecond, float64(req.GetReq().GetNq()))
collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).
Add(float64(proto.Size(req)))
if result.CostAggregation != nil {
// update channel's response time
currentTotalNQ := node.scheduler.GetWaitingTaskTotalNQ()
result.CostAggregation.TotalNQ = currentTotalNQ
if result.GetCostAggregation() != nil {
result.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
}
return result, nil
}
......@@ -904,7 +903,18 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ
defer cancel()
tr := timerecord.NewTimeRecorder("querySegments")
results, err := node.querySegments(queryCtx, req)
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
return nil, merr.WrapErrCollectionNotFound(req.Req.GetCollectionID())
}
// Send task to scheduler and wait until it finished.
task := tasks.NewQueryTask(queryCtx, collection, node.manager, req)
if err := node.scheduler.Add(task); err != nil {
log.Warn("failed to add query task into scheduler", zap.Error(err))
return nil, err
}
err := task.Wait()
if err != nil {
log.Warn("failed to query channel", zap.Error(err))
failRet.Status.Reason = err.Error()
......@@ -923,19 +933,17 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
results.CostAggregation = &internalpb.CostAggregation{
ServiceTime: latency.Milliseconds(),
ResponseTime: latency.Milliseconds(),
TotalNQ: 0,
}
return results, nil
result := task.Result()
result.GetCostAggregation().ResponseTime = latency.Milliseconds()
result.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ()
return result, nil
}
// Query performs replica query tasks.
func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
if req.FromShardLeader {
// for compatible with rolling upgrade from version before v2.2.9
return node.querySegments(ctx, req)
return node.QuerySegments(ctx, req)
}
log := log.Ctx(ctx).With(
......@@ -950,6 +958,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
zap.Uint64("travelTimestamp", req.GetReq().GetTravelTimestamp()),
zap.Bool("isCount", req.GetReq().GetIsCount()),
)
tr := timerecord.NewTimeRecorderWithTrace(ctx, "QueryRequest")
if !node.lifetime.Add(commonpbutil.IsHealthy) {
msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID())
......@@ -1000,24 +1009,23 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil
}
tr := timerecord.NewTimeRecorderWithTrace(ctx, "queryRequestReduce")
tr.RecordSpan()
reducer := segments.CreateInternalReducer(req, node.manager.Collection.Get(req.GetReq().GetCollectionID()).Schema())
ret, err := reducer.Reduce(ctx, toMergeResults)
if err != nil {
return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil
}
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.ReduceShards).
Observe(float64(tr.ElapseSpan().Milliseconds()))
Observe(float64(reduceLatency.Milliseconds()))
if !req.FromShardLeader {
collector.Rate.Add(metricsinfo.NQPerSecond, 1)
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
}
if ret.CostAggregation != nil {
// update channel's response time
currentTotalNQ := node.scheduler.GetWaitingTaskTotalNQ()
ret.CostAggregation.TotalNQ = currentTotalNQ
if ret.GetCostAggregation() != nil {
ret.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
}
return ret, nil
}
......
......@@ -75,6 +75,8 @@ func (t *QueryTask) PreExecute() error {
// Execute the task, only call once.
func (t *QueryTask) Execute() error {
tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "QueryTask")
retrievePlan, err := segments.NewRetrievePlan(
t.collection,
t.req.Req.GetSerializedExprPlan(),
......@@ -124,6 +126,9 @@ func (t *QueryTask) Execute() error {
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Ids: reducedResult.Ids,
FieldsData: reducedResult.FieldsData,
CostAggregation: &internalpb.CostAggregation{
ServiceTime: tr.ElapseSpan().Milliseconds(),
},
}
return nil
}
......
......@@ -114,7 +114,7 @@ func (t *SearchTask) Execute() error {
zap.String("shard", t.req.GetDmlChannels()[0]),
)
executeRecord := timerecord.NewTimeRecorderWithTrace(t.ctx, "searchTaskExecute")
tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "SearchTask")
req := t.req
t.combinePlaceHolderGroups()
......@@ -166,14 +166,14 @@ func (t *SearchTask) Execute() error {
SlicedOffset: 1,
SlicedNumCount: 1,
CostAggregation: &internalpb.CostAggregation{
ServiceTime: executeRecord.ElapseSpan().Milliseconds(),
ServiceTime: tr.ElapseSpan().Milliseconds(),
},
}
}
return nil
}
reduceRecord := timerecord.NewTimeRecorderWithTrace(t.ctx, "searchTaskReduce")
tr.RecordSpan()
blobs, err := segments.ReduceSearchResultsAndFillData(
searchReq.Plan(),
results,
......@@ -186,6 +186,7 @@ func (t *SearchTask) Execute() error {
return err
}
defer segments.DeleteSearchResultDataBlobs(blobs)
reduceLatency := tr.RecordSpan()
for i := range t.originNqs {
blob, err := segments.GetSearchResultDataBlob(blobs, i)
......@@ -208,7 +209,7 @@ func (t *SearchTask) Execute() error {
fmt.Sprint(paramtable.GetNodeID()),
metrics.SearchLabel,
metrics.ReduceSegments).
Observe(float64(reduceRecord.ElapseSpan().Milliseconds()))
Observe(float64(reduceLatency.Milliseconds()))
task.result = &internalpb.SearchResults{
Status: util.WrapStatus(commonpb.ErrorCode_Success, ""),
......@@ -219,7 +220,7 @@ func (t *SearchTask) Execute() error {
SlicedOffset: 1,
SlicedNumCount: 1,
CostAggregation: &internalpb.CostAggregation{
ServiceTime: executeRecord.ElapseSpan().Milliseconds(),
ServiceTime: tr.ElapseSpan().Milliseconds(),
},
}
}
......
......@@ -258,6 +258,16 @@ var (
}, []string{
nodeIDLabelName,
})
ProxyExecutingTotalNq = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: milvusNamespace,
Subsystem: typeutil.ProxyRole,
Name: "executing_total_nq",
Help: "total nq of executing search/query",
}, []string{
nodeIDLabelName,
})
)
// RegisterProxy registers Proxy metrics
......
......@@ -892,23 +892,25 @@ type proxyConfig struct {
// Alias string
SoPath ParamItem `refreshable:"false"`
TimeTickInterval ParamItem `refreshable:"false"`
HealthCheckTimetout ParamItem `refreshable:"true"`
MsgStreamTimeTickBufSize ParamItem `refreshable:"true"`
MaxNameLength ParamItem `refreshable:"true"`
MaxUsernameLength ParamItem `refreshable:"true"`
MinPasswordLength ParamItem `refreshable:"true"`
MaxPasswordLength ParamItem `refreshable:"true"`
MaxFieldNum ParamItem `refreshable:"true"`
MaxShardNum ParamItem `refreshable:"true"`
MaxDimension ParamItem `refreshable:"true"`
GinLogging ParamItem `refreshable:"false"`
MaxUserNum ParamItem `refreshable:"true"`
MaxRoleNum ParamItem `refreshable:"true"`
MaxTaskNum ParamItem `refreshable:"false"`
AccessLog AccessLogConfig
ShardLeaderCacheInterval ParamItem `refreshable:"false"`
ReplicaSelectionPolicy ParamItem `refreshable:"false"`
TimeTickInterval ParamItem `refreshable:"false"`
HealthCheckTimetout ParamItem `refreshable:"true"`
MsgStreamTimeTickBufSize ParamItem `refreshable:"true"`
MaxNameLength ParamItem `refreshable:"true"`
MaxUsernameLength ParamItem `refreshable:"true"`
MinPasswordLength ParamItem `refreshable:"true"`
MaxPasswordLength ParamItem `refreshable:"true"`
MaxFieldNum ParamItem `refreshable:"true"`
MaxShardNum ParamItem `refreshable:"true"`
MaxDimension ParamItem `refreshable:"true"`
GinLogging ParamItem `refreshable:"false"`
MaxUserNum ParamItem `refreshable:"true"`
MaxRoleNum ParamItem `refreshable:"true"`
MaxTaskNum ParamItem `refreshable:"false"`
AccessLog AccessLogConfig
ShardLeaderCacheInterval ParamItem `refreshable:"false"`
ReplicaSelectionPolicy ParamItem `refreshable:"false"`
CheckQueryNodeHealthInterval ParamItem `refreshable:"false"`
CostMetricsExpireTime ParamItem `refreshable:"true"`
}
func (p *proxyConfig) init(base *BaseTable) {
......@@ -1124,7 +1126,7 @@ please adjust in embedded Milvus: false`,
p.ShardLeaderCacheInterval = ParamItem{
Key: "proxy.shardLeaderCacheInterval",
Version: "2.2.4",
DefaultValue: "30",
DefaultValue: "10",
Doc: "time interval to update shard leader cache, in seconds",
}
p.ShardLeaderCacheInterval.Init(base.mgr)
......@@ -1136,6 +1138,23 @@ please adjust in embedded Milvus: false`,
Doc: "replica selection policy in multiple replicas load balancing, support round_robin and look_aside",
}
p.ReplicaSelectionPolicy.Init(base.mgr)
p.CheckQueryNodeHealthInterval = ParamItem{
Key: "proxy.checkQueryNodeHealthInterval",
Version: "2.3.0",
DefaultValue: "1000",
Doc: "time interval to check health for query node, in ms",
}
p.CheckQueryNodeHealthInterval.Init(base.mgr)
p.CostMetricsExpireTime = ParamItem{
Key: "proxy.costMetricsExpireTime",
Version: "2.3.0",
DefaultValue: "1000",
Doc: "expire time for query node cost metrics, in ms",
}
p.CostMetricsExpireTime.Init(base.mgr)
}
// /////////////////////////////////////////////////////////////////////////////
......
......@@ -183,6 +183,8 @@ func TestComponentParam(t *testing.T) {
assert.Equal(t, Params.ReplicaSelectionPolicy.GetValue(), "round_robin")
params.Save(Params.ReplicaSelectionPolicy.Key, "look_aside")
assert.Equal(t, Params.ReplicaSelectionPolicy.GetValue(), "look_aside")
assert.Equal(t, Params.CheckQueryNodeHealthInterval.GetAsInt(), 1000)
assert.Equal(t, Params.CostMetricsExpireTime.GetAsInt(), 1000)
})
// t.Run("test proxyConfig panic", func(t *testing.T) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册