diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 117e9bd6658d5f06c68eb2bd648afbbe4aa8be04..7f94d5d62eff876c956a006b7a15f4bf8d5c2c71 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -15,6 +15,7 @@ import ( "context" "errors" "fmt" + "reflect" "sync" "time" @@ -1761,13 +1762,23 @@ func assignInternalTask(ctx context.Context, log.Debug("assignInternalTask: watch request to node", zap.Any("request map", watchRequest2Nodes), zap.Int64("collectionID", collectionID)) watchQueryChannelInfo := make(map[int64]bool) - node2Segments := make(map[int64]*querypb.LoadSegmentsRequest) + node2Segments := make(map[int64][]*querypb.LoadSegmentsRequest) + sizeCounts := make(map[int64]int) for index, nodeID := range segment2Nodes { if _, ok := node2Segments[nodeID]; !ok { - node2Segments[nodeID] = loadSegmentRequests[index] - } else { - node2Segments[nodeID].Infos = append(node2Segments[nodeID].Infos, loadSegmentRequests[index].Infos...) + node2Segments[nodeID] = make([]*querypb.LoadSegmentsRequest, 0) + node2Segments[nodeID] = append(node2Segments[nodeID], loadSegmentRequests[index]) + sizeCounts[nodeID] = 0 + } + sizeOfReq := getSizeOfLoadSegmentReq(loadSegmentRequests[index]) + if sizeCounts[nodeID]+sizeOfReq > 2097152 { + node2Segments[nodeID] = append(node2Segments[nodeID], loadSegmentRequests[index]) + sizeCounts[nodeID] = 0 } + lastReq := node2Segments[nodeID][len(node2Segments[nodeID])-1] + lastReq.Infos = append(lastReq.Infos, loadSegmentRequests[index].Infos...) + sizeCounts[nodeID] += sizeOfReq + if cluster.hasWatchedQueryChannel(parentTask.traceCtx(), nodeID, collectionID) { watchQueryChannelInfo[nodeID] = true continue @@ -1782,20 +1793,22 @@ func assignInternalTask(ctx context.Context, watchQueryChannelInfo[nodeID] = false } - for nodeID, loadSegmentsReq := range node2Segments { - ctx = opentracing.ContextWithSpan(context.Background(), sp) - loadSegmentsReq.NodeID = nodeID - baseTask := newBaseTask(ctx, parentTask.getTriggerCondition()) - baseTask.setParentTask(parentTask) - loadSegmentTask := &loadSegmentTask{ - baseTask: baseTask, - LoadSegmentsRequest: loadSegmentsReq, - meta: meta, - cluster: cluster, - excludeNodeIDs: []int64{}, + for nodeID, loadSegmentsReqs := range node2Segments { + for _, req := range loadSegmentsReqs { + ctx = opentracing.ContextWithSpan(context.Background(), sp) + req.NodeID = nodeID + baseTask := newBaseTask(ctx, parentTask.getTriggerCondition()) + baseTask.setParentTask(parentTask) + loadSegmentTask := &loadSegmentTask{ + baseTask: baseTask, + LoadSegmentsRequest: req, + meta: meta, + cluster: cluster, + excludeNodeIDs: []int64{}, + } + parentTask.addChildTask(loadSegmentTask) + log.Debug("assignInternalTask: add a loadSegmentTask childTask", zap.Any("task", loadSegmentTask)) } - parentTask.addChildTask(loadSegmentTask) - log.Debug("assignInternalTask: add a loadSegmentTask childTask", zap.Any("task", loadSegmentTask)) } for index, nodeID := range watchRequest2Nodes { @@ -1846,3 +1859,30 @@ func assignInternalTask(ctx context.Context, } return nil } + +func getSizeOfLoadSegmentReq(req *querypb.LoadSegmentsRequest) int { + var totalSize = 0 + totalSize += int(reflect.ValueOf(*req).Type().Size()) + for _, info := range req.Infos { + totalSize += int(reflect.ValueOf(*info).Type().Size()) + for _, FieldBinlog := range info.BinlogPaths { + totalSize += int(reflect.ValueOf(*FieldBinlog).Type().Size()) + for _, path := range FieldBinlog.Binlogs { + totalSize += len(path) + } + } + } + + totalSize += len(req.Schema.Name) + len(req.Schema.Description) + int(reflect.ValueOf(*req.Schema).Type().Size()) + for _, fieldSchema := range req.Schema.Fields { + totalSize += len(fieldSchema.Name) + len(fieldSchema.Description) + int(reflect.ValueOf(*fieldSchema).Type().Size()) + for _, typeParam := range fieldSchema.TypeParams { + totalSize += len(typeParam.Key) + len(typeParam.Value) + } + for _, indexParam := range fieldSchema.IndexParams { + totalSize += len(indexParam.Key) + len(indexParam.Value) + } + } + + return totalSize +} diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index 9399a1de3ce313d6141e68f1b6f6848f4629069f..7f155e11f8cc69443c06d55d10c32cd4e0a1ef5a 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -19,6 +19,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/funcutil" ) func genLoadCollectionTask(ctx context.Context, queryCoord *QueryCoord) *loadCollectionTask { @@ -48,6 +49,7 @@ func genLoadPartitionTask(ctx context.Context, queryCoord *QueryCoord) *loadPart }, CollectionID: defaultCollectionID, PartitionIDs: []UniqueID{defaultPartitionID}, + Schema: genCollectionSchema(defaultCollectionID, false), } baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest) loadPartitionTask := &loadPartitionTask{ @@ -653,3 +655,49 @@ func Test_RescheduleDmChannelsEndWithFail(t *testing.T) { err = removeAllSession() assert.Nil(t, err) } + +func Test_assignInternalTask(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + + schema := genCollectionSchema(defaultCollectionID, false) + loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) + loadSegmentRequests := make([]*querypb.LoadSegmentsRequest, 0) + binlogs := make([]*datapb.FieldBinlog, 0) + binlogs = append(binlogs, &datapb.FieldBinlog{ + FieldID: 0, + Binlogs: []string{funcutil.RandomString(1000)}, + }) + for id := 0; id < 10000; id++ { + segmentInfo := &querypb.SegmentLoadInfo{ + SegmentID: UniqueID(id), + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + BinlogPaths: binlogs, + } + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadSegments, + }, + NodeID: node1.queryNodeID, + Schema: schema, + Infos: []*querypb.SegmentLoadInfo{segmentInfo}, + } + loadSegmentRequests = append(loadSegmentRequests, req) + } + + err = assignInternalTask(queryCoord.loopCtx, defaultCollectionID, loadCollectionTask, queryCoord.meta, queryCoord.cluster, loadSegmentRequests, nil, false) + assert.Nil(t, err) + + assert.NotEqual(t, 1, len(loadCollectionTask.getChildTask())) + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +}