diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index 376e6ee55d504ffa1d1ebf5ab2080bcf3bc6951c..cd5814f69ab19711db38b60b0bc84cb476fd3d7e 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -2,10 +2,12 @@ package msgstream import ( "context" + "github.com/zilliztech/milvus-distributed/internal/errors" "log" "sync" "github.com/golang/protobuf/proto" + commonPb "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" "github.com/apache/pulsar-client-go/pulsar" @@ -22,7 +24,7 @@ type MsgPack struct { Msgs []*TsMsg } -type RepackFunc func(msgs []*TsMsg, hashKeys [][]int32) map[int32]*MsgPack +type RepackFunc func(msgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) type MsgStream interface { Start() @@ -145,22 +147,23 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error { } var result map[int32]*MsgPack + var err error if ms.repackFunc != nil { - result = ms.repackFunc(tsMsgs, reBucketValues) + result, err = ms.repackFunc(tsMsgs, reBucketValues) } else { - result = make(map[int32]*MsgPack) - for i, request := range tsMsgs { - keys := reBucketValues[i] - for _, channelID := range keys { - _, ok := result[channelID] - if !ok { - msgPack := MsgPack{} - result[channelID] = &msgPack - } - result[channelID].Msgs = append(result[channelID].Msgs, request) - } + msgType := (*tsMsgs[0]).Type() + switch msgType { + case internalPb.MsgType_kInsert: + result, err = insertRepackFunc(tsMsgs, reBucketValues) + case internalPb.MsgType_kDelete: + result, err = deleteRepackFunc(tsMsgs, reBucketValues) + default: + result, err = defaultRepackFunc(tsMsgs, reBucketValues) } } + if err != nil { + return err + } for k, v := range result { for i := 0; i < len(v.Msgs); i++ { mb, err := (*v.Msgs[i]).Marshal(v.Msgs[i]) @@ -381,3 +384,113 @@ func checkTimeTickMsg(msg map[int]Timestamp) (Timestamp, bool) { } return 0, false } + +func insertRepackFunc(tsMsgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { + result := make(map[int32]*MsgPack) + for i, request := range tsMsgs { + if (*request).Type() != internalPb.MsgType_kInsert { + return nil, errors.New(string("msg's must be Insert")) + } + insertRequest := (*request).(*InsertMsg) + keys := hashKeys[i] + + timestampLen := len(insertRequest.Timestamps) + rowIDLen := len(insertRequest.RowIds) + rowDataLen := len(insertRequest.RowData) + keysLen := len(keys) + + if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { + return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal")) + } + for index, key := range keys { + _, ok := result[key] + if !ok { + msgPack := MsgPack{} + result[key] = &msgPack + } + + sliceRequest := internalPb.InsertRequest{ + MsgType: internalPb.MsgType_kInsert, + ReqId: insertRequest.ReqId, + CollectionName: insertRequest.CollectionName, + PartitionTag: insertRequest.PartitionTag, + SegmentId: insertRequest.SegmentId, + ChannelId: insertRequest.ChannelId, + ProxyId: insertRequest.ProxyId, + Timestamps: []uint64{insertRequest.Timestamps[index]}, + RowIds: []int64{insertRequest.RowIds[index]}, + RowData: []*commonPb.Blob{insertRequest.RowData[index]}, + } + + var msg TsMsg = &InsertMsg{ + InsertRequest: sliceRequest, + } + + result[key].Msgs = append(result[key].Msgs, &msg) + } + } + return result, nil +} + +func deleteRepackFunc(tsMsgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { + result := make(map[int32]*MsgPack) + for i, request := range tsMsgs { + if (*request).Type() != internalPb.MsgType_kDelete { + return nil, errors.New(string("msg's must be Delete")) + } + deleteRequest := (*request).(*DeleteMsg) + keys := hashKeys[i] + + timestampLen := len(deleteRequest.Timestamps) + primaryKeysLen := len(deleteRequest.PrimaryKeys) + keysLen := len(keys) + + if keysLen != timestampLen || keysLen != primaryKeysLen { + return nil, errors.New(string("the length of hashValue, timestamps, primaryKeys are not equal")) + } + + for index, key := range keys { + _, ok := result[key] + if !ok { + msgPack := MsgPack{} + result[key] = &msgPack + } + + sliceRequest := internalPb.DeleteRequest{ + MsgType: internalPb.MsgType_kDelete, + ReqId: deleteRequest.ReqId, + CollectionName: deleteRequest.CollectionName, + ChannelId: deleteRequest.ChannelId, + ProxyId: deleteRequest.ProxyId, + Timestamps: []uint64{deleteRequest.Timestamps[index]}, + PrimaryKeys: []int64{deleteRequest.PrimaryKeys[index]}, + } + + var msg TsMsg = &DeleteMsg{ + DeleteRequest: sliceRequest, + } + + result[key].Msgs = append(result[key].Msgs, &msg) + } + } + return result, nil +} + +func defaultRepackFunc(tsMsgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { + result := make(map[int32]*MsgPack) + for i, request := range tsMsgs { + keys := hashKeys[i] + if len(keys) != 1 { + return nil, errors.New(string("len(msg.hashValue) must equal 1")) + } + key := keys[0] + _, ok := result[key] + if !ok { + msgPack := MsgPack{} + result[key] = &msgPack + } + + result[key].Msgs = append(result[key].Msgs, request) + } + return result, nil +} diff --git a/internal/msgstream/msgstream_test.go b/internal/msgstream/msgstream_test.go index a0060583be3788c690cc000323e901b6ef7e0492..c168961cc57f2aafe476f56c9a09a4a5f9586247 100644 --- a/internal/msgstream/msgstream_test.go +++ b/internal/msgstream/msgstream_test.go @@ -3,13 +3,14 @@ package msgstream import ( "context" "fmt" + "log" "testing" commonPb "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" ) -func repackFunc(msgs []*TsMsg, hashKeys [][]int32) map[int32]*MsgPack { +func repackFunc(msgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { result := make(map[int32]*MsgPack) for i, request := range msgs { keys := hashKeys[i] @@ -22,7 +23,7 @@ func repackFunc(msgs []*TsMsg, hashKeys [][]int32) map[int32]*MsgPack { result[channelID].Msgs = append(result[channelID].Msgs, request) } } - return result + return result, nil } func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) *TsMsg { @@ -43,6 +44,8 @@ func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) *TsMsg { ChannelId: 1, ProxyId: 1, Timestamps: []Timestamp{1}, + RowIds: []int64{1}, + RowData: []*commonPb.Blob{{}}, } insertMsg := &InsertMsg{ BaseMsg: baseMsg, @@ -209,7 +212,11 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kInsert, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - (*inputStream).Produce(&msgPack) + err := (*inputStream).Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -227,7 +234,10 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kDelete, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - (*inputStream).Produce(&msgPack) + err := (*inputStream).Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -244,7 +254,10 @@ func TestStream_PulsarMsgStream_Search(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kSearch, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - (*inputStream).Produce(&msgPack) + err := (*inputStream).Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -261,7 +274,10 @@ func TestStream_PulsarMsgStream_SearchResult(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kSearchResult, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - (*inputStream).Produce(&msgPack) + err := (*inputStream).Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -278,7 +294,10 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kTimeTick, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - (*inputStream).Produce(&msgPack) + err := (*inputStream).Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -295,7 +314,10 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kTimeTick, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - (*inputStream).Broadcast(&msgPack) + err := (*inputStream).Broadcast(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } receiveMsg(outputStream, len(consumerChannels)*len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -312,12 +334,164 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kInsert, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName, repackFunc) - (*inputStream).Produce(&msgPack) + err := (*inputStream).Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() } +func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { + pulsarAddress := "pulsar://localhost:6650" + producerChannels := []string{"insert1", "insert2"} + consumerChannels := []string{"insert1", "insert2"} + consumerSubName := "subInsert" + + baseMsg := BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []int32{1, 3}, + } + + insertRequest := internalPb.InsertRequest{ + MsgType: internalPb.MsgType_kInsert, + ReqId: 1, + CollectionName: "Collection", + PartitionTag: "Partition", + SegmentId: 1, + ChannelId: 1, + ProxyId: 1, + Timestamps: []Timestamp{1, 1}, + RowIds: []int64{1, 3}, + RowData: []*commonPb.Blob{{}, {}}, + } + insertMsg := &InsertMsg{ + BaseMsg: baseMsg, + InsertRequest: insertRequest, + } + var tsMsg TsMsg = insertMsg + msgPack := MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, &tsMsg) + + inputStream := NewPulsarMsgStream(context.Background(), 100) + inputStream.SetPulsarCient(pulsarAddress) + inputStream.CreatePulsarProducers(producerChannels) + inputStream.Start() + + outputStream := NewPulsarMsgStream(context.Background(), 100) + outputStream.SetPulsarCient(pulsarAddress) + unmarshalDispatcher := NewUnmarshalDispatcher() + outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) + outputStream.Start() + var output MsgStream = outputStream + + err := (*inputStream).Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(&output, len(msgPack.Msgs)*2) + (*inputStream).Close() + (*outputStream).Close() +} + +func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) { + pulsarAddress := "pulsar://localhost:6650" + producerChannels := []string{"insert1", "insert2"} + consumerChannels := []string{"insert1", "insert2"} + consumerSubName := "subInsert" + + baseMsg := BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []int32{1, 3}, + } + + deleteRequest := internalPb.DeleteRequest{ + MsgType: internalPb.MsgType_kDelete, + ReqId: 1, + CollectionName: "Collection", + ChannelId: 1, + ProxyId: 1, + Timestamps: []Timestamp{1, 1}, + PrimaryKeys: []int64{1, 3}, + } + deleteMsg := &DeleteMsg{ + BaseMsg: baseMsg, + DeleteRequest: deleteRequest, + } + var tsMsg TsMsg = deleteMsg + msgPack := MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, &tsMsg) + + inputStream := NewPulsarMsgStream(context.Background(), 100) + inputStream.SetPulsarCient(pulsarAddress) + inputStream.CreatePulsarProducers(producerChannels) + inputStream.Start() + + outputStream := NewPulsarMsgStream(context.Background(), 100) + outputStream.SetPulsarCient(pulsarAddress) + unmarshalDispatcher := NewUnmarshalDispatcher() + outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) + outputStream.Start() + var output MsgStream = outputStream + + err := (*inputStream).Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(&output, len(msgPack.Msgs)*2) + (*inputStream).Close() + (*outputStream).Close() +} + +func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { + pulsarAddress := "pulsar://localhost:6650" + producerChannels := []string{"insert1", "insert2"} + consumerChannels := []string{"insert1", "insert2"} + consumerSubName := "subInsert" + + baseMsg := BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []int32{1}, + } + + timeTickRequest := internalPb.TimeTickMsg{ + MsgType: internalPb.MsgType_kTimeTick, + PeerId: int64(1), + Timestamp: uint64(1), + } + timeTick := &TimeTickMsg{ + BaseMsg: baseMsg, + TimeTickMsg: timeTickRequest, + } + var tsMsg TsMsg = timeTick + msgPack := MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, &tsMsg) + + inputStream := NewPulsarMsgStream(context.Background(), 100) + inputStream.SetPulsarCient(pulsarAddress) + inputStream.CreatePulsarProducers(producerChannels) + inputStream.Start() + + outputStream := NewPulsarMsgStream(context.Background(), 100) + outputStream.SetPulsarCient(pulsarAddress) + unmarshalDispatcher := NewUnmarshalDispatcher() + outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) + outputStream.Start() + var output MsgStream = outputStream + + err := (*inputStream).Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(&output, len(msgPack.Msgs)) + (*inputStream).Close() + (*outputStream).Close() +} + func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { pulsarAddress := "pulsar://localhost:6650" producerChannels := []string{"insert1", "insert2"} @@ -335,9 +509,18 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5, 5, 5)) inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - (*inputStream).Broadcast(&msgPack0) - (*inputStream).Produce(&msgPack1) - (*inputStream).Broadcast(&msgPack2) + err := (*inputStream).Broadcast(&msgPack0) + if err != nil { + log.Fatalf("broadcast error = %v", err) + } + err = (*inputStream).Produce(&msgPack1) + if err != nil { + log.Fatalf("produce error = %v", err) + } + err = (*inputStream).Broadcast(&msgPack2) + if err != nil { + log.Fatalf("broadcast error = %v", err) + } receiveMsg(outputStream, len(msgPack1.Msgs)) outputTtStream := (*outputStream).(*PulsarTtMsgStream) fmt.Printf("timestamp = %v", outputTtStream.lastTimeStamp) diff --git a/internal/msgstream/task_test.go b/internal/msgstream/task_test.go index 4755adef8e30738125a35e93adbcceefedb901bc..3c1dc426c3a9b3cd51587c88e7dae81d953655a1 100644 --- a/internal/msgstream/task_test.go +++ b/internal/msgstream/task_test.go @@ -2,7 +2,10 @@ package msgstream import ( "context" + "errors" "fmt" + commonPb "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "log" "testing" "github.com/golang/protobuf/proto" @@ -37,6 +40,53 @@ func (tt *InsertTask) Unmarshal(input []byte) (*TsMsg, error) { return &tsMsg, nil } +func newRepackFunc(tsMsgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { + result := make(map[int32]*MsgPack) + for i, request := range tsMsgs { + if (*request).Type() != internalPb.MsgType_kInsert { + return nil, errors.New(string("msg's must be Insert")) + } + insertRequest := (*request).(*InsertTask).InsertRequest + keys := hashKeys[i] + + timestampLen := len(insertRequest.Timestamps) + rowIDLen := len(insertRequest.RowIds) + rowDataLen := len(insertRequest.RowData) + keysLen := len(keys) + + if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { + return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal")) + } + for index, key := range keys { + _, ok := result[key] + if !ok { + msgPack := MsgPack{} + result[key] = &msgPack + } + + sliceRequest := internalPb.InsertRequest{ + MsgType: internalPb.MsgType_kInsert, + ReqId: insertRequest.ReqId, + CollectionName: insertRequest.CollectionName, + PartitionTag: insertRequest.PartitionTag, + SegmentId: insertRequest.SegmentId, + ChannelId: insertRequest.ChannelId, + ProxyId: insertRequest.ProxyId, + Timestamps: []uint64{insertRequest.Timestamps[index]}, + RowIds: []int64{insertRequest.RowIds[index]}, + RowData: []*commonPb.Blob{insertRequest.RowData[index]}, + } + + var msg TsMsg = &InsertTask{ + InsertMsg: InsertMsg{InsertRequest: sliceRequest}, + } + + result[key].Msgs = append(result[key].Msgs, &msg) + } + } + return result, nil +} + func getMsg(reqID UniqueID, hashValue int32) *TsMsg { var tsMsg TsMsg baseMsg := BaseMsg{ @@ -53,6 +103,8 @@ func getMsg(reqID UniqueID, hashValue int32) *TsMsg { ChannelId: 1, ProxyId: 1, Timestamps: []Timestamp{1}, + RowIds: []int64{1}, + RowData: []*commonPb.Blob{{}}, } insertMsg := InsertMsg{ BaseMsg: baseMsg, @@ -79,6 +131,7 @@ func TestStream_task_Insert(t *testing.T) { inputStream := NewPulsarMsgStream(context.Background(), 100) inputStream.SetPulsarCient(pulsarAddress) inputStream.CreatePulsarProducers(producerChannels) + inputStream.SetRepackFunc(newRepackFunc) inputStream.Start() outputStream := NewPulsarMsgStream(context.Background(), 100) @@ -89,7 +142,10 @@ func TestStream_task_Insert(t *testing.T) { outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) outputStream.Start() - inputStream.Produce(&msgPack) + err := inputStream.Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } receiveCount := 0 for { result := (*outputStream).Consume() diff --git a/internal/msgstream/unmarshal_test.go b/internal/msgstream/unmarshal_test.go index 0c12faa71829c0d79db2adeab3f7e796bd521dd2..24812eb520cccdc10c09816d693dd9b1d4f57f19 100644 --- a/internal/msgstream/unmarshal_test.go +++ b/internal/msgstream/unmarshal_test.go @@ -3,6 +3,7 @@ package msgstream import ( "context" "fmt" + "log" "testing" "github.com/golang/protobuf/proto" @@ -47,7 +48,10 @@ func TestStream_unmarshal_Insert(t *testing.T) { outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) outputStream.Start() - inputStream.Produce(&msgPack) + err := inputStream.Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } receiveCount := 0 for { result := (*outputStream).Consume()