未验证 提交 e7117f67 编写于 作者: C congqixia 提交者: GitHub

Add BroadcastMark for Msgstream returning MessageIDs (#8654)

Signed-off-by: NCongqi Xia <congqi.xia@zilliz.com>
上级 ba2eb746
......@@ -75,6 +75,9 @@ func (mtm *mockTtMsgStream) Produce(*msgstream.MsgPack) error {
func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) error {
return nil
}
func (mtm *mockTtMsgStream) BroadcastMark(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return map[string][]msgstream.MessageID{}, nil
}
func (mtm *mockTtMsgStream) Consume() *msgstream.MsgPack {
return nil
}
......
......@@ -32,6 +32,8 @@ import (
"github.com/milvus-io/milvus/internal/util/trace"
)
var _ MsgStream = (*mqMsgStream)(nil)
type mqMsgStream struct {
ctx context.Context
client mqclient.Client
......@@ -262,6 +264,8 @@ func (ms *mqMsgStream) Produce(msgPack *MsgPack) error {
return nil
}
// Broadcast put msgPack to all producer in current msgstream
// which ignores repackFunc logic
func (ms *mqMsgStream) Broadcast(msgPack *MsgPack) error {
if msgPack == nil || len(msgPack.Msgs) <= 0 {
log.Debug("Warning: Receive empty msgPack")
......@@ -302,6 +306,47 @@ func (ms *mqMsgStream) Broadcast(msgPack *MsgPack) error {
return nil
}
// BroadcastMark broadcast msg pack to all producers and returns corresponding msg id
// the returned message id serves as marking
func (ms *mqMsgStream) BroadcastMark(msgPack *MsgPack) (map[string][]MessageID, error) {
ids := make(map[string][]MessageID)
if msgPack == nil || len(msgPack.Msgs) <= 0 {
return ids, errors.New("empty msgs")
}
for _, v := range msgPack.Msgs {
sp, spanCtx := MsgSpanFromCtx(v.TraceCtx(), v)
mb, err := v.Marshal(v)
if err != nil {
return ids, err
}
m, err := convertToByteArray(mb)
if err != nil {
return ids, err
}
msg := &mqclient.ProducerMessage{Payload: m, Properties: map[string]string{}}
trace.InjectContextToPulsarMsgProperties(sp.Context(), msg.Properties)
ms.producerLock.Lock()
for channel, producer := range ms.producers {
id, err := producer.Send(spanCtx, msg)
if err != nil {
ms.producerLock.Unlock()
trace.LogError(sp, err)
sp.Finish()
return ids, err
}
ids[channel] = append(ids[channel], id)
}
ms.producerLock.Unlock()
sp.Finish()
}
return ids, nil
}
func (ms *mqMsgStream) Consume() *MsgPack {
for {
select {
......@@ -418,6 +463,8 @@ func (ms *mqMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
return nil
}
var _ MsgStream = (*MqTtMsgStream)(nil)
// MqTtMsgStream is a msgstream that contains timeticks
type MqTtMsgStream struct {
mqMsgStream
......
......@@ -13,6 +13,7 @@ package msgstream
import (
"context"
"errors"
"log"
"math/rand"
"os"
......@@ -951,6 +952,105 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
Close(rocksdbName, inputStream, outputStream, etcdKV)
}
func TestStream_BroadcastMark(t *testing.T) {
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1 := funcutil.RandomString(8)
c2 := funcutil.RandomString(8)
producerChannels := []string{c1, c2}
factory := ProtoUDFactory{}
pulsarClient, err := mqclient.GetPulsarClientInstance(pulsar.ClientOptions{URL: pulsarAddress})
assert.Nil(t, err)
outputStream, err := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
assert.Nil(t, err)
// add producer channels
outputStream.AsProducer(producerChannels)
outputStream.Start()
msgPack0 := MsgPack{}
msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0))
ids, err := outputStream.BroadcastMark(&msgPack0)
assert.Nil(t, err)
assert.NotNil(t, ids)
assert.Equal(t, len(producerChannels), len(ids))
for _, c := range producerChannels {
ids, ok := ids[c]
assert.True(t, ok)
assert.Equal(t, len(msgPack0.Msgs), len(ids))
}
msgPack1 := MsgPack{}
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
ids, err = outputStream.BroadcastMark(&msgPack1)
assert.Nil(t, err)
assert.NotNil(t, ids)
assert.Equal(t, len(producerChannels), len(ids))
for _, c := range producerChannels {
ids, ok := ids[c]
assert.True(t, ok)
assert.Equal(t, len(msgPack1.Msgs), len(ids))
}
// edge cases
_, err = outputStream.BroadcastMark(nil)
assert.NotNil(t, err)
msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{})
_, err = outputStream.BroadcastMark(&msgPack2)
assert.NotNil(t, err)
// mock send fail
for k, p := range outputStream.producers {
outputStream.producers[k] = &mockSendFailProducer{Producer: p}
}
_, err = outputStream.BroadcastMark(&msgPack1)
assert.NotNil(t, err)
outputStream.Close()
}
var _ TsMsg = (*MarshalFailTsMsg)(nil)
type MarshalFailTsMsg struct {
BaseMsg
}
func (t *MarshalFailTsMsg) ID() UniqueID {
return 0
}
func (t *MarshalFailTsMsg) Type() MsgType {
return commonpb.MsgType_Undefined
}
func (t *MarshalFailTsMsg) SourceID() int64 {
return -1
}
func (t *MarshalFailTsMsg) Marshal(_ TsMsg) (MarshalType, error) {
return nil, errors.New("mocked error")
}
func (t *MarshalFailTsMsg) Unmarshal(_ MarshalType) (TsMsg, error) {
return nil, errors.New("mocked error")
}
var _ mqclient.Producer = (*mockSendFailProducer)(nil)
type mockSendFailProducer struct {
mqclient.Producer
}
func (p *mockSendFailProducer) Send(_ context.Context, _ *mqclient.ProducerMessage) (MessageID, error) {
return nil, errors.New("mocked error")
}
/* ========================== Utility functions ========================== */
func repackFunc(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) {
result := make(map[int32]*MsgPack)
......
......@@ -15,6 +15,7 @@ import (
"context"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/mqclient"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
......@@ -22,6 +23,7 @@ type UniqueID = typeutil.UniqueID
type Timestamp = typeutil.Timestamp
type IntPrimaryKey = typeutil.IntPrimaryKey
type MsgPosition = internalpb.MsgPosition
type MessageID = mqclient.MessageID
// MsgPack represents a batch of msg in msgstream
type MsgPack struct {
......@@ -46,6 +48,7 @@ type MsgStream interface {
GetProduceChannels() []string
Produce(*MsgPack) error
Broadcast(*MsgPack) error
BroadcastMark(*MsgPack) (map[string][]MessageID, error)
Consume() *MsgPack
Seek(offset []*MsgPosition) error
}
......
......@@ -329,6 +329,10 @@ func (ms *simpleMockMsgStream) Broadcast(pack *msgstream.MsgPack) error {
return nil
}
func (ms *simpleMockMsgStream) BroadcastMark(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return map[string][]msgstream.MessageID{}, nil
}
func (ms *simpleMockMsgStream) GetProduceChannels() []string {
return nil
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册