diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index cfa4b93c209d0c35c590c8fdecd68be09b4a5bfa..37f020c8ee2684c73fd1431bbaabaa80c2e6f8ab 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -396,7 +396,12 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error { log.Warn("failed to get partitions of collection") return err } - metricType, err := getMetricType(ctx, task.CollectionID(), schema, ex.broker) + indexInfo, err := ex.broker.DescribeIndex(ctx, task.CollectionID()) + if err != nil { + log.Warn("fail to get index meta of collection") + return err + } + metricType, err := getMetricType(indexInfo, schema) if err != nil { log.Warn("failed to get metric type", zap.Error(err)) return err @@ -414,7 +419,7 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error { log.Warn(msg, zap.String("channelName", action.ChannelName())) return merr.WrapErrChannelReduplicate(action.ChannelName()) } - req := packSubChannelRequest(task, action, schema, loadMeta, dmChannel) + req := packSubChannelRequest(task, action, schema, loadMeta, dmChannel, indexInfo) err = fillSubChannelRequest(ctx, req, ex.broker) if err != nil { log.Warn("failed to subscribe channel, failed to fill the request with segments", diff --git a/internal/querycoordv2/task/utils.go b/internal/querycoordv2/task/utils.go index d31c5e8d6ff329d957be68c9d7515cfb44260aca..50368ca14400e816866b5d79cb9f8da9784f15f1 100644 --- a/internal/querycoordv2/task/utils.go +++ b/internal/querycoordv2/task/utils.go @@ -138,19 +138,21 @@ func packSubChannelRequest( schema *schemapb.CollectionSchema, loadMeta *querypb.LoadMetaInfo, channel *meta.DmChannel, + indexInfo []*indexpb.IndexInfo, ) *querypb.WatchDmChannelsRequest { return &querypb.WatchDmChannelsRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_WatchDmChannels), commonpbutil.WithMsgID(task.ID()), ), - NodeID: action.Node(), - CollectionID: task.CollectionID(), - Infos: []*datapb.VchannelInfo{channel.VchannelInfo}, - Schema: schema, // assign it for compatibility of rolling upgrade from 2.2.x to 2.3 - LoadMeta: loadMeta, // assign it for compatibility of rolling upgrade from 2.2.x to 2.3 - ReplicaID: task.ReplicaID(), - Version: time.Now().UnixNano(), + NodeID: action.Node(), + CollectionID: task.CollectionID(), + Infos: []*datapb.VchannelInfo{channel.VchannelInfo}, + Schema: schema, // assign it for compatibility of rolling upgrade from 2.2.x to 2.3 + LoadMeta: loadMeta, // assign it for compatibility of rolling upgrade from 2.2.x to 2.3 + ReplicaID: task.ReplicaID(), + Version: time.Now().UnixNano(), + IndexInfoList: indexInfo, } } @@ -201,11 +203,7 @@ func getShardLeader(replicaMgr *meta.ReplicaManager, distMgr *meta.DistributionM return distMgr.GetShardLeader(replica, channel) } -func getMetricType(ctx context.Context, collection int64, schema *schemapb.CollectionSchema, broker meta.Broker) (string, error) { - indexInfos, err := broker.DescribeIndex(ctx, collection) - if err != nil { - return "", err - } +func getMetricType(indexInfos []*indexpb.IndexInfo, schema *schemapb.CollectionSchema) (string, error) { vecField, err := typeutil.GetVectorFieldSchema(schema) if err != nil { return "", err diff --git a/internal/querycoordv2/task/utils_test.go b/internal/querycoordv2/task/utils_test.go index 4eb5cabb011b32110046e9f259cfc840f64774cc..788cadc4718dea8d541922a2b407f56ff8870d6a 100644 --- a/internal/querycoordv2/task/utils_test.go +++ b/internal/querycoordv2/task/utils_test.go @@ -17,22 +17,17 @@ package task import ( - "context" - "fmt" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/pkg/common" ) func Test_getMetricType(t *testing.T) { - ctx := context.Background() collection := int64(1) schema := &schemapb.CollectionSchema{ Name: "TestGetMetricType", @@ -50,50 +45,35 @@ func Test_getMetricType(t *testing.T) { }, }, } + + indexInfo2 := &indexpb.IndexInfo{ + CollectionID: collection, + FieldID: 100, + } + t.Run("test normal", func(t *testing.T) { - broker := meta.NewMockBroker(t) - broker.EXPECT().DescribeIndex(mock.Anything, collection). - Return([]*indexpb.IndexInfo{indexInfo}, nil) - metricType, err := getMetricType(ctx, collection, schema, broker) + metricType, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, schema) assert.NoError(t, err) assert.Equal(t, "L2", metricType) }) - t.Run("test describe index failed", func(t *testing.T) { - broker := meta.NewMockBroker(t) - broker.EXPECT().DescribeIndex(mock.Anything, collection). - Return(nil, fmt.Errorf("mock err")) - _, err := getMetricType(ctx, collection, schema, broker) - assert.Error(t, err) - }) + t.Run("test get vec field failed", func(t *testing.T) { - broker := meta.NewMockBroker(t) - broker.EXPECT().DescribeIndex(mock.Anything, collection). - Return([]*indexpb.IndexInfo{indexInfo}, nil) - _, err := getMetricType(ctx, collection, &schemapb.CollectionSchema{ + _, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{ Name: "TestGetMetricType", - }, broker) + }) assert.Error(t, err) }) t.Run("test field id mismatch", func(t *testing.T) { - broker := meta.NewMockBroker(t) - broker.EXPECT().DescribeIndex(mock.Anything, collection). - Return([]*indexpb.IndexInfo{indexInfo}, nil) - _, err := getMetricType(ctx, collection, &schemapb.CollectionSchema{ + _, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{ Name: "TestGetMetricType", Fields: []*schemapb.FieldSchema{ {FieldID: -1, Name: "vec", DataType: schemapb.DataType_FloatVector}, }, - }, broker) + }) assert.Error(t, err) }) t.Run("test no metric type", func(t *testing.T) { - broker := meta.NewMockBroker(t) - broker.EXPECT().DescribeIndex(mock.Anything, collection). - Return([]*indexpb.IndexInfo{{ - CollectionID: collection, - FieldID: 100, - }}, nil) - _, err := getMetricType(ctx, collection, schema, broker) + _, err := getMetricType([]*indexpb.IndexInfo{indexInfo2}, schema) assert.Error(t, err) }) }