From e3fadc45d43c64e0a7c0ceb391751dc681022c5f Mon Sep 17 00:00:00 2001 From: yukun Date: Thu, 4 Feb 2021 15:23:21 +0800 Subject: [PATCH] Fix for new msgstream interface Signed-off-by: yukun --- internal/dataservice/meta.go | 2 +- internal/dataservice/segment_allocator.go | 4 +- .../dataservice/segment_allocator_test.go | 9 +- internal/dataservice/server.go | 5 +- internal/dataservice/watcher_test.go | 2 +- .../msgstream/pulsarms/pulsar_msgstream.go | 2 +- internal/msgstream/rmqms/factory.go | 28 ++++++ internal/msgstream/rmqms/rmq_msgstream.go | 89 +++++++++++++++---- 8 files changed, 111 insertions(+), 30 deletions(-) create mode 100644 internal/msgstream/rmqms/factory.go diff --git a/internal/dataservice/meta.go b/internal/dataservice/meta.go index b69cc746e..61b67842b 100644 --- a/internal/dataservice/meta.go +++ b/internal/dataservice/meta.go @@ -183,7 +183,7 @@ func (meta *meta) UpdateSegment(segmentInfo *datapb.SegmentInfo) error { func (meta *meta) DropSegment(segmentID UniqueID) error { meta.ddLock.Lock() - meta.ddLock.Unlock() + defer meta.ddLock.Unlock() if _, ok := meta.segID2Info[segmentID]; !ok { return newErrSegmentNotFound(segmentID) diff --git a/internal/dataservice/segment_allocator.go b/internal/dataservice/segment_allocator.go index d15d3b66a..a85dd4111 100644 --- a/internal/dataservice/segment_allocator.go +++ b/internal/dataservice/segment_allocator.go @@ -71,7 +71,7 @@ type ( } ) -func newSegmentAllocator(meta *meta, allocator allocator) (*segmentAllocatorImpl, error) { +func newSegmentAllocator(meta *meta, allocator allocator) *segmentAllocatorImpl { segmentAllocator := &segmentAllocatorImpl{ mt: meta, segments: make(map[UniqueID]*segmentStatus), @@ -80,7 +80,7 @@ func newSegmentAllocator(meta *meta, allocator allocator) (*segmentAllocatorImpl segmentThresholdFactor: Params.SegmentSizeFactor, allocator: allocator, } - return segmentAllocator, nil + return segmentAllocator } func (allocator *segmentAllocatorImpl) OpenSegment(segmentInfo *datapb.SegmentInfo) error { diff --git a/internal/dataservice/segment_allocator_test.go b/internal/dataservice/segment_allocator_test.go index d0e248edc..9f81783ef 100644 --- a/internal/dataservice/segment_allocator_test.go +++ b/internal/dataservice/segment_allocator_test.go @@ -17,8 +17,7 @@ func TestAllocSegment(t *testing.T) { mockAllocator := newMockAllocator() meta, err := newMemoryMeta(mockAllocator) assert.Nil(t, err) - segAllocator, err := newSegmentAllocator(meta, mockAllocator) - assert.Nil(t, err) + segAllocator := newSegmentAllocator(meta, mockAllocator) schema := newTestSchema() collID, err := mockAllocator.allocID() @@ -68,8 +67,7 @@ func TestSealSegment(t *testing.T) { mockAllocator := newMockAllocator() meta, err := newMemoryMeta(mockAllocator) assert.Nil(t, err) - segAllocator, err := newSegmentAllocator(meta, mockAllocator) - assert.Nil(t, err) + segAllocator := newSegmentAllocator(meta, mockAllocator) schema := newTestSchema() collID, err := mockAllocator.allocID() @@ -105,8 +103,7 @@ func TestExpireSegment(t *testing.T) { mockAllocator := newMockAllocator() meta, err := newMemoryMeta(mockAllocator) assert.Nil(t, err) - segAllocator, err := newSegmentAllocator(meta, mockAllocator) - assert.Nil(t, err) + segAllocator := newSegmentAllocator(meta, mockAllocator) schema := newTestSchema() collID, err := mockAllocator.allocID() diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index a5537e11c..fcb2c5cce 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -134,10 +134,7 @@ func (s *Server) Start() error { return err } s.statsHandler = newStatsHandler(s.meta) - s.segAllocator, err = newSegmentAllocator(s.meta, s.allocator) - if err != nil { - return err - } + s.segAllocator = newSegmentAllocator(s.meta, s.allocator) s.ddHandler = newDDHandler(s.meta, s.segAllocator) s.initSegmentInfoChannel() if err = s.loadMetaFromMaster(); err != nil { diff --git a/internal/dataservice/watcher_test.go b/internal/dataservice/watcher_test.go index e476f9f02..11e8fbe5f 100644 --- a/internal/dataservice/watcher_test.go +++ b/internal/dataservice/watcher_test.go @@ -21,7 +21,7 @@ func TestDataNodeTTWatcher(t *testing.T) { allocator := newMockAllocator() meta, err := newMemoryMeta(allocator) assert.Nil(t, err) - segAllocator, err := newSegmentAllocator(meta, allocator) + segAllocator := newSegmentAllocator(meta, allocator) assert.Nil(t, err) watcher := newDataNodeTimeTickWatcher(meta, segAllocator, cluster) diff --git a/internal/msgstream/pulsarms/pulsar_msgstream.go b/internal/msgstream/pulsarms/pulsar_msgstream.go index 5f9986861..aedca490c 100644 --- a/internal/msgstream/pulsarms/pulsar_msgstream.go +++ b/internal/msgstream/pulsarms/pulsar_msgstream.go @@ -747,7 +747,7 @@ func checkTimeTickMsg(msg map[Consumer]Timestamp, for consumer := range msg { mu.RLock() v := msg[consumer] - mu.Unlock() + mu.RUnlock() if v != maxTime { isChannelReady[consumer] = false } else { diff --git a/internal/msgstream/rmqms/factory.go b/internal/msgstream/rmqms/factory.go new file mode 100644 index 000000000..7e0e3902f --- /dev/null +++ b/internal/msgstream/rmqms/factory.go @@ -0,0 +1,28 @@ +package rmqms + +import ( + "context" + + "github.com/zilliztech/milvus-distributed/internal/msgstream" +) + +type Factory struct { + dispatcherFactory msgstream.ProtoUDFactory + address string + receiveBufSize int64 + pulsarBufSize int64 +} + +func (f *Factory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) { + return newRmqMsgStream(ctx, f.receiveBufSize, f.dispatcherFactory.NewUnmarshalDispatcher()) +} + +func NewFactory(address string, receiveBufSize int64, pulsarBufSize int64) *Factory { + f := &Factory{ + dispatcherFactory: msgstream.ProtoUDFactory{}, + address: address, + receiveBufSize: receiveBufSize, + pulsarBufSize: pulsarBufSize, + } + return f +} diff --git a/internal/msgstream/rmqms/rmq_msgstream.go b/internal/msgstream/rmqms/rmq_msgstream.go index ec8cfa3a3..b2033f446 100644 --- a/internal/msgstream/rmqms/rmq_msgstream.go +++ b/internal/msgstream/rmqms/rmq_msgstream.go @@ -16,6 +16,17 @@ import ( "github.com/zilliztech/milvus-distributed/internal/msgstream" ) +type TsMsg = msgstream.TsMsg +type MsgPack = msgstream.MsgPack +type MsgType = msgstream.MsgType +type UniqueID = msgstream.UniqueID +type BaseMsg = msgstream.BaseMsg +type Timestamp = msgstream.Timestamp +type IntPrimaryKey = msgstream.IntPrimaryKey +type TimeTickMsg = msgstream.TimeTickMsg +type QueryNodeStatsMsg = msgstream.QueryNodeStatsMsg +type RepackFunc = msgstream.RepackFunc + type RmqMsgStream struct { isServing int64 ctx context.Context @@ -23,7 +34,6 @@ type RmqMsgStream struct { serverLoopCtx context.Context serverLoopCancel func() - rmq *rocksmq.RocksMQ repackFunc msgstream.RepackFunc consumers []rocksmq.Consumer producers []string @@ -35,17 +45,18 @@ type RmqMsgStream struct { streamCancel func() } -func NewRmqMsgStream(ctx context.Context, rmq *rocksmq.RocksMQ, receiveBufSize int64) *RmqMsgStream { +func newRmqMsgStream(ctx context.Context, receiveBufSize int64, + unmarshal msgstream.UnmarshalDispatcher) (*RmqMsgStream, error) { streamCtx, streamCancel := context.WithCancel(ctx) receiveBuf := make(chan *msgstream.MsgPack, receiveBufSize) stream := &RmqMsgStream{ ctx: streamCtx, - rmq: nil, receiveBuf: receiveBuf, + unmarshal: unmarshal, streamCancel: streamCancel, } - return stream + return stream, nil } func (ms *RmqMsgStream) Start() { @@ -59,25 +70,32 @@ func (ms *RmqMsgStream) Start() { func (ms *RmqMsgStream) Close() { } -func (ms *RmqMsgStream) CreateProducers(channels []string) error { +type propertiesReaderWriter struct { + ppMap map[string]string +} + +func (ms *RmqMsgStream) SetRepackFunc(repackFunc RepackFunc) { + ms.repackFunc = repackFunc +} + +func (ms *RmqMsgStream) AsProducer(channels []string) { for _, channel := range channels { // TODO(yhz): Here may allow to create an existing channel - if err := ms.rmq.CreateChannel(channel); err != nil { - return err + if err := rocksmq.Rmq.CreateChannel(channel); err != nil { + errMsg := "Failed to create producer " + channel + ", error = " + err.Error() + panic(errMsg) } } - return nil } -func (ms *RmqMsgStream) CreateConsumers(channels []string, groupName string) error { +func (ms *RmqMsgStream) AsConsumer(channels []string, groupName string) { for _, channelName := range channels { - if err := ms.rmq.CreateConsumerGroup(groupName, channelName); err != nil { - return err + if err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName); err != nil { + panic(err.Error()) } msgNum := make(chan int) ms.consumers = append(ms.consumers, rocksmq.Consumer{GroupName: groupName, ChannelName: channelName, MsgNum: msgNum}) } - return nil } func (ms *RmqMsgStream) Produce(pack *msgstream.MsgPack) error { @@ -172,7 +190,30 @@ func (ms *RmqMsgStream) Produce(pack *msgstream.MsgPack) error { } msg := make([]rocksmq.ProducerMessage, 0) msg = append(msg, *rocksmq.NewProducerMessage(m)) - if err := ms.rmq.Produce(ms.producers[k], msg); err != nil { + if err := rocksmq.Rmq.Produce(ms.producers[k], msg); err != nil { + return err + } + } + } + return nil +} + +func (ms *RmqMsgStream) Broadcast(msgPack *MsgPack) error { + producerLen := len(ms.producers) + for _, v := range msgPack.Msgs { + mb, err := v.Marshal(v) + if err != nil { + return err + } + m, err := msgstream.ConvertToByteArray(mb) + if err != nil { + return err + } + msg := make([]rocksmq.ProducerMessage, 0) + msg = append(msg, *rocksmq.NewProducerMessage(m)) + + for i := 0; i < producerLen; i++ { + if err := rocksmq.Rmq.Produce(ms.producers[i], msg); err != nil { return err } } @@ -221,7 +262,7 @@ func (ms *RmqMsgStream) bufMsgPackToChannel() { } msgNum := value.Interface().(int) - rmqMsg, err := ms.rmq.Consume(ms.consumers[chosen].GroupName, ms.consumers[chosen].ChannelName, msgNum) + rmqMsg, err := rocksmq.Rmq.Consume(ms.consumers[chosen].GroupName, ms.consumers[chosen].ChannelName, msgNum) if err != nil { log.Printf("Failed to consume message in rocksmq, error = %v", err) continue @@ -261,5 +302,23 @@ func (ms *RmqMsgStream) bufMsgPackToChannel() { } func (ms *RmqMsgStream) Chan() <-chan *msgstream.MsgPack { - return nil + return ms.receiveBuf +} + +func (ms *RmqMsgStream) Seek(offset *msgstream.MsgPosition) error { + for i := 0; i < len(ms.consumers); i++ { + if ms.consumers[i].ChannelName == offset.ChannelName { + messageID, err := strconv.ParseInt(offset.MsgID, 10, 64) + if err != nil { + return err + } + err = rocksmq.Rmq.Seek(ms.consumers[i].GroupName, ms.consumers[i].ChannelName, messageID) + if err != nil { + return err + } + return nil + } + } + + return errors.New("msgStream seek fail") } -- GitLab