From 6766169878a58bd66819edaa9aae330520984e3d Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Tue, 25 May 2021 19:53:15 +0800 Subject: [PATCH] Refactor repack logic for insertion (#5399) Signed-off-by: zhenshan.cao --- go.mod | 1 + go.sum | 6 + internal/msgstream/mem_msgstream.go | 4 + internal/msgstream/mq_msgstream.go | 5 + internal/msgstream/msgstream.go | 1 + internal/msgstream/msgstream_mock.go | 4 + internal/proxynode/impl.go | 1 + internal/proxynode/insert_channels.go | 5 +- internal/proxynode/repack_func.go | 277 +------------------------- internal/proxynode/task.go | 217 +++++++++++++++++++- 10 files changed, 244 insertions(+), 277 deletions(-) diff --git a/go.mod b/go.mod index 0ee9de46a..e4301e786 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/pierrec/lz4 v2.5.2+incompatible // indirect github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.7.1 + github.com/quasilyte/go-ruleguard v0.2.1 // indirect github.com/sirupsen/logrus v1.6.0 // indirect github.com/spaolacci/murmur3 v1.1.0 github.com/spf13/cast v1.3.0 diff --git a/go.sum b/go.sum index bcd73eef8..78b73f39d 100644 --- a/go.sum +++ b/go.sum @@ -320,6 +320,8 @@ github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFB github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/protocolbuffers/protobuf v3.17.0+incompatible h1:MYhKKlaNOl8FB3F4u6oM2AlpcyLtT+p8Ec1w/9YeHss= +github.com/quasilyte/go-ruleguard v0.2.1 h1:56eRm0daAyny9UhJnmtJW/UyLZQusukBAB8oT8AHKHo= +github.com/quasilyte/go-ruleguard v0.2.1/go.mod h1:hN2rVc/uS4bQhQKTio2XaSJSafJwqBUWWwtssT3cQmc= github.com/rivo/tview v0.0.0-20200219210816-cd38d7432498/go.mod h1:6lkG1x+13OShEf0EaOCaTQYyB7d5nSbb181KtjlS+84= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= @@ -380,6 +382,7 @@ github.com/yahoo/athenz v1.8.55/go.mod h1:G7LLFUH7Z/r4QAB7FfudfuA7Am/eCzO1GlzBhD github.com/yahoo/athenz v1.9.16 h1:2s8KtIxwAbcJIYySsfrT/t/WO0Ss5O7BPGUN/q8x2bg= github.com/yahoo/athenz v1.9.16/go.mod h1:guj+0Ut6F33wj+OcSRlw69O0itsR7tVocv15F2wJnIo= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= @@ -453,6 +456,7 @@ golang.org/x/net v0.0.0-20190921015927-1a5e07d1ff72/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb h1:eBmm0M9fYhWpKZLjQUUKka/LtIxf46G4fxeEz5KJr9U= @@ -468,6 +472,7 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -531,6 +536,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200812195022-5ae4c3c160a0/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a h1:CB3a9Nez8M13wwlr/E2YtwoU+qYHKfC+JrDa45RXXoQ= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/msgstream/mem_msgstream.go b/internal/msgstream/mem_msgstream.go index 5376010f5..2257852d8 100644 --- a/internal/msgstream/mem_msgstream.go +++ b/internal/msgstream/mem_msgstream.go @@ -67,6 +67,10 @@ func (mms *MemMsgStream) SetRepackFunc(repackFunc RepackFunc) { mms.repackFunc = repackFunc } +func (mms *MemMsgStream) GetProduceChannels() []string { + return mms.producers +} + func (mms *MemMsgStream) AsProducer(channels []string) { for _, channel := range channels { err := Mmq.CreateChannel(channel) diff --git a/internal/msgstream/mq_msgstream.go b/internal/msgstream/mq_msgstream.go index f40e9ceb6..133fda142 100644 --- a/internal/msgstream/mq_msgstream.go +++ b/internal/msgstream/mq_msgstream.go @@ -170,6 +170,7 @@ func (ms *mqMsgStream) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 { } reBucketValues := make([][]int32, len(tsMsgs)) channelNum := uint32(len(ms.producerChannels)) + if channelNum == 0 { return nil } @@ -184,6 +185,10 @@ func (ms *mqMsgStream) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 { return reBucketValues } +func (ms *mqMsgStream) GetProduceChannels() []string { + return ms.producerChannels +} + func (ms *mqMsgStream) Produce(msgPack *MsgPack) error { tsMsgs := msgPack.Msgs if len(tsMsgs) <= 0 { diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index 45e0c3021..e4715cc2f 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -41,6 +41,7 @@ type MsgStream interface { AsConsumer(channels []string, subName string) SetRepackFunc(repackFunc RepackFunc) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 + GetProduceChannels() []string Produce(*MsgPack) error Broadcast(*MsgPack) error Consume() *MsgPack diff --git a/internal/msgstream/msgstream_mock.go b/internal/msgstream/msgstream_mock.go index 0ff3d715c..5f0d7b1bb 100644 --- a/internal/msgstream/msgstream_mock.go +++ b/internal/msgstream/msgstream_mock.go @@ -76,6 +76,10 @@ func (ms *SimpleMsgStream) Broadcast(pack *MsgPack) error { return nil } +func (ms *SimpleMsgStream) GetProduceChannels() []string { + return nil +} + func (ms *SimpleMsgStream) Consume() *MsgPack { if ms.getMsgCount() <= 0 { return nil diff --git a/internal/proxynode/impl.go b/internal/proxynode/impl.go index 81ad07d72..81158108b 100644 --- a/internal/proxynode/impl.go +++ b/internal/proxynode/impl.go @@ -1058,6 +1058,7 @@ func (node *ProxyNode) Insert(ctx context.Context, request *milvuspb.InsertReque }, }, rowIDAllocator: node.idAllocator, + segIDAssigner: node.segAssigner, } if len(it.PartitionName) <= 0 { it.PartitionName = Params.DefaultPartitionName diff --git a/internal/proxynode/insert_channels.go b/internal/proxynode/insert_channels.go index e74c48809..b211a3a00 100644 --- a/internal/proxynode/insert_channels.go +++ b/internal/proxynode/insert_channels.go @@ -61,10 +61,7 @@ func (m *insertChannelsMap) CreateInsertMsgStream(collID UniqueID, channels []st stream, _ := m.msFactory.NewMsgStream(context.Background()) stream.AsProducer(channels) log.Debug("proxynode", zap.Strings("proxynode AsProducer: ", channels)) - repack := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) { - return insertRepackFunc(tsMsgs, hashKeys, m.nodeInstance.segAssigner, true) - } - stream.SetRepackFunc(repack) + stream.SetRepackFunc(insertRepackFunc) stream.Start() m.insertMsgStreams = append(m.insertMsgStreams, stream) m.droppedBitMap = append(m.droppedBitMap, 0) diff --git a/internal/proxynode/repack_func.go b/internal/proxynode/repack_func.go index a34962cb4..f3cdb2194 100644 --- a/internal/proxynode/repack_func.go +++ b/internal/proxynode/repack_func.go @@ -12,290 +12,23 @@ package proxynode import ( - "errors" - "sort" - "unsafe" - - "go.uber.org/zap" - - "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/msgstream" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/util/typeutil" ) func insertRepackFunc(tsMsgs []msgstream.TsMsg, - hashKeys [][]int32, - segIDAssigner *SegIDAssigner, - together bool) (map[int32]*msgstream.MsgPack, error) { + hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) { result := make(map[int32]*msgstream.MsgPack) - - channelCountMap := make(map[UniqueID]map[int32]uint32) // reqID --> channelID to count - channelMaxTSMap := make(map[UniqueID]map[int32]Timestamp) // reqID --> channelID to max Timestamp - reqSchemaMap := make(map[UniqueID][]UniqueID) // reqID --> channelID [2]UniqueID {CollectionID, PartitionID} - channelNamesMap := make(map[UniqueID][]string) // collectionID --> channelNames - for i, request := range tsMsgs { - if request.Type() != commonpb.MsgType_Insert { - return nil, errors.New("msg's must be Insert") - } - insertRequest, ok := request.(*msgstream.InsertMsg) - if !ok { - return nil, errors.New("msg's must be Insert") - } - keys := hashKeys[i] - timestampLen := len(insertRequest.Timestamps) - rowIDLen := len(insertRequest.RowIDs) - rowDataLen := len(insertRequest.RowData) - keysLen := len(keys) - - if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { - return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal") - } - - reqID := insertRequest.Base.MsgID - if _, ok := channelCountMap[reqID]; !ok { - channelCountMap[reqID] = make(map[int32]uint32) - } - - if _, ok := channelMaxTSMap[reqID]; !ok { - channelMaxTSMap[reqID] = make(map[int32]Timestamp) - } - - if _, ok := reqSchemaMap[reqID]; !ok { - reqSchemaMap[reqID] = []UniqueID{insertRequest.CollectionID, insertRequest.PartitionID} - } - - for idx, channelID := range keys { - channelCountMap[reqID][channelID]++ - if _, ok := channelMaxTSMap[reqID][channelID]; !ok { - channelMaxTSMap[reqID][channelID] = typeutil.ZeroTimestamp - } - ts := insertRequest.Timestamps[idx] - if channelMaxTSMap[reqID][channelID] < ts { - channelMaxTSMap[reqID][channelID] = ts - } - } - - collID := insertRequest.CollectionID - if _, ok := channelNamesMap[collID]; !ok { - channelNames, err := globalInsertChannelsMap.GetInsertChannels(collID) - if err != nil { - return nil, err - } - channelNamesMap[collID] = channelNames - } - } - - var getChannelName = func(collID UniqueID, channelID int32) string { - if _, ok := channelNamesMap[collID]; !ok { - return "" - } - names := channelNamesMap[collID] - return names[channelID] - } - - reqSegCountMap := make(map[UniqueID]map[int32]map[UniqueID]uint32) - - for reqID, countInfo := range channelCountMap { - if _, ok := reqSegCountMap[reqID]; !ok { - reqSegCountMap[reqID] = make(map[int32]map[UniqueID]uint32) - } - schema := reqSchemaMap[reqID] - collID, partitionID := schema[0], schema[1] - for channelID, count := range countInfo { - ts, ok := channelMaxTSMap[reqID][channelID] - if !ok { - ts = typeutil.ZeroTimestamp - log.Debug("Warning: did not get max Timstamp!") - } - channelName := getChannelName(collID, channelID) - if channelName == "" { - return nil, errors.New("ProxyNode, repack_func, can not found channelName") - } - mapInfo, err := segIDAssigner.GetSegmentID(collID, partitionID, channelName, count, ts) - if err != nil { - return nil, err - } - reqSegCountMap[reqID][channelID] = make(map[UniqueID]uint32) - reqSegCountMap[reqID][channelID] = mapInfo - log.Debug("proxynode", zap.Int64("repackFunc, reqSegCountMap, reqID", reqID), zap.Any("mapinfo", mapInfo)) - } - } - - reqSegAccumulateCountMap := make(map[UniqueID]map[int32][]uint32) - reqSegIDMap := make(map[UniqueID]map[int32][]UniqueID) - reqSegAllocateCounter := make(map[UniqueID]map[int32]uint32) - - for reqID, channelInfo := range reqSegCountMap { - if _, ok := reqSegAccumulateCountMap[reqID]; !ok { - reqSegAccumulateCountMap[reqID] = make(map[int32][]uint32) - } - if _, ok := reqSegIDMap[reqID]; !ok { - reqSegIDMap[reqID] = make(map[int32][]UniqueID) - } - if _, ok := reqSegAllocateCounter[reqID]; !ok { - reqSegAllocateCounter[reqID] = make(map[int32]uint32) - } - for channelID, segInfo := range channelInfo { - reqSegAllocateCounter[reqID][channelID] = 0 - keys := make([]UniqueID, len(segInfo)) - i := 0 - for key := range segInfo { - keys[i] = key - i++ - } - sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) - accumulate := uint32(0) - for _, key := range keys { - accumulate += segInfo[key] - if _, ok := reqSegAccumulateCountMap[reqID][channelID]; !ok { - reqSegAccumulateCountMap[reqID][channelID] = make([]uint32, 0) - } - reqSegAccumulateCountMap[reqID][channelID] = append( - reqSegAccumulateCountMap[reqID][channelID], - accumulate, - ) - if _, ok := reqSegIDMap[reqID][channelID]; !ok { - reqSegIDMap[reqID][channelID] = make([]UniqueID, 0) - } - reqSegIDMap[reqID][channelID] = append( - reqSegIDMap[reqID][channelID], - key, - ) - } - } - } - - var getSegmentID = func(reqID UniqueID, channelID int32) UniqueID { - reqSegAllocateCounter[reqID][channelID]++ - cur := reqSegAllocateCounter[reqID][channelID] - accumulateSlice := reqSegAccumulateCountMap[reqID][channelID] - segIDSlice := reqSegIDMap[reqID][channelID] - for index, count := range accumulateSlice { - if cur <= count { - return segIDSlice[index] - } - } - log.Warn("Can't Found SegmentID") - return 0 - } - - factor := 10 - threshold := Params.PulsarMaxMessageSize / factor - log.Debug("proxynode", zap.Int("threshold of message size: ", threshold)) - // not accurate - getSizeOfInsertMsg := func(msg *msgstream.InsertMsg) int { - // if real struct, call unsafe.Sizeof directly, - // if reference, dereference and then call unsafe.Sizeof, - // if slice, todo: a common function to calculate size of slice, - // if map, a little complicated - size := 0 - size += int(unsafe.Sizeof(msg.Ctx)) - size += int(unsafe.Sizeof(msg.BeginTimestamp)) - size += int(unsafe.Sizeof(msg.EndTimestamp)) - size += int(unsafe.Sizeof(msg.HashValues)) - size += len(msg.HashValues) * 4 - size += int(unsafe.Sizeof(*msg.MsgPosition)) - size += int(unsafe.Sizeof(*msg.Base)) - size += int(unsafe.Sizeof(msg.DbName)) - size += int(unsafe.Sizeof(msg.CollectionName)) - size += int(unsafe.Sizeof(msg.PartitionName)) - size += int(unsafe.Sizeof(msg.DbID)) - size += int(unsafe.Sizeof(msg.CollectionID)) - size += int(unsafe.Sizeof(msg.PartitionID)) - size += int(unsafe.Sizeof(msg.SegmentID)) - size += int(unsafe.Sizeof(msg.ChannelID)) - size += int(unsafe.Sizeof(msg.Timestamps)) - size += int(unsafe.Sizeof(msg.RowIDs)) - size += len(msg.RowIDs) * 8 - for _, blob := range msg.RowData { - size += int(unsafe.Sizeof(blob.Value)) - size += len(blob.Value) - } - - //log.Debug("proxynode", zap.Int("insert message size", size)) - return size - } - // not accurate - // getSizeOfMsgPack := func(mp *msgstream.MsgPack) int { - // size := 0 - // for _, msg := range mp.Msgs { - // insertMsg, ok := msg.(*msgstream.InsertMsg) - // if !ok { - // log.Panic("only insert message is supported!") - // } - // size += getSizeOfInsertMsg(insertMsg) - // } - // return size - // } - - for i, request := range tsMsgs { - insertRequest := request.(*msgstream.InsertMsg) - keys := hashKeys[i] - reqID := insertRequest.Base.MsgID - collectionName := insertRequest.CollectionName - collectionID := insertRequest.CollectionID - partitionID := insertRequest.PartitionID - partitionName := insertRequest.PartitionName - proxyID := insertRequest.Base.SourceID - channelNames := channelNamesMap[collectionID] - for index, key := range keys { - ts := insertRequest.Timestamps[index] - rowID := insertRequest.RowIDs[index] - row := insertRequest.RowData[index] + if len(keys) > 0 { + key := keys[0] _, ok := result[key] if !ok { - msgPack := msgstream.MsgPack{} - result[key] = &msgPack - } - segmentID := getSegmentID(reqID, key) - channelID := channelNames[int(key)%len(channelNames)] - sliceRequest := internalpb.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: reqID, - Timestamp: ts, - SourceID: proxyID, - }, - CollectionID: collectionID, - PartitionID: partitionID, - CollectionName: collectionName, - PartitionName: partitionName, - SegmentID: segmentID, - // todo rename to ChannelName - // ChannelID: strconv.FormatInt(int64(key), 10), - ChannelID: channelID, - Timestamps: []uint64{ts}, - RowIDs: []int64{rowID}, - RowData: []*commonpb.Blob{row}, - } - insertMsg := &msgstream.InsertMsg{ - BaseMsg: msgstream.BaseMsg{ - Ctx: request.TraceCtx(), - }, - InsertRequest: sliceRequest, - } - if together { // all rows with same hash value are accumulated to only one message - msgNums := len(result[key].Msgs) - if len(result[key].Msgs) <= 0 { - result[key].Msgs = append(result[key].Msgs, insertMsg) - } else if getSizeOfInsertMsg(result[key].Msgs[msgNums-1].(*msgstream.InsertMsg)) >= threshold { - result[key].Msgs = append(result[key].Msgs, insertMsg) - } else { - accMsgs, _ := result[key].Msgs[msgNums-1].(*msgstream.InsertMsg) - accMsgs.Timestamps = append(accMsgs.Timestamps, ts) - accMsgs.RowIDs = append(accMsgs.RowIDs, rowID) - accMsgs.RowData = append(accMsgs.RowData, row) - } - } else { // every row is a message - result[key].Msgs = append(result[key].Msgs, insertMsg) + result[key] = &msgstream.MsgPack{} } + result[key].Msgs = append(result[key].Msgs, request) } } - return result, nil } diff --git a/internal/proxynode/task.go b/internal/proxynode/task.go index 55cb8a16f..b24d5915e 100644 --- a/internal/proxynode/task.go +++ b/internal/proxynode/task.go @@ -18,8 +18,10 @@ import ( "math" "regexp" "runtime" + "sort" "strconv" "time" + "unsafe" "github.com/milvus-io/milvus/internal/proto/planpb" @@ -99,6 +101,7 @@ type InsertTask struct { dataService types.DataService result *milvuspb.InsertResponse rowIDAllocator *allocator.IDAllocator + segIDAssigner *SegIDAssigner } func (it *InsertTask) TraceCtx() context.Context { @@ -160,6 +163,211 @@ func (it *InsertTask) PreExecute(ctx context.Context) error { return nil } +func (it *InsertTask) _assignSegmentID(stream msgstream.MsgStream, pack *msgstream.MsgPack) (*msgstream.MsgPack, error) { + newPack := &msgstream.MsgPack{ + BeginTs: pack.BeginTs, + EndTs: pack.EndTs, + StartPositions: pack.StartPositions, + EndPositions: pack.EndPositions, + Msgs: nil, + } + tsMsgs := pack.Msgs + hashKeys := stream.ComputeProduceChannelIndexes(tsMsgs) + reqID := it.Base.MsgID + channelCountMap := make(map[int32]uint32) // channelID to count + channelMaxTSMap := make(map[int32]Timestamp) // channelID to max Timestamp + channelNames := stream.GetProduceChannels() + log.Debug("_assignSemgentID, produceChannels:", zap.Any("Channels", channelNames)) + + for i, request := range tsMsgs { + if request.Type() != commonpb.MsgType_Insert { + return nil, fmt.Errorf("msg's must be Insert") + } + insertRequest, ok := request.(*msgstream.InsertMsg) + if !ok { + return nil, fmt.Errorf("msg's must be Insert") + } + + keys := hashKeys[i] + timestampLen := len(insertRequest.Timestamps) + rowIDLen := len(insertRequest.RowIDs) + rowDataLen := len(insertRequest.RowData) + keysLen := len(keys) + + if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { + return nil, fmt.Errorf("the length of hashValue, timestamps, rowIDs, RowData are not equal") + } + + for idx, channelID := range keys { + channelCountMap[channelID]++ + if _, ok := channelMaxTSMap[channelID]; !ok { + channelMaxTSMap[channelID] = typeutil.ZeroTimestamp + } + ts := insertRequest.Timestamps[idx] + if channelMaxTSMap[channelID] < ts { + channelMaxTSMap[channelID] = ts + } + } + } + + reqSegCountMap := make(map[int32]map[UniqueID]uint32) + + for channelID, count := range channelCountMap { + ts, ok := channelMaxTSMap[channelID] + if !ok { + ts = typeutil.ZeroTimestamp + log.Debug("Warning: did not get max Timestamp!") + } + channelName := channelNames[channelID] + if channelName == "" { + return nil, fmt.Errorf("ProxyNode, repack_func, can not found channelName") + } + mapInfo, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, count, ts) + if err != nil { + return nil, err + } + reqSegCountMap[channelID] = make(map[UniqueID]uint32) + reqSegCountMap[channelID] = mapInfo + log.Debug("ProxyNode", zap.Int64("repackFunc, reqSegCountMap, reqID", reqID), zap.Any("mapinfo", mapInfo)) + } + + reqSegAccumulateCountMap := make(map[int32][]uint32) + reqSegIDMap := make(map[int32][]UniqueID) + reqSegAllocateCounter := make(map[int32]uint32) + + for channelID, segInfo := range reqSegCountMap { + reqSegAllocateCounter[channelID] = 0 + keys := make([]UniqueID, len(segInfo)) + i := 0 + for key := range segInfo { + keys[i] = key + i++ + } + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + accumulate := uint32(0) + for _, key := range keys { + accumulate += segInfo[key] + if _, ok := reqSegAccumulateCountMap[channelID]; !ok { + reqSegAccumulateCountMap[channelID] = make([]uint32, 0) + } + reqSegAccumulateCountMap[channelID] = append( + reqSegAccumulateCountMap[channelID], + accumulate, + ) + if _, ok := reqSegIDMap[channelID]; !ok { + reqSegIDMap[channelID] = make([]UniqueID, 0) + } + reqSegIDMap[channelID] = append( + reqSegIDMap[channelID], + key, + ) + } + } + + var getSegmentID = func(channelID int32) UniqueID { + reqSegAllocateCounter[channelID]++ + cur := reqSegAllocateCounter[channelID] + accumulateSlice := reqSegAccumulateCountMap[channelID] + segIDSlice := reqSegIDMap[channelID] + for index, count := range accumulateSlice { + if cur <= count { + return segIDSlice[index] + } + } + log.Warn("Can't Found SegmentID") + return 0 + } + + factor := 10 + threshold := Params.PulsarMaxMessageSize / factor + log.Debug("ProxyNode", zap.Int("threshold of message size: ", threshold)) + // not accurate + getFixedSizeOfInsertMsg := func(msg *msgstream.InsertMsg) int { + size := 0 + + size += int(unsafe.Sizeof(*msg.Base)) + size += int(unsafe.Sizeof(msg.DbName)) + size += int(unsafe.Sizeof(msg.CollectionName)) + size += int(unsafe.Sizeof(msg.PartitionName)) + size += int(unsafe.Sizeof(msg.DbID)) + size += int(unsafe.Sizeof(msg.CollectionID)) + size += int(unsafe.Sizeof(msg.PartitionID)) + size += int(unsafe.Sizeof(msg.SegmentID)) + size += int(unsafe.Sizeof(msg.ChannelID)) + size += int(unsafe.Sizeof(msg.Timestamps)) + size += int(unsafe.Sizeof(msg.RowIDs)) + return size + } + + result := make(map[int32]msgstream.TsMsg) + curMsgSizeMap := make(map[int32]int) + + for i, request := range tsMsgs { + insertRequest := request.(*msgstream.InsertMsg) + keys := hashKeys[i] + collectionName := insertRequest.CollectionName + collectionID := insertRequest.CollectionID + partitionID := insertRequest.PartitionID + partitionName := insertRequest.PartitionName + proxyID := insertRequest.Base.SourceID + for index, key := range keys { + ts := insertRequest.Timestamps[index] + rowID := insertRequest.RowIDs[index] + row := insertRequest.RowData[index] + segmentID := getSegmentID(key) + _, ok := result[key] + if !ok { + sliceRequest := internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: reqID, + Timestamp: ts, + SourceID: proxyID, + }, + CollectionID: collectionID, + PartitionID: partitionID, + CollectionName: collectionName, + PartitionName: partitionName, + SegmentID: segmentID, + // todo rename to ChannelName + ChannelID: channelNames[key], + } + insertMsg := &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: request.TraceCtx(), + }, + InsertRequest: sliceRequest, + } + result[key] = insertMsg + curMsgSizeMap[key] = getFixedSizeOfInsertMsg(insertMsg) + } + curMsg := result[key].(*msgstream.InsertMsg) + curMsgSize := curMsgSizeMap[key] + curMsg.HashValues = append(curMsg.HashValues, insertRequest.HashValues[index]) + curMsg.Timestamps = append(curMsg.Timestamps, ts) + curMsg.RowIDs = append(curMsg.RowIDs, rowID) + curMsg.RowData = append(curMsg.RowData, row) + curMsgSize += 4 + 8 + int(unsafe.Sizeof(row.Value)) + curMsgSize += len(row.Value) + + if curMsgSize >= threshold { + newPack.Msgs = append(newPack.Msgs, curMsg) + delete(result, key) + curMsgSize = 0 + } + + curMsgSizeMap[key] = curMsgSize + } + } + for _, msg := range result { + if msg != nil { + newPack.Msgs = append(newPack.Msgs, msg) + } + } + + return newPack, nil +} + func (it *InsertTask) Execute(ctx context.Context) error { collectionName := it.BaseInsertTask.CollectionName collSchema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName) @@ -254,7 +462,14 @@ func (it *InsertTask) Execute(ctx context.Context) error { return err } - err = stream.Produce(&msgPack) + // Assign SegmentID + var pack *msgstream.MsgPack + pack, err = it._assignSegmentID(stream, &msgPack) + if err != nil { + return err + } + + err = stream.Produce(pack) if err != nil { it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError it.result.Status.Reason = err.Error() -- GitLab