diff --git a/internal/dataservice/meta.go b/internal/dataservice/meta.go index b69cc746ef3ff0f3e93b61c0ee29dc3784175729..61b67842b4eef047b2df87acb037469b52028c30 100644 --- a/internal/dataservice/meta.go +++ b/internal/dataservice/meta.go @@ -183,7 +183,7 @@ func (meta *meta) UpdateSegment(segmentInfo *datapb.SegmentInfo) error { func (meta *meta) DropSegment(segmentID UniqueID) error { meta.ddLock.Lock() - meta.ddLock.Unlock() + defer meta.ddLock.Unlock() if _, ok := meta.segID2Info[segmentID]; !ok { return newErrSegmentNotFound(segmentID) diff --git a/internal/dataservice/segment_allocator.go b/internal/dataservice/segment_allocator.go index d15d3b66a049acaf54abd262e46e353aa65bbdbf..a85dd4111eba8314a23d4e28f7a36b986c6f4141 100644 --- a/internal/dataservice/segment_allocator.go +++ b/internal/dataservice/segment_allocator.go @@ -71,7 +71,7 @@ type ( } ) -func newSegmentAllocator(meta *meta, allocator allocator) (*segmentAllocatorImpl, error) { +func newSegmentAllocator(meta *meta, allocator allocator) *segmentAllocatorImpl { segmentAllocator := &segmentAllocatorImpl{ mt: meta, segments: make(map[UniqueID]*segmentStatus), @@ -80,7 +80,7 @@ func newSegmentAllocator(meta *meta, allocator allocator) (*segmentAllocatorImpl segmentThresholdFactor: Params.SegmentSizeFactor, allocator: allocator, } - return segmentAllocator, nil + return segmentAllocator } func (allocator *segmentAllocatorImpl) OpenSegment(segmentInfo *datapb.SegmentInfo) error { diff --git a/internal/dataservice/segment_allocator_test.go b/internal/dataservice/segment_allocator_test.go index d0e248edc0df55192373f890f87225e60ad07b4e..9f81783ef6dd86111bc4833f6b7a20e1fe07ae51 100644 --- a/internal/dataservice/segment_allocator_test.go +++ b/internal/dataservice/segment_allocator_test.go @@ -17,8 +17,7 @@ func TestAllocSegment(t *testing.T) { mockAllocator := newMockAllocator() meta, err := newMemoryMeta(mockAllocator) assert.Nil(t, err) - segAllocator, err := newSegmentAllocator(meta, mockAllocator) - assert.Nil(t, err) + segAllocator := newSegmentAllocator(meta, mockAllocator) schema := newTestSchema() collID, err := mockAllocator.allocID() @@ -68,8 +67,7 @@ func TestSealSegment(t *testing.T) { mockAllocator := newMockAllocator() meta, err := newMemoryMeta(mockAllocator) assert.Nil(t, err) - segAllocator, err := newSegmentAllocator(meta, mockAllocator) - assert.Nil(t, err) + segAllocator := newSegmentAllocator(meta, mockAllocator) schema := newTestSchema() collID, err := mockAllocator.allocID() @@ -105,8 +103,7 @@ func TestExpireSegment(t *testing.T) { mockAllocator := newMockAllocator() meta, err := newMemoryMeta(mockAllocator) assert.Nil(t, err) - segAllocator, err := newSegmentAllocator(meta, mockAllocator) - assert.Nil(t, err) + segAllocator := newSegmentAllocator(meta, mockAllocator) schema := newTestSchema() collID, err := mockAllocator.allocID() diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index a5537e11cba78e2a2621378d20f7e4f5342993ab..fcb2c5cce87aaeca457e041f04769ea6afd50192 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -134,10 +134,7 @@ func (s *Server) Start() error { return err } s.statsHandler = newStatsHandler(s.meta) - s.segAllocator, err = newSegmentAllocator(s.meta, s.allocator) - if err != nil { - return err - } + s.segAllocator = newSegmentAllocator(s.meta, s.allocator) s.ddHandler = newDDHandler(s.meta, s.segAllocator) s.initSegmentInfoChannel() if err = s.loadMetaFromMaster(); err != nil { diff --git a/internal/dataservice/watcher_test.go b/internal/dataservice/watcher_test.go index e476f9f02ab7749eccb41d638d51db489f8471c0..11e8fbe5f8872a9c0488439e34faa2eb40310216 100644 --- a/internal/dataservice/watcher_test.go +++ b/internal/dataservice/watcher_test.go @@ -21,7 +21,7 @@ func TestDataNodeTTWatcher(t *testing.T) { allocator := newMockAllocator() meta, err := newMemoryMeta(allocator) assert.Nil(t, err) - segAllocator, err := newSegmentAllocator(meta, allocator) + segAllocator := newSegmentAllocator(meta, allocator) assert.Nil(t, err) watcher := newDataNodeTimeTickWatcher(meta, segAllocator, cluster) diff --git a/internal/msgstream/pulsarms/pulsar_msgstream.go b/internal/msgstream/pulsarms/pulsar_msgstream.go index 5f9986861e04825ee525f45e7f3411694bc132f9..aedca490ca89b13c2faf547881e1e55fd5ed4c51 100644 --- a/internal/msgstream/pulsarms/pulsar_msgstream.go +++ b/internal/msgstream/pulsarms/pulsar_msgstream.go @@ -747,7 +747,7 @@ func checkTimeTickMsg(msg map[Consumer]Timestamp, for consumer := range msg { mu.RLock() v := msg[consumer] - mu.Unlock() + mu.RUnlock() if v != maxTime { isChannelReady[consumer] = false } else { diff --git a/internal/msgstream/rmqms/factory.go b/internal/msgstream/rmqms/factory.go new file mode 100644 index 0000000000000000000000000000000000000000..7e0e3902f79ea0a358d1f9394865a43de87a681b --- /dev/null +++ b/internal/msgstream/rmqms/factory.go @@ -0,0 +1,28 @@ +package rmqms + +import ( + "context" + + "github.com/zilliztech/milvus-distributed/internal/msgstream" +) + +type Factory struct { + dispatcherFactory msgstream.ProtoUDFactory + address string + receiveBufSize int64 + pulsarBufSize int64 +} + +func (f *Factory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) { + return newRmqMsgStream(ctx, f.receiveBufSize, f.dispatcherFactory.NewUnmarshalDispatcher()) +} + +func NewFactory(address string, receiveBufSize int64, pulsarBufSize int64) *Factory { + f := &Factory{ + dispatcherFactory: msgstream.ProtoUDFactory{}, + address: address, + receiveBufSize: receiveBufSize, + pulsarBufSize: pulsarBufSize, + } + return f +} diff --git a/internal/msgstream/rmqms/rmq_msgstream.go b/internal/msgstream/rmqms/rmq_msgstream.go index ec8cfa3a39618cae65fdefde1062e448d7b00c43..b2033f4464dc6fd3daae28c467137d31c6833185 100644 --- a/internal/msgstream/rmqms/rmq_msgstream.go +++ b/internal/msgstream/rmqms/rmq_msgstream.go @@ -16,6 +16,17 @@ import ( "github.com/zilliztech/milvus-distributed/internal/msgstream" ) +type TsMsg = msgstream.TsMsg +type MsgPack = msgstream.MsgPack +type MsgType = msgstream.MsgType +type UniqueID = msgstream.UniqueID +type BaseMsg = msgstream.BaseMsg +type Timestamp = msgstream.Timestamp +type IntPrimaryKey = msgstream.IntPrimaryKey +type TimeTickMsg = msgstream.TimeTickMsg +type QueryNodeStatsMsg = msgstream.QueryNodeStatsMsg +type RepackFunc = msgstream.RepackFunc + type RmqMsgStream struct { isServing int64 ctx context.Context @@ -23,7 +34,6 @@ type RmqMsgStream struct { serverLoopCtx context.Context serverLoopCancel func() - rmq *rocksmq.RocksMQ repackFunc msgstream.RepackFunc consumers []rocksmq.Consumer producers []string @@ -35,17 +45,18 @@ type RmqMsgStream struct { streamCancel func() } -func NewRmqMsgStream(ctx context.Context, rmq *rocksmq.RocksMQ, receiveBufSize int64) *RmqMsgStream { +func newRmqMsgStream(ctx context.Context, receiveBufSize int64, + unmarshal msgstream.UnmarshalDispatcher) (*RmqMsgStream, error) { streamCtx, streamCancel := context.WithCancel(ctx) receiveBuf := make(chan *msgstream.MsgPack, receiveBufSize) stream := &RmqMsgStream{ ctx: streamCtx, - rmq: nil, receiveBuf: receiveBuf, + unmarshal: unmarshal, streamCancel: streamCancel, } - return stream + return stream, nil } func (ms *RmqMsgStream) Start() { @@ -59,25 +70,32 @@ func (ms *RmqMsgStream) Start() { func (ms *RmqMsgStream) Close() { } -func (ms *RmqMsgStream) CreateProducers(channels []string) error { +type propertiesReaderWriter struct { + ppMap map[string]string +} + +func (ms *RmqMsgStream) SetRepackFunc(repackFunc RepackFunc) { + ms.repackFunc = repackFunc +} + +func (ms *RmqMsgStream) AsProducer(channels []string) { for _, channel := range channels { // TODO(yhz): Here may allow to create an existing channel - if err := ms.rmq.CreateChannel(channel); err != nil { - return err + if err := rocksmq.Rmq.CreateChannel(channel); err != nil { + errMsg := "Failed to create producer " + channel + ", error = " + err.Error() + panic(errMsg) } } - return nil } -func (ms *RmqMsgStream) CreateConsumers(channels []string, groupName string) error { +func (ms *RmqMsgStream) AsConsumer(channels []string, groupName string) { for _, channelName := range channels { - if err := ms.rmq.CreateConsumerGroup(groupName, channelName); err != nil { - return err + if err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName); err != nil { + panic(err.Error()) } msgNum := make(chan int) ms.consumers = append(ms.consumers, rocksmq.Consumer{GroupName: groupName, ChannelName: channelName, MsgNum: msgNum}) } - return nil } func (ms *RmqMsgStream) Produce(pack *msgstream.MsgPack) error { @@ -172,7 +190,30 @@ func (ms *RmqMsgStream) Produce(pack *msgstream.MsgPack) error { } msg := make([]rocksmq.ProducerMessage, 0) msg = append(msg, *rocksmq.NewProducerMessage(m)) - if err := ms.rmq.Produce(ms.producers[k], msg); err != nil { + if err := rocksmq.Rmq.Produce(ms.producers[k], msg); err != nil { + return err + } + } + } + return nil +} + +func (ms *RmqMsgStream) Broadcast(msgPack *MsgPack) error { + producerLen := len(ms.producers) + for _, v := range msgPack.Msgs { + mb, err := v.Marshal(v) + if err != nil { + return err + } + m, err := msgstream.ConvertToByteArray(mb) + if err != nil { + return err + } + msg := make([]rocksmq.ProducerMessage, 0) + msg = append(msg, *rocksmq.NewProducerMessage(m)) + + for i := 0; i < producerLen; i++ { + if err := rocksmq.Rmq.Produce(ms.producers[i], msg); err != nil { return err } } @@ -221,7 +262,7 @@ func (ms *RmqMsgStream) bufMsgPackToChannel() { } msgNum := value.Interface().(int) - rmqMsg, err := ms.rmq.Consume(ms.consumers[chosen].GroupName, ms.consumers[chosen].ChannelName, msgNum) + rmqMsg, err := rocksmq.Rmq.Consume(ms.consumers[chosen].GroupName, ms.consumers[chosen].ChannelName, msgNum) if err != nil { log.Printf("Failed to consume message in rocksmq, error = %v", err) continue @@ -261,5 +302,23 @@ func (ms *RmqMsgStream) bufMsgPackToChannel() { } func (ms *RmqMsgStream) Chan() <-chan *msgstream.MsgPack { - return nil + return ms.receiveBuf +} + +func (ms *RmqMsgStream) Seek(offset *msgstream.MsgPosition) error { + for i := 0; i < len(ms.consumers); i++ { + if ms.consumers[i].ChannelName == offset.ChannelName { + messageID, err := strconv.ParseInt(offset.MsgID, 10, 64) + if err != nil { + return err + } + err = rocksmq.Rmq.Seek(ms.consumers[i].GroupName, ms.consumers[i].ChannelName, messageID) + if err != nil { + return err + } + return nil + } + } + + return errors.New("msgStream seek fail") }