未验证 提交 9614e61f 编写于 作者: S smellthemoon 提交者: GitHub

Fix collection and channel not match (#25859)

Signed-off-by: Nlixinguo <xinguo.li@zilliz.com>
Co-authored-by: Nlixinguo <xinguo.li@zilliz.com>
上级 eade5f9b
...@@ -34,20 +34,22 @@ import ( ...@@ -34,20 +34,22 @@ import (
type executeFunc func(context.Context, UniqueID, types.QueryNode, ...string) error type executeFunc func(context.Context, UniqueID, types.QueryNode, ...string) error
type ChannelWorkload struct { type ChannelWorkload struct {
db string db string
collection string collectionName string
channel string collectionID int64
shardLeaders []int64 channel string
nq int64 shardLeaders []int64
exec executeFunc nq int64
retryTimes uint exec executeFunc
retryTimes uint
} }
type CollectionWorkLoad struct { type CollectionWorkLoad struct {
db string db string
collection string collectionName string
nq int64 collectionID int64
exec executeFunc nq int64
exec executeFunc
} }
type LBPolicy interface { type LBPolicy interface {
...@@ -89,7 +91,8 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) { ...@@ -89,7 +91,8 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) {
// try to select the best node from the available nodes // try to select the best node from the available nodes
func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) { func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
log := log.With( log := log.With(
zap.String("collectionName", workload.collection), zap.Int64("collectionID", workload.collectionID),
zap.String("collectionName", workload.collectionName),
zap.String("channelName", workload.channel), zap.String("channelName", workload.channel),
) )
...@@ -98,7 +101,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload ...@@ -98,7 +101,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
} }
getShardLeaders := func() ([]int64, error) { getShardLeaders := func() ([]int64, error) {
shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collection) shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collectionName, workload.collectionID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -109,7 +112,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload ...@@ -109,7 +112,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes) availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes)
targetNode, err := lb.balancer.SelectNode(ctx, availableNodes, workload.nq) targetNode, err := lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
if err != nil { if err != nil {
globalMetaCache.DeprecateShardCache(workload.db, workload.collection) globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
nodes, err := getShardLeaders() nodes, err := getShardLeaders()
if err != nil || len(nodes) == 0 { if err != nil || len(nodes) == 0 {
log.Warn("failed to get shard delegator", log.Warn("failed to get shard delegator",
...@@ -141,7 +144,8 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload ...@@ -141,7 +144,8 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error { func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
excludeNodes := typeutil.NewUniqueSet() excludeNodes := typeutil.NewUniqueSet()
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
zap.String("collectionName", workload.collection), zap.Int64("collectionID", workload.collectionID),
zap.String("collectionName", workload.collectionName),
zap.String("channelName", workload.channel), zap.String("channelName", workload.channel),
) )
...@@ -185,7 +189,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo ...@@ -185,7 +189,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
// Execute will execute collection workload in parallel // Execute will execute collection workload in parallel
func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error { func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error {
dml2leaders, err := globalMetaCache.GetShards(ctx, true, workload.db, workload.collection) dml2leaders, err := globalMetaCache.GetShards(ctx, true, workload.db, workload.collectionName, workload.collectionID)
if err != nil { if err != nil {
log.Ctx(ctx).Warn("failed to get shards", zap.Error(err)) log.Ctx(ctx).Warn("failed to get shards", zap.Error(err))
return err return err
...@@ -197,13 +201,14 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad ...@@ -197,13 +201,14 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID }) nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID })
wg.Go(func() error { wg.Go(func() error {
err := lb.ExecuteWithRetry(ctx, ChannelWorkload{ err := lb.ExecuteWithRetry(ctx, ChannelWorkload{
db: workload.db, db: workload.db,
collection: workload.collection, collectionName: workload.collectionName,
channel: channel, collectionID: workload.collectionID,
shardLeaders: nodes, channel: channel,
nq: workload.nq, shardLeaders: nodes,
exec: workload.exec, nq: workload.nq,
retryTimes: uint(len(nodes)), exec: workload.exec,
retryTimes: uint(len(nodes)),
}) })
return err return err
}) })
......
...@@ -53,7 +53,8 @@ type LBPolicySuite struct { ...@@ -53,7 +53,8 @@ type LBPolicySuite struct {
channels []string channels []string
qnList []*mocks.MockQueryNode qnList []*mocks.MockQueryNode
collection string collectionName string
collectionID int64
} }
func (s *LBPolicySuite) SetupSuite() { func (s *LBPolicySuite) SetupSuite() {
...@@ -108,7 +109,7 @@ func (s *LBPolicySuite) SetupTest() { ...@@ -108,7 +109,7 @@ func (s *LBPolicySuite) SetupTest() {
err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr) err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr)
s.NoError(err) s.NoError(err)
s.collection = "test_lb_policy" s.collectionName = "test_lb_policy"
s.loadCollection() s.loadCollection()
} }
...@@ -125,7 +126,7 @@ func (s *LBPolicySuite) loadCollection() { ...@@ -125,7 +126,7 @@ func (s *LBPolicySuite) loadCollection() {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
} }
schema := constructCollectionSchemaByDataType(s.collection, fieldName2Types, testInt64Field, false) schema := constructCollectionSchemaByDataType(s.collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema) marshaledSchema, err := proto.Marshal(schema)
s.NoError(err) s.NoError(err)
...@@ -133,7 +134,7 @@ func (s *LBPolicySuite) loadCollection() { ...@@ -133,7 +134,7 @@ func (s *LBPolicySuite) loadCollection() {
createColT := &createCollectionTask{ createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: s.collection, CollectionName: s.collectionName,
DbName: dbName, DbName: dbName,
Schema: marshaledSchema, Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum, ShardsNum: common.DefaultShardsNum,
...@@ -147,7 +148,7 @@ func (s *LBPolicySuite) loadCollection() { ...@@ -147,7 +148,7 @@ func (s *LBPolicySuite) loadCollection() {
s.NoError(createColT.Execute(ctx)) s.NoError(createColT.Execute(ctx))
s.NoError(createColT.PostExecute(ctx)) s.NoError(createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, s.collection) collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, s.collectionName)
s.NoError(err) s.NoError(err)
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
...@@ -159,17 +160,19 @@ func (s *LBPolicySuite) loadCollection() { ...@@ -159,17 +160,19 @@ func (s *LBPolicySuite) loadCollection() {
}) })
s.NoError(err) s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, status.ErrorCode) s.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
s.collectionID = collectionID
} }
func (s *LBPolicySuite) TestSelectNode() { func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background() ctx := context.Background()
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: s.nodes, channel: s.channels[0],
nq: 1, shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet()) }, typeutil.NewUniqueSet())
s.NoError(err) s.NoError(err)
s.Equal(int64(5), targetNode) s.Equal(int64(5), targetNode)
...@@ -179,11 +182,12 @@ func (s *LBPolicySuite) TestSelectNode() { ...@@ -179,11 +182,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: []int64{}, channel: s.channels[0],
nq: 1, shardLeaders: []int64{},
nq: 1,
}, typeutil.NewUniqueSet()) }, typeutil.NewUniqueSet())
s.NoError(err) s.NoError(err)
s.Equal(int64(3), targetNode) s.Equal(int64(3), targetNode)
...@@ -192,11 +196,12 @@ func (s *LBPolicySuite) TestSelectNode() { ...@@ -192,11 +196,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: []int64{}, channel: s.channels[0],
nq: 1, shardLeaders: []int64{},
nq: 1,
}, typeutil.NewUniqueSet()) }, typeutil.NewUniqueSet())
s.ErrorIs(err, merr.ErrNodeNotAvailable) s.ErrorIs(err, merr.ErrNodeNotAvailable)
s.Equal(int64(-1), targetNode) s.Equal(int64(-1), targetNode)
...@@ -205,11 +210,12 @@ func (s *LBPolicySuite) TestSelectNode() { ...@@ -205,11 +210,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: s.nodes, channel: s.channels[0],
nq: 1, shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet(s.nodes...)) }, typeutil.NewUniqueSet(s.nodes...))
s.ErrorIs(err, merr.ErrServiceUnavailable) s.ErrorIs(err, merr.ErrServiceUnavailable)
s.Equal(int64(-1), targetNode) s.Equal(int64(-1), targetNode)
...@@ -220,11 +226,12 @@ func (s *LBPolicySuite) TestSelectNode() { ...@@ -220,11 +226,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.qc.ExpectedCalls = nil s.qc.ExpectedCalls = nil
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable) s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: s.nodes, channel: s.channels[0],
nq: 1, shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet()) }, typeutil.NewUniqueSet())
s.ErrorIs(err, merr.ErrServiceUnavailable) s.ErrorIs(err, merr.ErrServiceUnavailable)
s.Equal(int64(-1), targetNode) s.Equal(int64(-1), targetNode)
...@@ -239,11 +246,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { ...@@ -239,11 +246,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: s.nodes, channel: s.channels[0],
nq: 1, shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil return nil
}, },
...@@ -255,11 +263,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { ...@@ -255,11 +263,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: s.nodes, channel: s.channels[0],
nq: 1, shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil return nil
}, },
...@@ -274,11 +283,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { ...@@ -274,11 +283,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: s.nodes, channel: s.channels[0],
nq: 1, shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil return nil
}, },
...@@ -291,11 +301,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { ...@@ -291,11 +301,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: s.nodes, channel: s.channels[0],
nq: 1, shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil return nil
}, },
...@@ -311,11 +322,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { ...@@ -311,11 +322,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := 0 counter := 0
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
channel: s.channels[0], collectionID: s.collectionID,
shardLeaders: s.nodes, channel: s.channels[0],
nq: 1, shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
counter++ counter++
if counter == 1 { if counter == 1 {
...@@ -336,9 +348,10 @@ func (s *LBPolicySuite) TestExecute() { ...@@ -336,9 +348,10 @@ func (s *LBPolicySuite) TestExecute() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{ err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
nq: 1, collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil return nil
}, },
...@@ -348,9 +361,10 @@ func (s *LBPolicySuite) TestExecute() { ...@@ -348,9 +361,10 @@ func (s *LBPolicySuite) TestExecute() {
// test some channel failed // test some channel failed
counter := atomic.NewInt64(0) counter := atomic.NewInt64(0)
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{ err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
nq: 1, collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
if counter.Add(1) == 1 { if counter.Add(1) == 1 {
return nil return nil
...@@ -363,12 +377,13 @@ func (s *LBPolicySuite) TestExecute() { ...@@ -363,12 +377,13 @@ func (s *LBPolicySuite) TestExecute() {
// test get shard leader failed // test get shard leader failed
s.qc.ExpectedCalls = nil s.qc.ExpectedCalls = nil
globalMetaCache.DeprecateShardCache(dbName, s.collection) globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, mockErr) s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, mockErr)
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{ err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName, db: dbName,
collection: s.collection, collectionName: s.collectionName,
nq: 1, collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil return nil
}, },
......
...@@ -57,8 +57,8 @@ type Cache interface { ...@@ -57,8 +57,8 @@ type Cache interface {
GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error) GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error)
// GetDatabaseAndCollectionName get collection's name and database by id // GetDatabaseAndCollectionName get collection's name and database by id
GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error)
// GetCollectionInfo get collection's information by name, such as collection id, schema, and etc. // GetCollectionInfo get collection's information by name or collection id, such as schema, and etc.
GetCollectionInfo(ctx context.Context, database, collectionName string) (*collectionInfo, error) GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error)
// GetPartitionID get partition's identifier of specific collection. // GetPartitionID get partition's identifier of specific collection.
GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error) GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error)
// GetPartitions get all partitions' id of specific collection. // GetPartitions get all partitions' id of specific collection.
...@@ -67,7 +67,7 @@ type Cache interface { ...@@ -67,7 +67,7 @@ type Cache interface {
GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error) GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error)
// GetCollectionSchema get collection's schema. // GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, database, collectionName string) (map[string][]nodeInfo, error) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error)
DeprecateShardCache(database, collectionName string) DeprecateShardCache(database, collectionName string)
expireShardLeaderCache(ctx context.Context) expireShardLeaderCache(ctx context.Context)
RemoveCollection(ctx context.Context, database, collectionName string) RemoveCollection(ctx context.Context, database, collectionName string)
...@@ -229,7 +229,7 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam ...@@ -229,7 +229,7 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam
collInfo, ok = db[collectionName] collInfo, ok = db[collectionName]
} }
method := "GeCollectionID" method := "GetCollectionID"
if !ok || !collInfo.isCollectionCached() { if !ok || !collInfo.isCollectionCached() {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
tr := timerecord.NewTimeRecorder("UpdateCache") tr := timerecord.NewTimeRecorder("UpdateCache")
...@@ -289,7 +289,7 @@ func (m *MetaCache) GetDatabaseAndCollectionName(ctx context.Context, collection ...@@ -289,7 +289,7 @@ func (m *MetaCache) GetDatabaseAndCollectionName(ctx context.Context, collection
// GetCollectionInfo returns the collection information related to provided collection name // GetCollectionInfo returns the collection information related to provided collection name
// If the information is not found, proxy will try to fetch information for other source (RootCoord for now) // If the information is not found, proxy will try to fetch information for other source (RootCoord for now)
func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionName string) (*collectionInfo, error) { func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) {
m.mu.RLock() m.mu.RLock()
var collInfo *collectionInfo var collInfo *collectionInfo
var ok bool var ok bool
...@@ -301,10 +301,17 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionN ...@@ -301,10 +301,17 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionN
m.mu.RUnlock() m.mu.RUnlock()
method := "GetCollectionInfo" method := "GetCollectionInfo"
if !ok || !collInfo.isCollectionCached() { // if collInfo.collID != collectionID, means that the cache is not trustable
// try to get collection according to collectionID
if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID {
tr := timerecord.NewTimeRecorder("UpdateCache") tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
coll, err := m.describeCollection(ctx, database, collectionName, 0) var coll *milvuspb.DescribeCollectionResponse
var err error
// collectionName maybe not trustable, get collection according to id
coll, err = m.describeCollection(ctx, database, "", collectionID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -695,8 +702,12 @@ func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) { ...@@ -695,8 +702,12 @@ func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
} }
// GetShards update cache if withCache == false // GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string) (map[string][]nodeInfo, error) { func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) {
info, err := m.GetCollectionInfo(ctx, database, collectionName) log := log.Ctx(ctx).With(
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))
info, err := m.GetCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -715,8 +726,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col ...@@ -715,8 +726,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
} }
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord", log.Info("no shard cache for collection, try to get shard leaders from QueryCoord")
zap.String("collectionName", collectionName))
} }
req := &querypb.GetShardLeadersRequest{ req := &querypb.GetShardLeadersRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
...@@ -754,9 +764,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col ...@@ -754,9 +764,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
shards := parseShardLeaderList2QueryNode(resp.GetShards()) shards := parseShardLeaderList2QueryNode(resp.GetShards())
info, err = m.GetCollectionInfo(ctx, database, collectionName) info, err = m.GetCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get shards, collection %s not found", collectionName) return nil, fmt.Errorf("failed to get shards, collectionName %s, colectionID %d not found", collectionName, collectionID)
} }
// lock leader // lock leader
info.leaderMutex.Lock() info.leaderMutex.Lock()
......
...@@ -474,6 +474,7 @@ func TestMetaCache_GetShards(t *testing.T) { ...@@ -474,6 +474,7 @@ func TestMetaCache_GetShards(t *testing.T) {
var ( var (
ctx = context.Background() ctx = context.Background()
collectionName = "collection1" collectionName = "collection1"
collectionID = int64(1)
) )
rootCoord := &MockRootCoordClientInterface{} rootCoord := &MockRootCoordClientInterface{}
...@@ -488,7 +489,7 @@ func TestMetaCache_GetShards(t *testing.T) { ...@@ -488,7 +489,7 @@ func TestMetaCache_GetShards(t *testing.T) {
defer qc.Stop() defer qc.Stop()
t.Run("No collection in meta cache", func(t *testing.T) { t.Run("No collection in meta cache", func(t *testing.T) {
shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists") shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists", 0)
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, shards) assert.Empty(t, shards)
}) })
...@@ -503,7 +504,7 @@ func TestMetaCache_GetShards(t *testing.T) { ...@@ -503,7 +504,7 @@ func TestMetaCache_GetShards(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil) }, nil)
shards, err := globalMetaCache.GetShards(ctx, false, dbName, collectionName) shards, err := globalMetaCache.GetShards(ctx, false, dbName, collectionName, collectionID)
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, shards) assert.Empty(t, shards)
}) })
...@@ -524,7 +525,7 @@ func TestMetaCache_GetShards(t *testing.T) { ...@@ -524,7 +525,7 @@ func TestMetaCache_GetShards(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil) }, nil)
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName) shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, shards) assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards)) assert.Equal(t, 1, len(shards))
...@@ -537,7 +538,7 @@ func TestMetaCache_GetShards(t *testing.T) { ...@@ -537,7 +538,7 @@ func TestMetaCache_GetShards(t *testing.T) {
Reason: "not implemented", Reason: "not implemented",
}, },
}, nil) }, nil)
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName) shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, shards) assert.NotEmpty(t, shards)
...@@ -550,6 +551,7 @@ func TestMetaCache_ClearShards(t *testing.T) { ...@@ -550,6 +551,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
var ( var (
ctx = context.TODO() ctx = context.TODO()
collectionName = "collection1" collectionName = "collection1"
collectionID = int64(1)
) )
rootCoord := &MockRootCoordClientInterface{} rootCoord := &MockRootCoordClientInterface{}
...@@ -588,7 +590,7 @@ func TestMetaCache_ClearShards(t *testing.T) { ...@@ -588,7 +590,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil) }, nil)
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName) shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, shards) require.NotEmpty(t, shards)
require.Equal(t, 1, len(shards)) require.Equal(t, 1, len(shards))
...@@ -602,7 +604,7 @@ func TestMetaCache_ClearShards(t *testing.T) { ...@@ -602,7 +604,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
Reason: "not implemented", Reason: "not implemented",
}, },
}, nil) }, nil)
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName) shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, shards) assert.Empty(t, shards)
}) })
...@@ -706,26 +708,26 @@ func TestMetaCache_RemoveCollection(t *testing.T) { ...@@ -706,26 +708,26 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
InMemoryPercentages: []int64{100, 50}, InMemoryPercentages: []int64{100, 50},
}, nil) }, nil)
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1") _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err) assert.NoError(t, err)
// no collectionInfo of collection1, should access RootCoord // no collectionInfo of collection1, should access RootCoord
assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1)
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1") _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err) assert.NoError(t, err)
// shouldn't access RootCoord again // shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1)
globalMetaCache.RemoveCollection(ctx, dbName, "collection1") globalMetaCache.RemoveCollection(ctx, dbName, "collection1")
// no collectionInfo of collection2, should access RootCoord // no collectionInfo of collection2, should access RootCoord
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1") _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err) assert.NoError(t, err)
// shouldn't access RootCoord again // shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Equal(t, rootCoord.GetAccessCount(), 2)
globalMetaCache.RemoveCollectionsByID(ctx, UniqueID(1)) globalMetaCache.RemoveCollectionsByID(ctx, UniqueID(1))
// no collectionInfo of collection2, should access RootCoord // no collectionInfo of collection2, should access RootCoord
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1") _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err) assert.NoError(t, err)
// shouldn't access RootCoord again // shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 3) assert.Equal(t, rootCoord.GetAccessCount(), 3)
...@@ -761,7 +763,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { ...@@ -761,7 +763,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, },
}, },
}, nil) }, nil)
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1") nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, nodeInfos["channel-1"], 3) assert.Len(t, nodeInfos["channel-1"], 3)
...@@ -780,7 +782,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { ...@@ -780,7 +782,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, nil) }, nil)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1") nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err) assert.NoError(t, err)
return len(nodeInfos["channel-1"]) == 2 return len(nodeInfos["channel-1"]) == 2
}, 3*time.Second, 1*time.Second) }, 3*time.Second, 1*time.Second)
...@@ -800,7 +802,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { ...@@ -800,7 +802,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, nil) }, nil)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1") nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err) assert.NoError(t, err)
return len(nodeInfos["channel-1"]) == 3 return len(nodeInfos["channel-1"]) == 3
}, 3*time.Second, 1*time.Second) }, 3*time.Second, 1*time.Second)
...@@ -825,7 +827,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { ...@@ -825,7 +827,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, nil) }, nil)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1") nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err) assert.NoError(t, err)
return len(nodeInfos["channel-1"]) == 3 && len(nodeInfos["channel-2"]) == 3 return len(nodeInfos["channel-1"]) == 3 && len(nodeInfos["channel-2"]) == 3
}, 3*time.Second, 1*time.Second) }, 3*time.Second, 1*time.Second)
......
此差异已折叠。
// Code generated by mockery v2.21.1. DO NOT EDIT. // Code generated by mockery v2.23.1. DO NOT EDIT.
package proxy package proxy
......
...@@ -709,7 +709,7 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error { ...@@ -709,7 +709,7 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error {
zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections"))
continue continue
} }
collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, sct.GetDbName(), collectionName) collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, sct.GetDbName(), collectionName, id)
if err != nil { if err != nil {
log.Debug("Failed to get collection info.", zap.Any("collectionName", collectionName), log.Debug("Failed to get collection info.", zap.Any("collectionName", collectionName),
zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections"))
......
...@@ -370,10 +370,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error { ...@@ -370,10 +370,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName) collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
if err2 != nil { if err2 != nil {
log.Warn("Proxy::queryTask::PreExecute failed to GetCollectionInfo from cache", log.Warn("Proxy::queryTask::PreExecute failed to GetCollectionInfo from cache",
zap.String("collectionName", collectionName), zap.Error(err2)) zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID),
zap.Error(err2))
return err2 return err2
} }
...@@ -417,10 +418,11 @@ func (t *queryTask) Execute(ctx context.Context) error { ...@@ -417,10 +418,11 @@ func (t *queryTask) Execute(ctx context.Context) error {
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.RetrieveResults]() t.resultBuf = typeutil.NewConcurrentSet[*internalpb.RetrieveResults]()
err := t.lb.Execute(ctx, CollectionWorkLoad{ err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(), db: t.request.GetDbName(),
collection: t.collectionName, collectionID: t.CollectionID,
nq: 1, collectionName: t.collectionName,
exec: t.queryShard, nq: 1,
exec: t.queryShard,
}) })
if err != nil { if err != nil {
log.Warn("fail to execute query", zap.Error(err)) log.Warn("fail to execute query", zap.Error(err))
......
...@@ -361,10 +361,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error { ...@@ -361,10 +361,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName) collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
if err2 != nil { if err2 != nil {
log.Warn("Proxy::searchTask::PreExecute failed to GetCollectionInfo from cache", log.Warn("Proxy::searchTask::PreExecute failed to GetCollectionInfo from cache",
zap.Any("collectionName", collectionName), zap.Error(err2)) zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2))
return err2 return err2
} }
guaranteeTs := t.request.GetGuaranteeTimestamp() guaranteeTs := t.request.GetGuaranteeTimestamp()
...@@ -417,10 +417,11 @@ func (t *searchTask) Execute(ctx context.Context) error { ...@@ -417,10 +417,11 @@ func (t *searchTask) Execute(ctx context.Context) error {
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]() t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
err := t.lb.Execute(ctx, CollectionWorkLoad{ err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(), db: t.request.GetDbName(),
collection: t.collectionName, collectionID: t.SearchRequest.CollectionID,
nq: t.Nq, collectionName: t.collectionName,
exec: t.searchShard, nq: t.Nq,
exec: t.searchShard,
}) })
if err != nil { if err != nil {
log.Warn("search execute failed", zap.Error(err)) log.Warn("search execute failed", zap.Error(err))
......
...@@ -139,26 +139,28 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error { ...@@ -139,26 +139,28 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error {
} }
// check if collection/partitions are loaded into query node // check if collection/partitions are loaded into query node
loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.collectionName, partIDs) loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.collectionName, g.GetStatisticsRequest.CollectionID, partIDs)
log := log.Ctx(ctx) log := log.Ctx(ctx).With(
zap.String("collectionName", g.collectionName),
zap.Int64("collectionID", g.CollectionID),
)
if err != nil { if err != nil {
g.fromDataCoord = true g.fromDataCoord = true
g.unloadedPartitionIDs = partIDs g.unloadedPartitionIDs = partIDs
log.Info("checkFullLoaded failed, try get statistics from DataCoord", zap.Error(err)) log.Info("checkFullLoaded failed, try get statistics from DataCoord",
zap.Error(err))
return nil return nil
} }
if len(unloaded) > 0 { if len(unloaded) > 0 {
g.fromDataCoord = true g.fromDataCoord = true
g.unloadedPartitionIDs = unloaded g.unloadedPartitionIDs = unloaded
log.Info("some partitions has not been loaded, try get statistics from DataCoord", log.Info("some partitions has not been loaded, try get statistics from DataCoord",
zap.String("collection", g.collectionName),
zap.Int64s("unloaded partitions", unloaded)) zap.Int64s("unloaded partitions", unloaded))
} }
if len(loaded) > 0 { if len(loaded) > 0 {
g.fromQueryNode = true g.fromQueryNode = true
g.loadedPartitionIDs = loaded g.loadedPartitionIDs = loaded
log.Info("some partitions has been loaded, try get statistics from QueryNode", log.Info("some partitions has been loaded, try get statistics from QueryNode",
zap.String("collection", g.collectionName),
zap.Int64s("loaded partitions", loaded)) zap.Int64s("loaded partitions", loaded))
} }
return nil return nil
...@@ -266,10 +268,11 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro ...@@ -266,10 +268,11 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]() g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]()
} }
err := g.lb.Execute(ctx, CollectionWorkLoad{ err := g.lb.Execute(ctx, CollectionWorkLoad{
db: g.request.GetDbName(), db: g.request.GetDbName(),
collection: g.collectionName, collectionID: g.GetStatisticsRequest.CollectionID,
nq: 1, collectionName: g.collectionName,
exec: g.getStatisticsShard, nq: 1,
exec: g.getStatisticsShard,
}) })
if err != nil { if err != nil {
...@@ -317,14 +320,14 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64 ...@@ -317,14 +320,14 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64
// checkFullLoaded check if collection / partition was fully loaded into QueryNode // checkFullLoaded check if collection / partition was fully loaded into QueryNode
// return loaded partitions, unloaded partitions and error // return loaded partitions, unloaded partitions and error
func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) { func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, collectionID int64, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
var loadedPartitionIDs []UniqueID var loadedPartitionIDs []UniqueID
var unloadPartitionIDs []UniqueID var unloadPartitionIDs []UniqueID
// TODO: Consider to check if partition loaded from cache to save rpc. // TODO: Consider to check if partition loaded from cache to save rpc.
info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName, collectionID)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err) return nil, nil, fmt.Errorf("GetCollectionInfo failed, collectionName = %s,collectionID = %d, err = %s", collectionName, collectionID, err)
} }
// If request to search partitions // If request to search partitions
...@@ -338,10 +341,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st ...@@ -338,10 +341,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st
PartitionIDs: searchPartitionIDs, PartitionIDs: searchPartitionIDs,
}) })
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err) return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err)
} }
if resp.Status.ErrorCode != commonpb.ErrorCode_Success { if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason()) return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason())
} }
for i, percentage := range resp.GetInMemoryPercentages() { for i, percentage := range resp.GetInMemoryPercentages() {
...@@ -363,10 +366,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st ...@@ -363,10 +366,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st
CollectionID: info.collID, CollectionID: info.collID,
}) })
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err) return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err)
} }
if resp.Status.ErrorCode != commonpb.ErrorCode_Success { if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason()) return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason())
} }
loadedMap := make(map[UniqueID]bool) loadedMap := make(map[UniqueID]bool)
......
...@@ -44,7 +44,8 @@ type StatisticTaskSuite struct { ...@@ -44,7 +44,8 @@ type StatisticTaskSuite struct {
lb LBPolicy lb LBPolicy
collection string collectionName string
collectionID int64
} }
func (s *StatisticTaskSuite) SetupSuite() { func (s *StatisticTaskSuite) SetupSuite() {
...@@ -87,7 +88,7 @@ func (s *StatisticTaskSuite) SetupTest() { ...@@ -87,7 +88,7 @@ func (s *StatisticTaskSuite) SetupTest() {
err := InitMetaCache(context.Background(), s.rc, s.qc, mgr) err := InitMetaCache(context.Background(), s.rc, s.qc, mgr)
s.NoError(err) s.NoError(err)
s.collection = "test_statistics_task" s.collectionName = "test_statistics_task"
s.loadCollection() s.loadCollection()
} }
...@@ -104,7 +105,7 @@ func (s *StatisticTaskSuite) loadCollection() { ...@@ -104,7 +105,7 @@ func (s *StatisticTaskSuite) loadCollection() {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
} }
schema := constructCollectionSchemaByDataType(s.collection, fieldName2Types, testInt64Field, false) schema := constructCollectionSchemaByDataType(s.collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema) marshaledSchema, err := proto.Marshal(schema)
s.NoError(err) s.NoError(err)
...@@ -112,7 +113,7 @@ func (s *StatisticTaskSuite) loadCollection() { ...@@ -112,7 +113,7 @@ func (s *StatisticTaskSuite) loadCollection() {
createColT := &createCollectionTask{ createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: s.collection, CollectionName: s.collectionName,
Schema: marshaledSchema, Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum, ShardsNum: common.DefaultShardsNum,
}, },
...@@ -125,7 +126,7 @@ func (s *StatisticTaskSuite) loadCollection() { ...@@ -125,7 +126,7 @@ func (s *StatisticTaskSuite) loadCollection() {
s.NoError(createColT.Execute(ctx)) s.NoError(createColT.Execute(ctx))
s.NoError(createColT.PostExecute(ctx)) s.NoError(createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), s.collection) collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), s.collectionName)
s.NoError(err) s.NoError(err)
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
...@@ -137,6 +138,7 @@ func (s *StatisticTaskSuite) loadCollection() { ...@@ -137,6 +138,7 @@ func (s *StatisticTaskSuite) loadCollection() {
}) })
s.NoError(err) s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, status.ErrorCode) s.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
s.collectionID = collectionID
} }
func (s *StatisticTaskSuite) TearDownSuite() { func (s *StatisticTaskSuite) TearDownSuite() {
...@@ -164,7 +166,7 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti ...@@ -164,7 +166,7 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti
return &getStatisticsTask{ return &getStatisticsTask{
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
ctx: ctx, ctx: ctx,
collectionName: s.collection, collectionName: s.collectionName,
result: &milvuspb.GetStatisticsResponse{ result: &milvuspb.GetStatisticsResponse{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
...@@ -175,7 +177,7 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti ...@@ -175,7 +177,7 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti
MsgType: commonpb.MsgType_Retrieve, MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(), SourceID: paramtable.GetNodeID(),
}, },
CollectionName: s.collection, CollectionName: s.collectionName,
}, },
qc: s.qc, qc: s.qc,
lb: s.lb, lb: s.lb,
...@@ -195,6 +197,7 @@ func (s *StatisticTaskSuite) TestStatisticTask_NotShardLeader() { ...@@ -195,6 +197,7 @@ func (s *StatisticTaskSuite) TestStatisticTask_NotShardLeader() {
Reason: "error", Reason: "error",
}, },
}, nil) }, nil)
s.NoError(task.PreExecute(ctx))
s.Error(task.Execute(ctx)) s.Error(task.Execute(ctx))
s.NoError(task.PostExecute(ctx)) s.NoError(task.PostExecute(ctx))
} }
...@@ -211,6 +214,7 @@ func (s *StatisticTaskSuite) TestStatisticTask_UnexpectedError() { ...@@ -211,6 +214,7 @@ func (s *StatisticTaskSuite) TestStatisticTask_UnexpectedError() {
Reason: "error", Reason: "error",
}, },
}, nil) }, nil)
s.NoError(task.PreExecute(ctx))
s.Error(task.Execute(ctx)) s.Error(task.Execute(ctx))
s.NoError(task.PostExecute(ctx)) s.NoError(task.PostExecute(ctx))
} }
...@@ -220,8 +224,10 @@ func (s *StatisticTaskSuite) TestStatisticTask_Success() { ...@@ -220,8 +224,10 @@ func (s *StatisticTaskSuite) TestStatisticTask_Success() {
task := s.getStatisticsTask(ctx) task := s.getStatisticsTask(ctx)
s.NoError(task.OnEnqueue()) s.NoError(task.OnEnqueue())
task.fromQueryNode = true
s.qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, nil) s.qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, nil)
s.NoError(task.PreExecute(ctx))
task.fromQueryNode = true
task.fromDataCoord = false
s.NoError(task.Execute(ctx)) s.NoError(task.Execute(ctx))
s.NoError(task.PostExecute(ctx)) s.NoError(task.PostExecute(ctx))
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册