未验证 提交 67661698 编写于 作者: Z zhenshan.cao 提交者: GitHub

Refactor repack logic for insertion (#5399)

Signed-off-by: Nzhenshan.cao <zhenshan.cao@zilliz.com>
上级 1c49ddc8
......@@ -26,6 +26,7 @@ require (
github.com/pierrec/lz4 v2.5.2+incompatible // indirect
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.7.1
github.com/quasilyte/go-ruleguard v0.2.1 // indirect
github.com/sirupsen/logrus v1.6.0 // indirect
github.com/spaolacci/murmur3 v1.1.0
github.com/spf13/cast v1.3.0
......
......@@ -320,6 +320,8 @@ github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFB
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/protocolbuffers/protobuf v3.17.0+incompatible h1:MYhKKlaNOl8FB3F4u6oM2AlpcyLtT+p8Ec1w/9YeHss=
github.com/quasilyte/go-ruleguard v0.2.1 h1:56eRm0daAyny9UhJnmtJW/UyLZQusukBAB8oT8AHKHo=
github.com/quasilyte/go-ruleguard v0.2.1/go.mod h1:hN2rVc/uS4bQhQKTio2XaSJSafJwqBUWWwtssT3cQmc=
github.com/rivo/tview v0.0.0-20200219210816-cd38d7432498/go.mod h1:6lkG1x+13OShEf0EaOCaTQYyB7d5nSbb181KtjlS+84=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
......@@ -380,6 +382,7 @@ github.com/yahoo/athenz v1.8.55/go.mod h1:G7LLFUH7Z/r4QAB7FfudfuA7Am/eCzO1GlzBhD
github.com/yahoo/athenz v1.9.16 h1:2s8KtIxwAbcJIYySsfrT/t/WO0Ss5O7BPGUN/q8x2bg=
github.com/yahoo/athenz v1.9.16/go.mod h1:guj+0Ut6F33wj+OcSRlw69O0itsR7tVocv15F2wJnIo=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
......@@ -453,6 +456,7 @@ golang.org/x/net v0.0.0-20190921015927-1a5e07d1ff72/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb h1:eBmm0M9fYhWpKZLjQUUKka/LtIxf46G4fxeEz5KJr9U=
......@@ -468,6 +472,7 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
......@@ -531,6 +536,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200812195022-5ae4c3c160a0/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a h1:CB3a9Nez8M13wwlr/E2YtwoU+qYHKfC+JrDa45RXXoQ=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
......
......@@ -67,6 +67,10 @@ func (mms *MemMsgStream) SetRepackFunc(repackFunc RepackFunc) {
mms.repackFunc = repackFunc
}
func (mms *MemMsgStream) GetProduceChannels() []string {
return mms.producers
}
func (mms *MemMsgStream) AsProducer(channels []string) {
for _, channel := range channels {
err := Mmq.CreateChannel(channel)
......
......@@ -170,6 +170,7 @@ func (ms *mqMsgStream) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 {
}
reBucketValues := make([][]int32, len(tsMsgs))
channelNum := uint32(len(ms.producerChannels))
if channelNum == 0 {
return nil
}
......@@ -184,6 +185,10 @@ func (ms *mqMsgStream) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 {
return reBucketValues
}
func (ms *mqMsgStream) GetProduceChannels() []string {
return ms.producerChannels
}
func (ms *mqMsgStream) Produce(msgPack *MsgPack) error {
tsMsgs := msgPack.Msgs
if len(tsMsgs) <= 0 {
......
......@@ -41,6 +41,7 @@ type MsgStream interface {
AsConsumer(channels []string, subName string)
SetRepackFunc(repackFunc RepackFunc)
ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32
GetProduceChannels() []string
Produce(*MsgPack) error
Broadcast(*MsgPack) error
Consume() *MsgPack
......
......@@ -76,6 +76,10 @@ func (ms *SimpleMsgStream) Broadcast(pack *MsgPack) error {
return nil
}
func (ms *SimpleMsgStream) GetProduceChannels() []string {
return nil
}
func (ms *SimpleMsgStream) Consume() *MsgPack {
if ms.getMsgCount() <= 0 {
return nil
......
......@@ -1058,6 +1058,7 @@ func (node *ProxyNode) Insert(ctx context.Context, request *milvuspb.InsertReque
},
},
rowIDAllocator: node.idAllocator,
segIDAssigner: node.segAssigner,
}
if len(it.PartitionName) <= 0 {
it.PartitionName = Params.DefaultPartitionName
......
......@@ -61,10 +61,7 @@ func (m *insertChannelsMap) CreateInsertMsgStream(collID UniqueID, channels []st
stream, _ := m.msFactory.NewMsgStream(context.Background())
stream.AsProducer(channels)
log.Debug("proxynode", zap.Strings("proxynode AsProducer: ", channels))
repack := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return insertRepackFunc(tsMsgs, hashKeys, m.nodeInstance.segAssigner, true)
}
stream.SetRepackFunc(repack)
stream.SetRepackFunc(insertRepackFunc)
stream.Start()
m.insertMsgStreams = append(m.insertMsgStreams, stream)
m.droppedBitMap = append(m.droppedBitMap, 0)
......
......@@ -12,290 +12,23 @@
package proxynode
import (
"errors"
"sort"
"unsafe"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func insertRepackFunc(tsMsgs []msgstream.TsMsg,
hashKeys [][]int32,
segIDAssigner *SegIDAssigner,
together bool) (map[int32]*msgstream.MsgPack, error) {
hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
result := make(map[int32]*msgstream.MsgPack)
channelCountMap := make(map[UniqueID]map[int32]uint32) // reqID --> channelID to count
channelMaxTSMap := make(map[UniqueID]map[int32]Timestamp) // reqID --> channelID to max Timestamp
reqSchemaMap := make(map[UniqueID][]UniqueID) // reqID --> channelID [2]UniqueID {CollectionID, PartitionID}
channelNamesMap := make(map[UniqueID][]string) // collectionID --> channelNames
for i, request := range tsMsgs {
if request.Type() != commonpb.MsgType_Insert {
return nil, errors.New("msg's must be Insert")
}
insertRequest, ok := request.(*msgstream.InsertMsg)
if !ok {
return nil, errors.New("msg's must be Insert")
}
keys := hashKeys[i]
timestampLen := len(insertRequest.Timestamps)
rowIDLen := len(insertRequest.RowIDs)
rowDataLen := len(insertRequest.RowData)
keysLen := len(keys)
if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen {
return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal")
}
reqID := insertRequest.Base.MsgID
if _, ok := channelCountMap[reqID]; !ok {
channelCountMap[reqID] = make(map[int32]uint32)
}
if _, ok := channelMaxTSMap[reqID]; !ok {
channelMaxTSMap[reqID] = make(map[int32]Timestamp)
}
if _, ok := reqSchemaMap[reqID]; !ok {
reqSchemaMap[reqID] = []UniqueID{insertRequest.CollectionID, insertRequest.PartitionID}
}
for idx, channelID := range keys {
channelCountMap[reqID][channelID]++
if _, ok := channelMaxTSMap[reqID][channelID]; !ok {
channelMaxTSMap[reqID][channelID] = typeutil.ZeroTimestamp
}
ts := insertRequest.Timestamps[idx]
if channelMaxTSMap[reqID][channelID] < ts {
channelMaxTSMap[reqID][channelID] = ts
}
}
collID := insertRequest.CollectionID
if _, ok := channelNamesMap[collID]; !ok {
channelNames, err := globalInsertChannelsMap.GetInsertChannels(collID)
if err != nil {
return nil, err
}
channelNamesMap[collID] = channelNames
}
}
var getChannelName = func(collID UniqueID, channelID int32) string {
if _, ok := channelNamesMap[collID]; !ok {
return ""
}
names := channelNamesMap[collID]
return names[channelID]
}
reqSegCountMap := make(map[UniqueID]map[int32]map[UniqueID]uint32)
for reqID, countInfo := range channelCountMap {
if _, ok := reqSegCountMap[reqID]; !ok {
reqSegCountMap[reqID] = make(map[int32]map[UniqueID]uint32)
}
schema := reqSchemaMap[reqID]
collID, partitionID := schema[0], schema[1]
for channelID, count := range countInfo {
ts, ok := channelMaxTSMap[reqID][channelID]
if !ok {
ts = typeutil.ZeroTimestamp
log.Debug("Warning: did not get max Timstamp!")
}
channelName := getChannelName(collID, channelID)
if channelName == "" {
return nil, errors.New("ProxyNode, repack_func, can not found channelName")
}
mapInfo, err := segIDAssigner.GetSegmentID(collID, partitionID, channelName, count, ts)
if err != nil {
return nil, err
}
reqSegCountMap[reqID][channelID] = make(map[UniqueID]uint32)
reqSegCountMap[reqID][channelID] = mapInfo
log.Debug("proxynode", zap.Int64("repackFunc, reqSegCountMap, reqID", reqID), zap.Any("mapinfo", mapInfo))
}
}
reqSegAccumulateCountMap := make(map[UniqueID]map[int32][]uint32)
reqSegIDMap := make(map[UniqueID]map[int32][]UniqueID)
reqSegAllocateCounter := make(map[UniqueID]map[int32]uint32)
for reqID, channelInfo := range reqSegCountMap {
if _, ok := reqSegAccumulateCountMap[reqID]; !ok {
reqSegAccumulateCountMap[reqID] = make(map[int32][]uint32)
}
if _, ok := reqSegIDMap[reqID]; !ok {
reqSegIDMap[reqID] = make(map[int32][]UniqueID)
}
if _, ok := reqSegAllocateCounter[reqID]; !ok {
reqSegAllocateCounter[reqID] = make(map[int32]uint32)
}
for channelID, segInfo := range channelInfo {
reqSegAllocateCounter[reqID][channelID] = 0
keys := make([]UniqueID, len(segInfo))
i := 0
for key := range segInfo {
keys[i] = key
i++
}
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
accumulate := uint32(0)
for _, key := range keys {
accumulate += segInfo[key]
if _, ok := reqSegAccumulateCountMap[reqID][channelID]; !ok {
reqSegAccumulateCountMap[reqID][channelID] = make([]uint32, 0)
}
reqSegAccumulateCountMap[reqID][channelID] = append(
reqSegAccumulateCountMap[reqID][channelID],
accumulate,
)
if _, ok := reqSegIDMap[reqID][channelID]; !ok {
reqSegIDMap[reqID][channelID] = make([]UniqueID, 0)
}
reqSegIDMap[reqID][channelID] = append(
reqSegIDMap[reqID][channelID],
key,
)
}
}
}
var getSegmentID = func(reqID UniqueID, channelID int32) UniqueID {
reqSegAllocateCounter[reqID][channelID]++
cur := reqSegAllocateCounter[reqID][channelID]
accumulateSlice := reqSegAccumulateCountMap[reqID][channelID]
segIDSlice := reqSegIDMap[reqID][channelID]
for index, count := range accumulateSlice {
if cur <= count {
return segIDSlice[index]
}
}
log.Warn("Can't Found SegmentID")
return 0
}
factor := 10
threshold := Params.PulsarMaxMessageSize / factor
log.Debug("proxynode", zap.Int("threshold of message size: ", threshold))
// not accurate
getSizeOfInsertMsg := func(msg *msgstream.InsertMsg) int {
// if real struct, call unsafe.Sizeof directly,
// if reference, dereference and then call unsafe.Sizeof,
// if slice, todo: a common function to calculate size of slice,
// if map, a little complicated
size := 0
size += int(unsafe.Sizeof(msg.Ctx))
size += int(unsafe.Sizeof(msg.BeginTimestamp))
size += int(unsafe.Sizeof(msg.EndTimestamp))
size += int(unsafe.Sizeof(msg.HashValues))
size += len(msg.HashValues) * 4
size += int(unsafe.Sizeof(*msg.MsgPosition))
size += int(unsafe.Sizeof(*msg.Base))
size += int(unsafe.Sizeof(msg.DbName))
size += int(unsafe.Sizeof(msg.CollectionName))
size += int(unsafe.Sizeof(msg.PartitionName))
size += int(unsafe.Sizeof(msg.DbID))
size += int(unsafe.Sizeof(msg.CollectionID))
size += int(unsafe.Sizeof(msg.PartitionID))
size += int(unsafe.Sizeof(msg.SegmentID))
size += int(unsafe.Sizeof(msg.ChannelID))
size += int(unsafe.Sizeof(msg.Timestamps))
size += int(unsafe.Sizeof(msg.RowIDs))
size += len(msg.RowIDs) * 8
for _, blob := range msg.RowData {
size += int(unsafe.Sizeof(blob.Value))
size += len(blob.Value)
}
//log.Debug("proxynode", zap.Int("insert message size", size))
return size
}
// not accurate
// getSizeOfMsgPack := func(mp *msgstream.MsgPack) int {
// size := 0
// for _, msg := range mp.Msgs {
// insertMsg, ok := msg.(*msgstream.InsertMsg)
// if !ok {
// log.Panic("only insert message is supported!")
// }
// size += getSizeOfInsertMsg(insertMsg)
// }
// return size
// }
for i, request := range tsMsgs {
insertRequest := request.(*msgstream.InsertMsg)
keys := hashKeys[i]
reqID := insertRequest.Base.MsgID
collectionName := insertRequest.CollectionName
collectionID := insertRequest.CollectionID
partitionID := insertRequest.PartitionID
partitionName := insertRequest.PartitionName
proxyID := insertRequest.Base.SourceID
channelNames := channelNamesMap[collectionID]
for index, key := range keys {
ts := insertRequest.Timestamps[index]
rowID := insertRequest.RowIDs[index]
row := insertRequest.RowData[index]
if len(keys) > 0 {
key := keys[0]
_, ok := result[key]
if !ok {
msgPack := msgstream.MsgPack{}
result[key] = &msgPack
}
segmentID := getSegmentID(reqID, key)
channelID := channelNames[int(key)%len(channelNames)]
sliceRequest := internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: reqID,
Timestamp: ts,
SourceID: proxyID,
},
CollectionID: collectionID,
PartitionID: partitionID,
CollectionName: collectionName,
PartitionName: partitionName,
SegmentID: segmentID,
// todo rename to ChannelName
// ChannelID: strconv.FormatInt(int64(key), 10),
ChannelID: channelID,
Timestamps: []uint64{ts},
RowIDs: []int64{rowID},
RowData: []*commonpb.Blob{row},
}
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: request.TraceCtx(),
},
InsertRequest: sliceRequest,
}
if together { // all rows with same hash value are accumulated to only one message
msgNums := len(result[key].Msgs)
if len(result[key].Msgs) <= 0 {
result[key].Msgs = append(result[key].Msgs, insertMsg)
} else if getSizeOfInsertMsg(result[key].Msgs[msgNums-1].(*msgstream.InsertMsg)) >= threshold {
result[key].Msgs = append(result[key].Msgs, insertMsg)
} else {
accMsgs, _ := result[key].Msgs[msgNums-1].(*msgstream.InsertMsg)
accMsgs.Timestamps = append(accMsgs.Timestamps, ts)
accMsgs.RowIDs = append(accMsgs.RowIDs, rowID)
accMsgs.RowData = append(accMsgs.RowData, row)
}
} else { // every row is a message
result[key].Msgs = append(result[key].Msgs, insertMsg)
result[key] = &msgstream.MsgPack{}
}
result[key].Msgs = append(result[key].Msgs, request)
}
}
return result, nil
}
......@@ -18,8 +18,10 @@ import (
"math"
"regexp"
"runtime"
"sort"
"strconv"
"time"
"unsafe"
"github.com/milvus-io/milvus/internal/proto/planpb"
......@@ -99,6 +101,7 @@ type InsertTask struct {
dataService types.DataService
result *milvuspb.InsertResponse
rowIDAllocator *allocator.IDAllocator
segIDAssigner *SegIDAssigner
}
func (it *InsertTask) TraceCtx() context.Context {
......@@ -160,6 +163,211 @@ func (it *InsertTask) PreExecute(ctx context.Context) error {
return nil
}
func (it *InsertTask) _assignSegmentID(stream msgstream.MsgStream, pack *msgstream.MsgPack) (*msgstream.MsgPack, error) {
newPack := &msgstream.MsgPack{
BeginTs: pack.BeginTs,
EndTs: pack.EndTs,
StartPositions: pack.StartPositions,
EndPositions: pack.EndPositions,
Msgs: nil,
}
tsMsgs := pack.Msgs
hashKeys := stream.ComputeProduceChannelIndexes(tsMsgs)
reqID := it.Base.MsgID
channelCountMap := make(map[int32]uint32) // channelID to count
channelMaxTSMap := make(map[int32]Timestamp) // channelID to max Timestamp
channelNames := stream.GetProduceChannels()
log.Debug("_assignSemgentID, produceChannels:", zap.Any("Channels", channelNames))
for i, request := range tsMsgs {
if request.Type() != commonpb.MsgType_Insert {
return nil, fmt.Errorf("msg's must be Insert")
}
insertRequest, ok := request.(*msgstream.InsertMsg)
if !ok {
return nil, fmt.Errorf("msg's must be Insert")
}
keys := hashKeys[i]
timestampLen := len(insertRequest.Timestamps)
rowIDLen := len(insertRequest.RowIDs)
rowDataLen := len(insertRequest.RowData)
keysLen := len(keys)
if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen {
return nil, fmt.Errorf("the length of hashValue, timestamps, rowIDs, RowData are not equal")
}
for idx, channelID := range keys {
channelCountMap[channelID]++
if _, ok := channelMaxTSMap[channelID]; !ok {
channelMaxTSMap[channelID] = typeutil.ZeroTimestamp
}
ts := insertRequest.Timestamps[idx]
if channelMaxTSMap[channelID] < ts {
channelMaxTSMap[channelID] = ts
}
}
}
reqSegCountMap := make(map[int32]map[UniqueID]uint32)
for channelID, count := range channelCountMap {
ts, ok := channelMaxTSMap[channelID]
if !ok {
ts = typeutil.ZeroTimestamp
log.Debug("Warning: did not get max Timestamp!")
}
channelName := channelNames[channelID]
if channelName == "" {
return nil, fmt.Errorf("ProxyNode, repack_func, can not found channelName")
}
mapInfo, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, count, ts)
if err != nil {
return nil, err
}
reqSegCountMap[channelID] = make(map[UniqueID]uint32)
reqSegCountMap[channelID] = mapInfo
log.Debug("ProxyNode", zap.Int64("repackFunc, reqSegCountMap, reqID", reqID), zap.Any("mapinfo", mapInfo))
}
reqSegAccumulateCountMap := make(map[int32][]uint32)
reqSegIDMap := make(map[int32][]UniqueID)
reqSegAllocateCounter := make(map[int32]uint32)
for channelID, segInfo := range reqSegCountMap {
reqSegAllocateCounter[channelID] = 0
keys := make([]UniqueID, len(segInfo))
i := 0
for key := range segInfo {
keys[i] = key
i++
}
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
accumulate := uint32(0)
for _, key := range keys {
accumulate += segInfo[key]
if _, ok := reqSegAccumulateCountMap[channelID]; !ok {
reqSegAccumulateCountMap[channelID] = make([]uint32, 0)
}
reqSegAccumulateCountMap[channelID] = append(
reqSegAccumulateCountMap[channelID],
accumulate,
)
if _, ok := reqSegIDMap[channelID]; !ok {
reqSegIDMap[channelID] = make([]UniqueID, 0)
}
reqSegIDMap[channelID] = append(
reqSegIDMap[channelID],
key,
)
}
}
var getSegmentID = func(channelID int32) UniqueID {
reqSegAllocateCounter[channelID]++
cur := reqSegAllocateCounter[channelID]
accumulateSlice := reqSegAccumulateCountMap[channelID]
segIDSlice := reqSegIDMap[channelID]
for index, count := range accumulateSlice {
if cur <= count {
return segIDSlice[index]
}
}
log.Warn("Can't Found SegmentID")
return 0
}
factor := 10
threshold := Params.PulsarMaxMessageSize / factor
log.Debug("ProxyNode", zap.Int("threshold of message size: ", threshold))
// not accurate
getFixedSizeOfInsertMsg := func(msg *msgstream.InsertMsg) int {
size := 0
size += int(unsafe.Sizeof(*msg.Base))
size += int(unsafe.Sizeof(msg.DbName))
size += int(unsafe.Sizeof(msg.CollectionName))
size += int(unsafe.Sizeof(msg.PartitionName))
size += int(unsafe.Sizeof(msg.DbID))
size += int(unsafe.Sizeof(msg.CollectionID))
size += int(unsafe.Sizeof(msg.PartitionID))
size += int(unsafe.Sizeof(msg.SegmentID))
size += int(unsafe.Sizeof(msg.ChannelID))
size += int(unsafe.Sizeof(msg.Timestamps))
size += int(unsafe.Sizeof(msg.RowIDs))
return size
}
result := make(map[int32]msgstream.TsMsg)
curMsgSizeMap := make(map[int32]int)
for i, request := range tsMsgs {
insertRequest := request.(*msgstream.InsertMsg)
keys := hashKeys[i]
collectionName := insertRequest.CollectionName
collectionID := insertRequest.CollectionID
partitionID := insertRequest.PartitionID
partitionName := insertRequest.PartitionName
proxyID := insertRequest.Base.SourceID
for index, key := range keys {
ts := insertRequest.Timestamps[index]
rowID := insertRequest.RowIDs[index]
row := insertRequest.RowData[index]
segmentID := getSegmentID(key)
_, ok := result[key]
if !ok {
sliceRequest := internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: reqID,
Timestamp: ts,
SourceID: proxyID,
},
CollectionID: collectionID,
PartitionID: partitionID,
CollectionName: collectionName,
PartitionName: partitionName,
SegmentID: segmentID,
// todo rename to ChannelName
ChannelID: channelNames[key],
}
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: request.TraceCtx(),
},
InsertRequest: sliceRequest,
}
result[key] = insertMsg
curMsgSizeMap[key] = getFixedSizeOfInsertMsg(insertMsg)
}
curMsg := result[key].(*msgstream.InsertMsg)
curMsgSize := curMsgSizeMap[key]
curMsg.HashValues = append(curMsg.HashValues, insertRequest.HashValues[index])
curMsg.Timestamps = append(curMsg.Timestamps, ts)
curMsg.RowIDs = append(curMsg.RowIDs, rowID)
curMsg.RowData = append(curMsg.RowData, row)
curMsgSize += 4 + 8 + int(unsafe.Sizeof(row.Value))
curMsgSize += len(row.Value)
if curMsgSize >= threshold {
newPack.Msgs = append(newPack.Msgs, curMsg)
delete(result, key)
curMsgSize = 0
}
curMsgSizeMap[key] = curMsgSize
}
}
for _, msg := range result {
if msg != nil {
newPack.Msgs = append(newPack.Msgs, msg)
}
}
return newPack, nil
}
func (it *InsertTask) Execute(ctx context.Context) error {
collectionName := it.BaseInsertTask.CollectionName
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
......@@ -254,7 +462,14 @@ func (it *InsertTask) Execute(ctx context.Context) error {
return err
}
err = stream.Produce(&msgPack)
// Assign SegmentID
var pack *msgstream.MsgPack
pack, err = it._assignSegmentID(stream, &msgPack)
if err != nil {
return err
}
err = stream.Produce(pack)
if err != nil {
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
it.result.Status.Reason = err.Error()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册