diff --git a/internal/proxy/lb_balancer.go b/internal/proxy/lb_balancer.go index 8747df9bb9a16e1ced8a06a0f525323edd54a2ea..f918ab20a7581c61d2961ff7bb32487efa40dcd1 100644 --- a/internal/proxy/lb_balancer.go +++ b/internal/proxy/lb_balancer.go @@ -16,11 +16,16 @@ package proxy -import "github.com/milvus-io/milvus/internal/proto/internalpb" +import ( + "context" + + "github.com/milvus-io/milvus/internal/proto/internalpb" +) type LBBalancer interface { - SelectNode(availableNodes []int64, nq int64) (int64, error) + SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) CancelWorkload(node int64, nq int64) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) + Start(ctx context.Context) Close() } diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 7294f9cfbb1c9a086ba2b3bfd63d6ec975444fdb..b28f041c3f505ef4fe05671e7d94027c35532621 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -54,6 +54,7 @@ type LBPolicy interface { Execute(ctx context.Context, workload CollectionWorkLoad) error ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) + Start(ctx context.Context) Close() } @@ -81,6 +82,10 @@ func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl { } } +func (lb *LBPolicyImpl) Start(ctx context.Context) { + lb.balancer.Start(ctx) +} + // try to select the best node from the available nodes func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) { log := log.With( @@ -102,7 +107,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload } availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes) - targetNode, err := lb.balancer.SelectNode(availableNodes, workload.nq) + targetNode, err := lb.balancer.SelectNode(ctx, availableNodes, workload.nq) if err != nil { globalMetaCache.DeprecateShardCache(workload.db, workload.collection) nodes, err := getShardLeaders() @@ -120,7 +125,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload return -1, merr.WrapErrNoAvailableNode("all available nodes has been excluded") } - targetNode, err = lb.balancer.SelectNode(availableNodes, workload.nq) + targetNode, err = lb.balancer.SelectNode(ctx, availableNodes, workload.nq) if err != nil { log.Warn("failed to select shard", zap.Error(err)) diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index ba253a2e3129e0d5ea01f8193e1559a81dfac3b4..7694c9f995eda61451fa0064c62fb7a8e69acf6b 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -99,7 +99,9 @@ func (s *LBPolicySuite) SetupTest() { s.mgr = NewMockShardClientManager(s.T()) s.mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() s.lbBalancer = NewMockLBBalancer(s.T()) + s.lbBalancer.EXPECT().Start(context.Background()).Maybe() s.lbPolicy = NewLBPolicyImpl(s.mgr) + s.lbPolicy.Start(context.Background()) s.lbPolicy.balancer = s.lbBalancer err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr) @@ -160,7 +162,7 @@ func (s *LBPolicySuite) loadCollection() { func (s *LBPolicySuite) TestSelectNode() { ctx := context.Background() - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(5, nil) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil) targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{ db: dbName, collection: s.collection, @@ -173,8 +175,8 @@ func (s *LBPolicySuite) TestSelectNode() { // test select node failed, then update shard leader cache and retry, expect success s.lbBalancer.ExpectedCalls = nil - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1) - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(3, nil) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil) targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ db: dbName, collection: s.collection, @@ -187,7 +189,7 @@ func (s *LBPolicySuite) TestSelectNode() { // test select node always fails, expected failure s.lbBalancer.ExpectedCalls = nil - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ db: dbName, collection: s.collection, @@ -200,7 +202,7 @@ func (s *LBPolicySuite) TestSelectNode() { // test all nodes has been excluded, expected failure s.lbBalancer.ExpectedCalls = nil - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ db: dbName, collection: s.collection, @@ -213,7 +215,7 @@ func (s *LBPolicySuite) TestSelectNode() { // test get shard leaders failed, retry to select node failed s.lbBalancer.ExpectedCalls = nil - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) s.qc.ExpectedCalls = nil s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrNoAvailableNodeInReplica) targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ @@ -233,7 +235,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test execute success s.lbBalancer.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ db: dbName, @@ -250,7 +252,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test select node failed, expected error s.lbBalancer.ExpectedCalls = nil - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ db: dbName, collection: s.collection, @@ -268,7 +270,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.lbBalancer.ExpectedCalls = nil - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ db: dbName, @@ -304,7 +306,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.ExpectedCalls = nil - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) counter := 0 err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ @@ -329,7 +331,7 @@ func (s *LBPolicySuite) TestExecute() { ctx := context.Background() // test all channel success s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) - s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{ db: dbName, diff --git a/internal/proxy/look_aside_balancer.go b/internal/proxy/look_aside_balancer.go index 899b1c2e6760bd3bc621bb11cafd8ee33fa4123f..da2532f3bbfe107f5279ce5372184d4d2960bf06 100644 --- a/internal/proxy/look_aside_balancer.go +++ b/internal/proxy/look_aside_balancer.go @@ -19,13 +19,16 @@ package proxy import ( "context" "math" + "strconv" "sync" "time" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" "go.uber.org/atomic" "go.uber.org/zap" @@ -64,11 +67,14 @@ func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer { closeCh: make(chan struct{}), } - balancer.wg.Add(1) - go balancer.checkQueryNodeHealthLoop() return balancer } +func (b *LookAsideBalancer) Start(ctx context.Context) { + b.wg.Add(1) + go b.checkQueryNodeHealthLoop(ctx) +} + func (b *LookAsideBalancer) Close() { b.closeOnce.Do(func() { close(b.closeCh) @@ -76,11 +82,14 @@ func (b *LookAsideBalancer) Close() { }) } -func (b *LookAsideBalancer) SelectNode(availableNodes []int64, cost int64) (int64, error) { +func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) { + log := log.Ctx(ctx).WithRateGroup("proxy.LookAsideBalancer", 60, 1) targetNode := int64(-1) targetScore := float64(math.MaxFloat64) for _, node := range availableNodes { if b.unreachableQueryNodes.Contain(node) { + log.RatedWarn(30, "query node is unreachable, skip it", + zap.Int64("nodeID", node)) continue } @@ -92,18 +101,19 @@ func (b *LookAsideBalancer) SelectNode(availableNodes []int64, cost int64) (int6 } score := b.calculateScore(cost, executingNQ.Load()) + metrics.ProxyWorkLoadScore.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(score) + if targetNode == -1 || score < targetScore { targetScore = score targetNode = node } } - // update executing task cost - totalNQ, ok := b.executingTaskTotalNQ.Get(targetNode) - if !ok { - totalNQ = atomic.NewInt64(0) + if targetNode != -1 { + // update executing task cost + totalNQ, _ := b.executingTaskTotalNQ.Get(targetNode) + totalNQ.Add(cost) } - totalNQ.Add(cost) return targetNode, nil } @@ -132,7 +142,8 @@ func (b *LookAsideBalancer) calculateScore(cost *internalpb.CostAggregation, exe return float64(cost.ResponseTime) - float64(1)/float64(cost.ServiceTime) + math.Pow(float64(1+cost.TotalNQ+executingNQ), 3.0)/float64(cost.ServiceTime) } -func (b *LookAsideBalancer) checkQueryNodeHealthLoop() { +func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { + log := log.Ctx(context.TODO()).WithRateGroup("proxy.LookAsideBalancer", 60, 1) defer b.wg.Done() ticker := time.NewTicker(checkQueryNodeHealthInterval) @@ -152,7 +163,7 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop() { defer cancel() checkHealthFailed := func(err error) bool { - log.Warn("query node check health failed, add it to unreachable nodes list", + log.RatedWarn(30, "query node check health failed, add it to unreachable nodes list", zap.Int64("nodeID", node), zap.Error(err)) b.unreachableQueryNodes.Insert(node) @@ -174,8 +185,12 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop() { } // check health successfully, update check health ts - b.metricsUpdateTs.Insert(node, time.Now().Local().UnixMilli()) - b.unreachableQueryNodes.Remove(node) + b.metricsUpdateTs.Insert(node, time.Now().UnixMilli()) + if b.unreachableQueryNodes.Contain(node) { + b.unreachableQueryNodes.Remove(node) + log.Info("query node check health success, remove it from unreachable nodes list", + zap.Int64("nodeID", node)) + } } return true diff --git a/internal/proxy/look_aside_balancer_test.go b/internal/proxy/look_aside_balancer_test.go index 07d4658f18a6565e3dda26ad07d56fa629b68eaa..3f6778ad9c3c98f5d7c9e16a0b52747f52de955d 100644 --- a/internal/proxy/look_aside_balancer_test.go +++ b/internal/proxy/look_aside_balancer_test.go @@ -17,6 +17,7 @@ package proxy import ( + "context" "testing" "time" @@ -40,6 +41,7 @@ type LookAsideBalancerSuite struct { func (suite *LookAsideBalancerSuite) SetupTest() { suite.clientMgr = NewMockShardClientManager(suite.T()) suite.balancer = NewLookAsideBalancer(suite.clientMgr) + suite.balancer.Start(context.Background()) qn := types.NewMockQueryNode(suite.T()) suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(1)).Return(qn, nil).Maybe() @@ -118,6 +120,18 @@ func (suite *LookAsideBalancerSuite) TestSelectNode() { } cases := []testcase{ + { + name: "qn with empty metrics", + costMetrics: map[int64]*internalpb.CostAggregation{ + 1: {}, + 2: {}, + 3: {}, + }, + + executingNQ: map[int64]int64{}, + requestCount: 100, + result: map[int64]int64{1: 34, 2: 33, 3: 33}, + }, { name: "each qn has same cost metrics", costMetrics: map[int64]*internalpb.CostAggregation{ @@ -219,18 +233,6 @@ func (suite *LookAsideBalancerSuite) TestSelectNode() { requestCount: 100, result: map[int64]int64{1: 40, 2: 40, 3: 20}, }, - { - name: "qn with empty metrics", - costMetrics: map[int64]*internalpb.CostAggregation{ - 1: {}, - 2: {}, - 3: {}, - }, - - executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0}, - requestCount: 100, - result: map[int64]int64{1: 34, 2: 33, 3: 33}, - }, } for _, c := range cases { @@ -242,10 +244,9 @@ func (suite *LookAsideBalancerSuite) TestSelectNode() { for node, executingNQ := range c.executingNQ { suite.balancer.executingTaskTotalNQ.Insert(node, atomic.NewInt64(executingNQ)) } - counter := make(map[int64]int64) for i := 0; i < c.requestCount; i++ { - node, err := suite.balancer.SelectNode([]int64{1, 2, 3}, 1) + node, err := suite.balancer.SelectNode(context.TODO(), []int64{1, 2, 3}, 1) suite.NoError(err) counter[node]++ } @@ -258,7 +259,7 @@ func (suite *LookAsideBalancerSuite) TestSelectNode() { } func (suite *LookAsideBalancerSuite) TestCancelWorkload() { - node, err := suite.balancer.SelectNode([]int64{1, 2, 3}, 10) + node, err := suite.balancer.SelectNode(context.TODO(), []int64{1, 2, 3}, 10) suite.NoError(err) suite.balancer.CancelWorkload(node, 10) @@ -281,12 +282,41 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { suite.Eventually(func() bool { return suite.balancer.unreachableQueryNodes.Contain(1) }, 2*time.Second, 100*time.Millisecond) + targetNode, err := suite.balancer.SelectNode(context.Background(), []int64{1}, 1) + suite.NoError(err) + suite.Equal(int64(-1), targetNode) suite.Eventually(func() bool { return !suite.balancer.unreachableQueryNodes.Contain(2) }, 3*time.Second, 100*time.Millisecond) } +func (suite *LookAsideBalancerSuite) TestNodeRecover() { + // mock qn down for a while and then recover + qn3 := types.NewMockQueryNode(suite.T()) + suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(3)).Return(qn3, nil) + qn3.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + State: &milvuspb.ComponentInfo{ + StateCode: commonpb.StateCode_Abnormal, + }, + }, nil).Times(3) + + qn3.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + State: &milvuspb.ComponentInfo{ + StateCode: commonpb.StateCode_Healthy, + }, + }, nil) + + suite.balancer.metricsUpdateTs.Insert(3, time.Now().UnixMilli()) + suite.Eventually(func() bool { + return suite.balancer.unreachableQueryNodes.Contain(3) + }, 2*time.Second, 100*time.Millisecond) + + suite.Eventually(func() bool { + return !suite.balancer.unreachableQueryNodes.Contain(3) + }, 3*time.Second, 100*time.Millisecond) +} + func TestLookAsideBalancerSuite(t *testing.T) { suite.Run(t, new(LookAsideBalancerSuite)) } diff --git a/internal/proxy/mock_lb_balancer.go b/internal/proxy/mock_lb_balancer.go index 0a0550b48b3824d1856e91a8bc3808765d5ba894..3cc78254a98f27fdf5f687c7a467048cb124a571 100644 --- a/internal/proxy/mock_lb_balancer.go +++ b/internal/proxy/mock_lb_balancer.go @@ -3,6 +3,8 @@ package proxy import ( + context "context" + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" mock "github.com/stretchr/testify/mock" ) @@ -86,23 +88,23 @@ func (_c *MockLBBalancer_Close_Call) RunAndReturn(run func()) *MockLBBalancer_Cl return _c } -// SelectNode provides a mock function with given fields: availableNodes, nq -func (_m *MockLBBalancer) SelectNode(availableNodes []int64, nq int64) (int64, error) { - ret := _m.Called(availableNodes, nq) +// SelectNode provides a mock function with given fields: ctx, availableNodes, nq +func (_m *MockLBBalancer) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) { + ret := _m.Called(ctx, availableNodes, nq) var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func([]int64, int64) (int64, error)); ok { - return rf(availableNodes, nq) + if rf, ok := ret.Get(0).(func(context.Context, []int64, int64) (int64, error)); ok { + return rf(ctx, availableNodes, nq) } - if rf, ok := ret.Get(0).(func([]int64, int64) int64); ok { - r0 = rf(availableNodes, nq) + if rf, ok := ret.Get(0).(func(context.Context, []int64, int64) int64); ok { + r0 = rf(ctx, availableNodes, nq) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func([]int64, int64) error); ok { - r1 = rf(availableNodes, nq) + if rf, ok := ret.Get(1).(func(context.Context, []int64, int64) error); ok { + r1 = rf(ctx, availableNodes, nq) } else { r1 = ret.Error(1) } @@ -116,15 +118,16 @@ type MockLBBalancer_SelectNode_Call struct { } // SelectNode is a helper method to define mock.On call +// - ctx context.Context // - availableNodes []int64 // - nq int64 -func (_e *MockLBBalancer_Expecter) SelectNode(availableNodes interface{}, nq interface{}) *MockLBBalancer_SelectNode_Call { - return &MockLBBalancer_SelectNode_Call{Call: _e.mock.On("SelectNode", availableNodes, nq)} +func (_e *MockLBBalancer_Expecter) SelectNode(ctx interface{}, availableNodes interface{}, nq interface{}) *MockLBBalancer_SelectNode_Call { + return &MockLBBalancer_SelectNode_Call{Call: _e.mock.On("SelectNode", ctx, availableNodes, nq)} } -func (_c *MockLBBalancer_SelectNode_Call) Run(run func(availableNodes []int64, nq int64)) *MockLBBalancer_SelectNode_Call { +func (_c *MockLBBalancer_SelectNode_Call) Run(run func(ctx context.Context, availableNodes []int64, nq int64)) *MockLBBalancer_SelectNode_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]int64), args[1].(int64)) + run(args[0].(context.Context), args[1].([]int64), args[2].(int64)) }) return _c } @@ -134,7 +137,40 @@ func (_c *MockLBBalancer_SelectNode_Call) Return(_a0 int64, _a1 error) *MockLBBa return _c } -func (_c *MockLBBalancer_SelectNode_Call) RunAndReturn(run func([]int64, int64) (int64, error)) *MockLBBalancer_SelectNode_Call { +func (_c *MockLBBalancer_SelectNode_Call) RunAndReturn(run func(context.Context, []int64, int64) (int64, error)) *MockLBBalancer_SelectNode_Call { + _c.Call.Return(run) + return _c +} + +// Start provides a mock function with given fields: ctx +func (_m *MockLBBalancer) Start(ctx context.Context) { + _m.Called(ctx) +} + +// MockLBBalancer_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type MockLBBalancer_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockLBBalancer_Expecter) Start(ctx interface{}) *MockLBBalancer_Start_Call { + return &MockLBBalancer_Start_Call{Call: _e.mock.On("Start", ctx)} +} + +func (_c *MockLBBalancer_Start_Call) Run(run func(ctx context.Context)) *MockLBBalancer_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockLBBalancer_Start_Call) Return() *MockLBBalancer_Start_Call { + _c.Call.Return() + return _c +} + +func (_c *MockLBBalancer_Start_Call) RunAndReturn(run func(context.Context)) *MockLBBalancer_Start_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/mock_lb_policy.go b/internal/proxy/mock_lb_policy.go index c6d907635f6969e991feead644b07cbd4fa70637..5eae7bc858af992de40b6f1110bc4d7a6a5712f2 100644 --- a/internal/proxy/mock_lb_policy.go +++ b/internal/proxy/mock_lb_policy.go @@ -22,6 +22,38 @@ func (_m *MockLBPolicy) EXPECT() *MockLBPolicy_Expecter { return &MockLBPolicy_Expecter{mock: &_m.Mock} } +// Close provides a mock function with given fields: +func (_m *MockLBPolicy) Close() { + _m.Called() +} + +// MockLBPolicy_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockLBPolicy_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockLBPolicy_Expecter) Close() *MockLBPolicy_Close_Call { + return &MockLBPolicy_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockLBPolicy_Close_Call) Run(run func()) *MockLBPolicy_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockLBPolicy_Close_Call) Return() *MockLBPolicy_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockLBPolicy_Close_Call) RunAndReturn(run func()) *MockLBPolicy_Close_Call { + _c.Call.Return(run) + return _c +} + // Execute provides a mock function with given fields: ctx, workload func (_m *MockLBPolicy) Execute(ctx context.Context, workload CollectionWorkLoad) error { ret := _m.Called(ctx, workload) @@ -108,6 +140,39 @@ func (_c *MockLBPolicy_ExecuteWithRetry_Call) RunAndReturn(run func(context.Cont return _c } +// Start provides a mock function with given fields: ctx +func (_m *MockLBPolicy) Start(ctx context.Context) { + _m.Called(ctx) +} + +// MockLBPolicy_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type MockLBPolicy_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockLBPolicy_Expecter) Start(ctx interface{}) *MockLBPolicy_Start_Call { + return &MockLBPolicy_Start_Call{Call: _e.mock.On("Start", ctx)} +} + +func (_c *MockLBPolicy_Start_Call) Run(run func(ctx context.Context)) *MockLBPolicy_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockLBPolicy_Start_Call) Return() *MockLBPolicy_Start_Call { + _c.Call.Return() + return _c +} + +func (_c *MockLBPolicy_Start_Call) RunAndReturn(run func(context.Context)) *MockLBPolicy_Start_Call { + _c.Call.Return(run) + return _c +} + // UpdateCostMetrics provides a mock function with given fields: node, cost func (_m *MockLBPolicy) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) { _m.Called(node, cost) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index e11ddba58914a5942af6a7d5d5f3d32b33b0a239..9a58b846b0d8ddac567cfed06689aced618595c6 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -120,6 +120,8 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { ctx1, cancel := context.WithCancel(ctx) n := 1024 // better to be configurable mgr := newShardClientMgr() + lbPolicy := NewLBPolicyImpl(mgr) + lbPolicy.Start(ctx) node := &Proxy{ ctx: ctx1, cancel: cancel, @@ -127,7 +129,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { searchResultCh: make(chan *internalpb.SearchResults, n), shardMgr: mgr, multiRateLimiter: NewMultiRateLimiter(), - lbPolicy: NewLBPolicyImpl(mgr), + lbPolicy: lbPolicy, } node.UpdateStateCode(commonpb.StateCode_Abnormal) logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load())) diff --git a/internal/proxy/roundrobin_balancer.go b/internal/proxy/roundrobin_balancer.go index 3ab0d7823164c5dff1b9d463ac3950b726a1759e..962dc6e4ff083e910347f16abf349ecc3960802f 100644 --- a/internal/proxy/roundrobin_balancer.go +++ b/internal/proxy/roundrobin_balancer.go @@ -16,6 +16,8 @@ package proxy import ( + "context" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -33,7 +35,7 @@ func NewRoundRobinBalancer() *RoundRobinBalancer { } } -func (b *RoundRobinBalancer) SelectNode(availableNodes []int64, cost int64) (int64, error) { +func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) { if len(availableNodes) == 0 { return -1, merr.ErrNoAvailableNode } @@ -68,4 +70,6 @@ func (b *RoundRobinBalancer) CancelWorkload(node int64, nq int64) { func (b *RoundRobinBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {} +func (b *RoundRobinBalancer) Start(ctx context.Context) {} + func (b *RoundRobinBalancer) Close() {} diff --git a/internal/proxy/roundrobin_balancer_test.go b/internal/proxy/roundrobin_balancer_test.go index 71a7b7ac4237da0d0769fc90d22bd08e21deb510..8840099bc2fc8d40863b2b7a5be36038ed7de6fc 100644 --- a/internal/proxy/roundrobin_balancer_test.go +++ b/internal/proxy/roundrobin_balancer_test.go @@ -16,6 +16,7 @@ package proxy import ( + "context" "testing" "github.com/stretchr/testify/suite" @@ -29,14 +30,15 @@ type RoundRobinBalancerSuite struct { func (s *RoundRobinBalancerSuite) SetupTest() { s.balancer = NewRoundRobinBalancer() + s.balancer.Start(context.Background()) } func (s *RoundRobinBalancerSuite) TestRoundRobin() { availableNodes := []int64{1, 2} - s.balancer.SelectNode(availableNodes, 1) - s.balancer.SelectNode(availableNodes, 1) - s.balancer.SelectNode(availableNodes, 1) - s.balancer.SelectNode(availableNodes, 1) + s.balancer.SelectNode(context.TODO(), availableNodes, 1) + s.balancer.SelectNode(context.TODO(), availableNodes, 1) + s.balancer.SelectNode(context.TODO(), availableNodes, 1) + s.balancer.SelectNode(context.TODO(), availableNodes, 1) workload, ok := s.balancer.nodeWorkload.Get(1) s.True(ok) @@ -45,10 +47,10 @@ func (s *RoundRobinBalancerSuite) TestRoundRobin() { s.True(ok) s.Equal(int64(2), workload.Load()) - s.balancer.SelectNode(availableNodes, 3) - s.balancer.SelectNode(availableNodes, 1) - s.balancer.SelectNode(availableNodes, 1) - s.balancer.SelectNode(availableNodes, 1) + s.balancer.SelectNode(context.TODO(), availableNodes, 3) + s.balancer.SelectNode(context.TODO(), availableNodes, 1) + s.balancer.SelectNode(context.TODO(), availableNodes, 1) + s.balancer.SelectNode(context.TODO(), availableNodes, 1) workload, ok = s.balancer.nodeWorkload.Get(1) s.True(ok) @@ -60,13 +62,13 @@ func (s *RoundRobinBalancerSuite) TestRoundRobin() { func (s *RoundRobinBalancerSuite) TestNoAvailableNode() { availableNodes := []int64{} - _, err := s.balancer.SelectNode(availableNodes, 1) + _, err := s.balancer.SelectNode(context.TODO(), availableNodes, 1) s.Error(err) } func (s *RoundRobinBalancerSuite) TestCancelWorkload() { availableNodes := []int64{101} - _, err := s.balancer.SelectNode(availableNodes, 5) + _, err := s.balancer.SelectNode(context.TODO(), availableNodes, 5) s.NoError(err) workload, ok := s.balancer.nodeWorkload.Get(101) s.True(ok) diff --git a/pkg/metrics/proxy_metrics.go b/pkg/metrics/proxy_metrics.go index 5d115a073b7846abd90a4e7449a319c5e271b09f..e7e34ca0e68cc207934821dbd278163c1e96653f 100644 --- a/pkg/metrics/proxy_metrics.go +++ b/pkg/metrics/proxy_metrics.go @@ -247,6 +247,16 @@ var ( Name: "user_rpc_count", Help: "the rpc count of a user", }, []string{usernameLabelName}) + + // ProxyWorkLoadScore record the score that measured query node's workload. + ProxyWorkLoadScore = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: milvusNamespace, + Subsystem: typeutil.ProxyRole, + Name: "workload_score", + Help: "score that measured query node's workload", + Buckets: buckets, + }, []string{nodeIDLabelName}) ) // RegisterProxy registers Proxy metrics @@ -284,6 +294,8 @@ func RegisterProxy(registry *prometheus.Registry) { registry.MustRegister(ProxyLimiterRate) registry.MustRegister(ProxyHookFunc) registry.MustRegister(UserRPCCounter) + + registry.MustRegister(ProxyWorkLoadScore) } func CleanupCollectionMetrics(nodeID int64, collection string) {