diff --git a/internal/msgstream/mq_msgstream.go b/internal/msgstream/mq_msgstream.go index 26c4f52414d4bb417b2a03cde1c854877486a6d2..059d3278b55cccc1b7b3bc7fa796b1d9cb4c0ece 100644 --- a/internal/msgstream/mq_msgstream.go +++ b/internal/msgstream/mq_msgstream.go @@ -17,7 +17,6 @@ package msgstream import ( - "bytes" "context" "errors" "fmt" @@ -133,19 +132,18 @@ func (ms *mqMsgStream) AsConsumer(channels []string, subName string) { } // Create consumer to receive message from channels, with initial position +// if initial position is set to latest, last message in the channel is exclusive func (ms *mqMsgStream) AsConsumerWithPosition(channels []string, subName string, position mqclient.SubscriptionInitialPosition) { for _, channel := range channels { if _, ok := ms.consumers[channel]; ok { continue } fn := func() error { - receiveChannel := make(chan mqclient.Message, ms.bufSize) pc, err := ms.client.Subscribe(mqclient.ConsumerOptions{ Topic: channel, SubscriptionName: subName, - Type: mqclient.KeyShared, + Type: mqclient.Exclusive, SubscriptionInitialPosition: position, - MessageChannel: receiveChannel, }) if err != nil { return err @@ -599,7 +597,7 @@ func (ms *mqMsgStream) Next(ctx context.Context, channelName string) (TsMsg, err } -// Seek reset the subscription associated with this consumer to a specific position +// Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive // User has to ensure mq_msgstream is not closed before seek, and the seek position is already written. func (ms *mqMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error { for _, mp := range msgPositions { @@ -612,25 +610,12 @@ func (ms *mqMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error { return err } log.Debug("MsgStream begin to seek", zap.Any("MessageID", mp.MsgID)) - err = consumer.Seek(messageID) + err = consumer.Seek(messageID, false) if err != nil { log.Debug("Failed to seek", zap.Error(err)) return err } log.Debug("MsgStream seek finished", zap.Any("MessageID", messageID)) - if consumer.ConsumeAfterSeek() { - log.Debug("MsgStream start to pop one message after seek") - msg, ok := <-consumer.Chan() - if !ok { - return errors.New("consumer closed") - } - log.Debug("MsgStream finish to pop one message after seek") - consumer.Ack(msg) - if !bytes.Equal(msg.ID().Serialize(), messageID.Serialize()) { - err = fmt.Errorf("seek msg not correct") - log.Error("msMsgStream seek", zap.Error(err)) - } - } } return nil } @@ -708,13 +693,11 @@ func (ms *MqTtMsgStream) AsConsumerWithPosition(channels []string, subName strin continue } fn := func() error { - receiveChannel := make(chan mqclient.Message, ms.bufSize) pc, err := ms.client.Subscribe(mqclient.ConsumerOptions{ Topic: channel, SubscriptionName: subName, - Type: mqclient.KeyShared, + Type: mqclient.Exclusive, SubscriptionInitialPosition: position, - MessageChannel: receiveChannel, }) if err != nil { return err @@ -955,7 +938,7 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error { if err != nil { return err } - err = consumer.Seek(seekMsgID) + err = consumer.Seek(seekMsgID, true) if err != nil { return err } @@ -975,16 +958,10 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error { return fmt.Errorf("Failed to seek, error %s", err.Error()) } ms.addConsumer(consumer, mp.ChannelName) - ms.chanMsgPos[consumer] = mp - - // rmq seek behavior (position, ...) - // pulsar seek behavior [position, ...) - // skip one tt for pulsar + ms.chanMsgPos[consumer] = (proto.Clone(mp)).(*MsgPosition) - runLoop := false - if consumer.ConsumeAfterSeek() { - runLoop = true - } + // skip all data before current tt + runLoop := true for runLoop { select { case <-ms.ctx.Done(): diff --git a/internal/msgstream/mq_msgstream_test.go b/internal/msgstream/mq_msgstream_test.go index 5a80d56b71a191debceb074ed0bba12323e36ddd..aa9417402403ac25eb3923fe64d2f91c45454b81 100644 --- a/internal/msgstream/mq_msgstream_test.go +++ b/internal/msgstream/mq_msgstream_test.go @@ -34,6 +34,7 @@ import ( "github.com/stretchr/testify/require" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/common" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -738,11 +739,12 @@ func TestStream_PulsarTtMsgStream_NoSeek(t *testing.T) { assert.Equal(t, o3.BeginTs, p3.BeginTs) } + func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { pulsarAddress, _ := Params.Load("_PulsarAddress") - c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) - producerChannels := []string{c1, c2} - consumerChannels := []string{c1, c2} + c1 := funcutil.RandomString(8) + producerChannels := []string{c1} + consumerChannels := []string{c1} consumerSubName := funcutil.RandomString(8) msgPack0 := MsgPack{} @@ -750,6 +752,7 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { msgPack1 := MsgPack{} msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 19)) msgPack2 := MsgPack{} @@ -763,7 +766,14 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { msgPack4.Msgs = append(msgPack4.Msgs, getTimeTickMsg(11)) msgPack5 := MsgPack{} - msgPack5.Msgs = append(msgPack5.Msgs, getTimeTickMsg(15)) + msgPack5.Msgs = append(msgPack5.Msgs, getTsMsg(commonpb.MsgType_Insert, 12)) + msgPack5.Msgs = append(msgPack5.Msgs, getTsMsg(commonpb.MsgType_Insert, 13)) + + msgPack6 := MsgPack{} + msgPack6.Msgs = append(msgPack6.Msgs, getTimeTickMsg(15)) + + msgPack7 := MsgPack{} + msgPack7.Msgs = append(msgPack7.Msgs, getTimeTickMsg(20)) inputStream := getPulsarInputStream(pulsarAddress, producerChannels) outputStream := getPulsarTtOutputStream(pulsarAddress, consumerChannels, consumerSubName) @@ -778,18 +788,66 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { assert.Nil(t, err) err = inputStream.Broadcast(&msgPack4) assert.Nil(t, err) + err = inputStream.Produce(&msgPack5) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack6) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack7) + assert.Nil(t, err) - outputStream.Consume() receivedMsg := outputStream.Consume() + assert.Equal(t, len(receivedMsg.Msgs), 2) + assert.Equal(t, receivedMsg.BeginTs, uint64(0)) + assert.Equal(t, receivedMsg.EndTs, uint64(5)) + + assert.Equal(t, receivedMsg.StartPositions[0].Timestamp, uint64(0)) + assert.Equal(t, receivedMsg.EndPositions[0].Timestamp, uint64(5)) + + receivedMsg2 := outputStream.Consume() + assert.Equal(t, len(receivedMsg2.Msgs), 1) + assert.Equal(t, receivedMsg2.BeginTs, uint64(5)) + assert.Equal(t, receivedMsg2.EndTs, uint64(11)) + assert.Equal(t, receivedMsg2.StartPositions[0].Timestamp, uint64(5)) + assert.Equal(t, receivedMsg2.EndPositions[0].Timestamp, uint64(11)) + + receivedMsg3 := outputStream.Consume() + assert.Equal(t, len(receivedMsg3.Msgs), 3) + assert.Equal(t, receivedMsg3.BeginTs, uint64(11)) + assert.Equal(t, receivedMsg3.EndTs, uint64(15)) + assert.Equal(t, receivedMsg3.StartPositions[0].Timestamp, uint64(11)) + assert.Equal(t, receivedMsg3.EndPositions[0].Timestamp, uint64(15)) + + receivedMsg4 := outputStream.Consume() + assert.Equal(t, len(receivedMsg4.Msgs), 1) + assert.Equal(t, receivedMsg4.BeginTs, uint64(15)) + assert.Equal(t, receivedMsg4.EndTs, uint64(20)) + assert.Equal(t, receivedMsg4.StartPositions[0].Timestamp, uint64(15)) + assert.Equal(t, receivedMsg4.EndPositions[0].Timestamp, uint64(20)) + outputStream.Close() - outputStream = getPulsarTtOutputStreamAndSeek(pulsarAddress, receivedMsg.EndPositions) - err = inputStream.Broadcast(&msgPack5) - assert.Nil(t, err) + outputStream = getPulsarTtOutputStreamAndSeek(pulsarAddress, receivedMsg3.StartPositions) + seekMsg := outputStream.Consume() + assert.Equal(t, len(seekMsg.Msgs), 3) + result := []uint64{14, 12, 13} + for i, msg := range seekMsg.Msgs { + assert.Equal(t, msg.BeginTs(), result[i]) + } + seekMsg2 := outputStream.Consume() + assert.Equal(t, len(seekMsg2.Msgs), 1) + for _, msg := range seekMsg2.Msgs { + assert.Equal(t, msg.BeginTs(), uint64(19)) + } + //outputStream.Close() + outputStream = getPulsarTtOutputStreamAndSeek(pulsarAddress, receivedMsg3.EndPositions) + + seekMsg = outputStream.Consume() + assert.Equal(t, len(seekMsg.Msgs), 1) for _, msg := range seekMsg.Msgs { - assert.Equal(t, msg.BeginTs(), uint64(14)) + assert.Equal(t, msg.BeginTs(), uint64(19)) } + inputStream.Close() outputStream.Close() } @@ -1061,12 +1119,12 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { c := funcutil.RandomString(8) producerChannels := []string{c} consumerChannels := []string{c} - consumerSubName := funcutil.RandomString(8) msgPack := &MsgPack{} inputStream := getPulsarInputStream(pulsarAddress, producerChannels) - outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName) - + defer inputStream.Close() + outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, funcutil.RandomString(8)) + defer outputStream.Close() for i := 0; i < 10; i++ { insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) msgPack.Msgs = append(msgPack.Msgs, insertMsg) @@ -1080,16 +1138,15 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { assert.Equal(t, result.Msgs[0].ID(), int64(i)) seekPosition = result.EndPositions[0] } - outputStream.Close() factory := ProtoUDFactory{} pulsarClient, _ := mqclient.GetPulsarClientInstance(pulsar.ClientOptions{URL: pulsarAddress}) outputStream2, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - outputStream2.AsConsumer(consumerChannels, consumerSubName) - + outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8)) + defer outputStream2.Close() messageID, _ := pulsar.DeserializeMessageID(seekPosition.MsgID) // try to seek to not written position - patchMessageID(&messageID, 11) + patchMessageID(&messageID, 13) p := []*internalpb.MsgPosition{ { @@ -1100,13 +1157,78 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { }, } - go func() { - time.Sleep(1 * time.Second) - outputStream2.Close() - }() + err = outputStream2.Seek(p) + assert.Nil(t, err) + outputStream2.Start() + + for i := 10; i < 20; i++ { + insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) + msgPack.Msgs = append(msgPack.Msgs, insertMsg) + } + err = inputStream.Produce(msgPack) + assert.Nil(t, err) + result := outputStream2.Consume() + assert.Equal(t, result.Msgs[0].ID(), int64(1)) +} + +func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { + rocksdbName := "/tmp/rocksmq_tt_msg_seekInvalid" + etcdKV := initRmq(rocksdbName) + c := funcutil.RandomString(8) + producerChannels := []string{c} + consumerChannels := []string{c} + consumerSubName := funcutil.RandomString(8) + inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) + + msgPack := &MsgPack{} + for i := 0; i < 10; i++ { + insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) + msgPack.Msgs = append(msgPack.Msgs, insertMsg) + } + + err := inputStream.Produce(msgPack) + assert.Nil(t, err) + var seekPosition *internalpb.MsgPosition + for i := 0; i < 10; i++ { + result := outputStream.Consume() + assert.Equal(t, result.Msgs[0].ID(), int64(i)) + seekPosition = result.EndPositions[0] + } + outputStream.Close() + + factory := ProtoUDFactory{} + rmqClient2, _ := mqclient.NewRmqClient(client.ClientOptions{Server: rocksmq.Rmq}) + outputStream2, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher()) + outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8)) + + id := common.Endian.Uint64(seekPosition.MsgID) + 10 + bs := make([]byte, 8) + common.Endian.PutUint64(bs, id) + p := []*internalpb.MsgPosition{ + { + ChannelName: seekPosition.ChannelName, + Timestamp: seekPosition.Timestamp, + MsgGroup: seekPosition.MsgGroup, + MsgID: bs, + }, + } err = outputStream2.Seek(p) - assert.Error(t, err) + assert.Nil(t, err) + outputStream2.Start() + + for i := 10; i < 20; i++ { + insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) + msgPack.Msgs = append(msgPack.Msgs, insertMsg) + } + err = inputStream.Produce(msgPack) + assert.Nil(t, err) + + result := outputStream2.Consume() + assert.Equal(t, result.Msgs[0].ID(), int64(1)) + + Close(rocksdbName, inputStream, outputStream2, etcdKV) + } func TestStream_MqMsgStream_SeekLatest(t *testing.T) { @@ -1332,6 +1454,118 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) { Close(rocksdbName, inputStream, outputStream, etcdKV) } +func TestStream_RmqTtMsgStream_Seek(t *testing.T) { + rocksdbName := "/tmp/rocksmq_tt_msg_seek" + etcdKV := initRmq(rocksdbName) + + c1 := funcutil.RandomString(8) + producerChannels := []string{c1} + consumerChannels := []string{c1} + consumerSubName := funcutil.RandomString(8) + + msgPack0 := MsgPack{} + msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0)) + + msgPack1 := MsgPack{} + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 19)) + + msgPack2 := MsgPack{} + msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5)) + + msgPack3 := MsgPack{} + msgPack3.Msgs = append(msgPack3.Msgs, getTsMsg(commonpb.MsgType_Insert, 14)) + msgPack3.Msgs = append(msgPack3.Msgs, getTsMsg(commonpb.MsgType_Insert, 9)) + + msgPack4 := MsgPack{} + msgPack4.Msgs = append(msgPack4.Msgs, getTimeTickMsg(11)) + + msgPack5 := MsgPack{} + msgPack5.Msgs = append(msgPack5.Msgs, getTsMsg(commonpb.MsgType_Insert, 12)) + msgPack5.Msgs = append(msgPack5.Msgs, getTsMsg(commonpb.MsgType_Insert, 13)) + + msgPack6 := MsgPack{} + msgPack6.Msgs = append(msgPack6.Msgs, getTimeTickMsg(15)) + + msgPack7 := MsgPack{} + msgPack7.Msgs = append(msgPack7.Msgs, getTimeTickMsg(20)) + + inputStream, outputStream := initRmqTtStream(producerChannels, consumerChannels, consumerSubName) + + err := inputStream.Broadcast(&msgPack0) + assert.Nil(t, err) + err = inputStream.Produce(&msgPack1) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack2) + assert.Nil(t, err) + err = inputStream.Produce(&msgPack3) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack4) + assert.Nil(t, err) + err = inputStream.Produce(&msgPack5) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack6) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack7) + assert.Nil(t, err) + + receivedMsg := outputStream.Consume() + assert.Equal(t, len(receivedMsg.Msgs), 2) + assert.Equal(t, receivedMsg.BeginTs, uint64(0)) + assert.Equal(t, receivedMsg.EndTs, uint64(5)) + + assert.Equal(t, receivedMsg.StartPositions[0].Timestamp, uint64(0)) + assert.Equal(t, receivedMsg.EndPositions[0].Timestamp, uint64(5)) + + receivedMsg2 := outputStream.Consume() + assert.Equal(t, len(receivedMsg2.Msgs), 1) + assert.Equal(t, receivedMsg2.BeginTs, uint64(5)) + assert.Equal(t, receivedMsg2.EndTs, uint64(11)) + assert.Equal(t, receivedMsg2.StartPositions[0].Timestamp, uint64(5)) + assert.Equal(t, receivedMsg2.EndPositions[0].Timestamp, uint64(11)) + + receivedMsg3 := outputStream.Consume() + assert.Equal(t, len(receivedMsg3.Msgs), 3) + assert.Equal(t, receivedMsg3.BeginTs, uint64(11)) + assert.Equal(t, receivedMsg3.EndTs, uint64(15)) + assert.Equal(t, receivedMsg3.StartPositions[0].Timestamp, uint64(11)) + assert.Equal(t, receivedMsg3.EndPositions[0].Timestamp, uint64(15)) + + receivedMsg4 := outputStream.Consume() + assert.Equal(t, len(receivedMsg4.Msgs), 1) + assert.Equal(t, receivedMsg4.BeginTs, uint64(15)) + assert.Equal(t, receivedMsg4.EndTs, uint64(20)) + assert.Equal(t, receivedMsg4.StartPositions[0].Timestamp, uint64(15)) + assert.Equal(t, receivedMsg4.EndPositions[0].Timestamp, uint64(20)) + + outputStream.Close() + + factory := ProtoUDFactory{} + + rmqClient, _ := mqclient.NewRmqClient(client.ClientOptions{Server: rocksmq.Rmq}) + outputStream, _ = NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) + consumerSubName = funcutil.RandomString(8) + outputStream.AsConsumer(consumerChannels, consumerSubName) + + outputStream.Seek(receivedMsg3.StartPositions) + outputStream.Start() + seekMsg := outputStream.Consume() + assert.Equal(t, len(seekMsg.Msgs), 3) + result := []uint64{14, 12, 13} + for i, msg := range seekMsg.Msgs { + assert.Equal(t, msg.BeginTs(), result[i]) + } + + seekMsg2 := outputStream.Consume() + assert.Equal(t, len(seekMsg2.Msgs), 1) + for _, msg := range seekMsg2.Msgs { + assert.Equal(t, msg.BeginTs(), uint64(19)) + } + + Close(rocksdbName, inputStream, outputStream, etcdKV) +} + func TestStream_BroadcastMark(t *testing.T) { pulsarAddress, _ := Params.Load("_PulsarAddress") c1 := funcutil.RandomString(8) @@ -1722,7 +1956,7 @@ func getPulsarTtOutputStreamAndSeek(pulsarAddress string, positions []*MsgPositi for _, c := range positions { consumerName = append(consumerName, c.ChannelName) } - outputStream.AsConsumer(consumerName, positions[0].MsgGroup) + outputStream.AsConsumer(consumerName, funcutil.RandomString(8)) outputStream.Seek(positions) outputStream.Start() return outputStream diff --git a/internal/querynode/flow_graph_query_node_test.go b/internal/querynode/flow_graph_query_node_test.go index 0a3f6db0873a1e5e99bcb918a70fdb96f5373d52..978e97180a969ea1fe156c27e0cae7eb20c51034 100644 --- a/internal/querynode/flow_graph_query_node_test.go +++ b/internal/querynode/flow_graph_query_node_test.go @@ -43,6 +43,8 @@ func TestQueryNodeFlowGraph_consumerFlowGraph(t *testing.T) { err = fg.consumerFlowGraph(defaultVChannel, defaultSubName) assert.NoError(t, err) + + fg.close() } func TestQueryNodeFlowGraph_seekQueryNodeFlowGraph(t *testing.T) { @@ -74,4 +76,6 @@ func TestQueryNodeFlowGraph_seekQueryNodeFlowGraph(t *testing.T) { } err = fg.seekQueryNodeFlowGraph(position) assert.Error(t, err) + + fg.close() } diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index 65e321bb5e698b521237d9ba2a5850105729e775..bb553f48ea81c16b8a013ab35deb4cd80b8d38f9 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -154,7 +154,8 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery } return status, err } - log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels)) + log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels), + zap.String("seek position", string(in.SeekPosition.MsgID))) } } diff --git a/internal/util/mqclient/consumer.go b/internal/util/mqclient/consumer.go index 7e2def0ccc01a8cf46042a1f2ba238db8dba9306..b65f15d6a0702fd7cddc01a81f3a13357a3302a6 100644 --- a/internal/util/mqclient/consumer.go +++ b/internal/util/mqclient/consumer.go @@ -59,10 +59,6 @@ type ConsumerOptions struct { // Default is `Latest` SubscriptionInitialPosition - // Message for this consumer - // When a message is received, it will be pushed to this channel for consumption - MessageChannel chan Message - // Set receive channel size BufSize int64 @@ -80,14 +76,11 @@ type Consumer interface { Chan() <-chan Message // Seek to the uniqueID position - Seek(MessageID) error //nolint:govet + Seek(MessageID, bool) error //nolint:govet // Make sure that msg is received. Only used in pulsar Ack(Message) - // ConsumeAfterSeek defines the behavior whether to consume after seeking is done - ConsumeAfterSeek() bool - // Close consumer Close() } diff --git a/internal/util/mqclient/pulsar_client_test.go b/internal/util/mqclient/pulsar_client_test.go index e20dde41d4301ed6c31351f2f7ce2d819eb1d05f..2ee2be419510bf9f48ac57bd38dae426b7801581 100644 --- a/internal/util/mqclient/pulsar_client_test.go +++ b/internal/util/mqclient/pulsar_client_test.go @@ -124,7 +124,7 @@ func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, assert.NotNil(t, consumer) defer consumer.Close() - err = consumer.Seek(msgID) + err = consumer.Seek(msgID, true) assert.Nil(t, err) // skip the last received message @@ -376,7 +376,7 @@ func TestPulsarClient_Consume2(t *testing.T) { log.Info("main done") } -func TestPulsarClient_Seek(t *testing.T) { +func TestPulsarClient_SeekPosition(t *testing.T) { pulsarAddress, _ := Params.Load("_PulsarAddress") pc, err := GetPulsarClientInstance(pulsar.ClientOptions{URL: pulsarAddress}) defer pc.Close() @@ -393,14 +393,15 @@ func TestPulsarClient_Seek(t *testing.T) { assert.NotNil(t, producer) log.Info("Produce start") - var id MessageID + ids := []MessageID{} arr := []int{1, 2, 3} for _, v := range arr { msg := &ProducerMessage{ Payload: IntToBytes(v), Properties: map[string]string{}, } - id, err = producer.Send(ctx, msg) + id, err := producer.Send(ctx, msg) + ids = append(ids, id) assert.Nil(t, err) } @@ -415,16 +416,99 @@ func TestPulsarClient_Seek(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, consumer) defer consumer.Close() - seekID := id.(*pulsarID).messageID + seekID := ids[2].(*pulsarID).messageID consumer.Seek(seekID) msgChan := consumer.Chan() select { case msg := <-msgChan: + assert.Equal(t, seekID.BatchIdx(), msg.ID().BatchIdx()) + assert.Equal(t, seekID.LedgerID(), msg.ID().LedgerID()) + assert.Equal(t, seekID.EntryID(), msg.ID().EntryID()) + assert.Equal(t, seekID.PartitionIdx(), msg.ID().PartitionIdx()) assert.Equal(t, 3, BytesToInt(msg.Payload())) case <-time.After(2 * time.Second): - log.Info("after 2 seconds") + assert.FailNow(t, "should not wait") + } + + seekID = ids[1].(*pulsarID).messageID + consumer.Seek(seekID) + + msgChan = consumer.Chan() + + select { + case msg := <-msgChan: + assert.Equal(t, seekID.BatchIdx(), msg.ID().BatchIdx()) + assert.Equal(t, seekID.LedgerID(), msg.ID().LedgerID()) + assert.Equal(t, seekID.EntryID(), msg.ID().EntryID()) + assert.Equal(t, seekID.PartitionIdx(), msg.ID().PartitionIdx()) + assert.Equal(t, 2, BytesToInt(msg.Payload())) + case <-time.After(2 * time.Second): + assert.FailNow(t, "should not wait") + } +} + +func TestPulsarClient_SeekLatest(t *testing.T) { + pulsarAddress, _ := Params.Load("_PulsarAddress") + pc, err := GetPulsarClientInstance(pulsar.ClientOptions{URL: pulsarAddress}) + defer pc.Close() + assert.NoError(t, err) + assert.NotNil(t, pc) + rand.Seed(time.Now().UnixNano()) + + ctx := context.Background() + topic := fmt.Sprintf("test-topic-%d", rand.Int()) + subName := fmt.Sprintf("test-subname-%d", rand.Int()) + + producer, err := pc.CreateProducer(ProducerOptions{Topic: topic}) + assert.Nil(t, err) + assert.NotNil(t, producer) + + log.Info("Produce start") + + arr := []int{1, 2, 3} + for _, v := range arr { + msg := &ProducerMessage{ + Payload: IntToBytes(v), + Properties: map[string]string{}, + } + _, err = producer.Send(ctx, msg) + assert.Nil(t, err) + } + + log.Info("Produced") + + consumer, err := pc.client.Subscribe(pulsar.ConsumerOptions{ + Topic: topic, + SubscriptionName: subName, + Type: pulsar.KeyShared, + SubscriptionInitialPosition: pulsar.SubscriptionPositionLatest, + }) + assert.Nil(t, err) + assert.NotNil(t, consumer) + defer consumer.Close() + + msgChan := consumer.Chan() + + loop := true + for loop { + select { + case msg := <-msgChan: + consumer.Ack(msg) + v := BytesToInt(msg.Payload()) + log.Info("RECV", zap.Any("v", v)) + assert.Equal(t, v, 4) + loop = false + case <-time.After(2 * time.Second): + log.Info("after 2 seconds") + msg := &ProducerMessage{ + Payload: IntToBytes(4), + Properties: map[string]string{}, + } + _, err = producer.Send(ctx, msg) + assert.Nil(t, err) + } } } diff --git a/internal/util/mqclient/pulsar_consumer.go b/internal/util/mqclient/pulsar_consumer.go index 607e61be88b9080ab9f7b12c15b499011ded1425..ae0a8bc659116036d815d7f9f53a8eefe2dbc3b9 100644 --- a/internal/util/mqclient/pulsar_consumer.go +++ b/internal/util/mqclient/pulsar_consumer.go @@ -27,6 +27,7 @@ type PulsarConsumer struct { AtLatest bool closeCh chan struct{} once sync.Once + skip bool } func (pc *PulsarConsumer) Subscription() string { @@ -58,7 +59,11 @@ func (pc *PulsarConsumer) Chan() <-chan Message { log.Debug("pulsar consumer channel closed") return } - pc.msgChannel <- &pulsarMessage{msg: msg} + if !pc.skip { + pc.msgChannel <- &pulsarMessage{msg: msg} + } else { + pc.skip = false + } case <-pc.closeCh: // workaround for pulsar consumer.receiveCh not closed close(pc.msgChannel) return @@ -72,20 +77,17 @@ func (pc *PulsarConsumer) Chan() <-chan Message { // Seek seek consume position to the pointed messageID, // the pointed messageID will be consumed after the seek in pulsar -func (pc *PulsarConsumer) Seek(id MessageID) error { +func (pc *PulsarConsumer) Seek(id MessageID, inclusive bool) error { messageID := id.(*pulsarID).messageID err := pc.c.Seek(messageID) if err == nil { pc.hasSeek = true + // skip the first message when consume + pc.skip = !inclusive } return err } -// ConsumeAfterSeek defines pulsar consumer SHOULD consume after seek -func (pc *PulsarConsumer) ConsumeAfterSeek() bool { - return true -} - func (pc *PulsarConsumer) Ack(message Message) { pm := message.(*pulsarMessage) pc.c.Ack(pm.msg) diff --git a/internal/util/mqclient/rmq_client_test.go b/internal/util/mqclient/rmq_client_test.go index 2c26162fb7b3b091f4f60ad52446e9a73e2f1178..d31a3082cd875530599639ccd4fd11a7961f3bdf 100644 --- a/internal/util/mqclient/rmq_client_test.go +++ b/internal/util/mqclient/rmq_client_test.go @@ -136,7 +136,7 @@ func TestRmqClient_Subscribe(t *testing.T) { msgID := rmqmsg.ID() rID := msgID.(*rmqID) assert.NotZero(t, rID) - err = consumer.Seek(msgID) + err = consumer.Seek(msgID, true) assert.Nil(t, err) } } diff --git a/internal/util/mqclient/rmq_consumer.go b/internal/util/mqclient/rmq_consumer.go index 62330f7fe4aacad162a355af10707d2be57ebf4b..5f6887cfc1f162985d1a52b96588f33b9270f4db 100644 --- a/internal/util/mqclient/rmq_consumer.go +++ b/internal/util/mqclient/rmq_consumer.go @@ -23,6 +23,7 @@ type RmqConsumer struct { msgChannel chan Message closeCh chan struct{} once sync.Once + skip bool } // Subscription returns the subscription name of this consumer @@ -43,7 +44,11 @@ func (rc *RmqConsumer) Chan() <-chan Message { close(rc.msgChannel) return } - rc.msgChannel <- &rmqMessage{msg: msg} + if !rc.skip { + rc.msgChannel <- &rmqMessage{msg: msg} + } else { + rc.skip = false + } case <-rc.closeCh: close(rc.msgChannel) return @@ -56,16 +61,13 @@ func (rc *RmqConsumer) Chan() <-chan Message { } // Seek is used to seek the position in rocksmq topic -func (rc *RmqConsumer) Seek(id MessageID) error { +func (rc *RmqConsumer) Seek(id MessageID, inclusive bool) error { msgID := id.(*rmqID).messageID + // skip the first message when consume + rc.skip = !inclusive return rc.c.Seek(msgID) } -// ConsumeAfterSeek defines rmq consumer should NOT consume after seek -func (rc *RmqConsumer) ConsumeAfterSeek() bool { - return false -} - // Ack is used to ask a rocksmq message func (rc *RmqConsumer) Ack(message Message) { } diff --git a/internal/util/rocksmq/client/rocksmq/client_impl.go b/internal/util/rocksmq/client/rocksmq/client_impl.go index a40d6af3bbc8b52ec17f8c2ded909e7d6d1c02e8..02ffd41f5a49462897a927a111c4e6c930310da9 100644 --- a/internal/util/rocksmq/client/rocksmq/client_impl.go +++ b/internal/util/rocksmq/client/rocksmq/client_impl.go @@ -123,36 +123,46 @@ func (c *client) consume(consumer *consumer) { select { case <-c.closeCh: return + case _, ok := <-consumer.initCh: + if !ok { + return + } + c.deliver(consumer, 100) case _, ok := <-consumer.MsgMutex(): if !ok { // consumer MsgMutex closed, goroutine exit log.Debug("Consumer MsgMutex closed") return } + c.deliver(consumer, 100) + } + } +} - for { - n := cap(consumer.messageCh) - len(consumer.messageCh) - if n < 100 { // batch min size - n = 100 - } - msgs, err := consumer.client.server.Consume(consumer.topic, consumer.consumerName, n) - if err != nil { - log.Debug("Consumer's goroutine cannot consume from (" + consumer.topic + - "," + consumer.consumerName + "): " + err.Error()) - break - } - - // no more msgs - if len(msgs) == 0 { - break - } - for _, msg := range msgs { - consumer.messageCh <- Message{ - MsgID: msg.MsgID, - Payload: msg.Payload, - Topic: consumer.Topic(), - } - } +func (c *client) deliver(consumer *consumer, batchMin int) { + for { + n := cap(consumer.messageCh) - len(consumer.messageCh) + if n < batchMin { // batch min size + n = batchMin + } + msgs, err := consumer.client.server.Consume(consumer.topic, consumer.consumerName, n) + if err != nil { + log.Warn("Consumer's goroutine cannot consume from (" + consumer.topic + "," + consumer.consumerName + "): " + err.Error()) + break + } + + // no more msgs + if len(msgs) == 0 { + break + } + for _, msg := range msgs { + select { + case consumer.messageCh <- Message{ + MsgID: msg.MsgID, + Payload: msg.Payload, + Topic: consumer.Topic()}: + case <-c.closeCh: + return } } } diff --git a/internal/util/rocksmq/client/rocksmq/client_impl_test.go b/internal/util/rocksmq/client/rocksmq/client_impl_test.go index 8dd4d8ca3af24885377228c704d83b8e255cc734..41f0ea10dcadc09882c75a4032bbdf61d2d0dea4 100644 --- a/internal/util/rocksmq/client/rocksmq/client_impl_test.go +++ b/internal/util/rocksmq/client/rocksmq/client_impl_test.go @@ -13,6 +13,7 @@ package rocksmq import ( "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -120,6 +121,76 @@ func TestClient_Subscribe(t *testing.T) { assert.NoError(t, err) } +func TestClient_SeekLatest(t *testing.T) { + rmqPath := "/tmp/milvus/seekLatest" + rmq := newRocksMQ(rmqPath) + defer removePath(rmqPath) + client, err := NewClient(ClientOptions{ + Server: rmq, + }) + assert.NoError(t, err) + defer client.Close() + + topicName := newTopicName() + opt := ConsumerOptions{ + Topic: topicName, + SubscriptionName: newConsumerName(), + SubscriptionInitialPosition: SubscriptionPositionEarliest, + } + consumer1, err := client.Subscribe(opt) + assert.NoError(t, err) + assert.NotNil(t, consumer1) + + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topicName, + }) + assert.NotNil(t, producer) + assert.NoError(t, err) + msg := &ProducerMessage{ + Payload: make([]byte, 10), + } + id, err := producer.Send(msg) + assert.Nil(t, err) + + msgChan := consumer1.Chan() + msgRead, ok := <-msgChan + assert.Equal(t, ok, true) + assert.Equal(t, msgRead.MsgID, id) + + consumer1.Close() + + opt1 := ConsumerOptions{ + Topic: topicName, + SubscriptionName: newConsumerName(), + SubscriptionInitialPosition: SubscriptionPositionLatest, + } + consumer2, err := client.Subscribe(opt1) + assert.NoError(t, err) + assert.NotNil(t, consumer2) + + msgChan = consumer2.Chan() + loop := true + for loop { + select { + case msg := <-msgChan: + assert.Equal(t, len(msg.Payload), 8) + loop = false + case <-time.After(2 * time.Second): + msg := &ProducerMessage{ + Payload: make([]byte, 8), + } + _, err = producer.Send(msg) + assert.Nil(t, err) + } + } + + producer1, err := client.CreateProducer(ProducerOptions{ + Topic: newTopicName(), + }) + assert.NotNil(t, producer1) + assert.NoError(t, err) +} + func TestClient_consume(t *testing.T) { rmqPath := "/tmp/milvus/test_client3" rmq := newRocksMQ(rmqPath) @@ -148,8 +219,11 @@ func TestClient_consume(t *testing.T) { msg := &ProducerMessage{ Payload: make([]byte, 10), } - _, err = producer.Send(msg) + id, err := producer.Send(msg) assert.Nil(t, err) - <-consumer.Chan() + msgChan := consumer.Chan() + msgConsume, ok := <-msgChan + assert.Equal(t, ok, true) + assert.Equal(t, id, msgConsume.MsgID) } diff --git a/internal/util/rocksmq/client/rocksmq/consumer_impl.go b/internal/util/rocksmq/client/rocksmq/consumer_impl.go index b9471bf53b4dbbe440c68f3e5079eeb93ca0883b..d6fbe1116b7480914e9ee92553c7ed97068f547a 100644 --- a/internal/util/rocksmq/client/rocksmq/consumer_impl.go +++ b/internal/util/rocksmq/client/rocksmq/consumer_impl.go @@ -28,6 +28,7 @@ type consumer struct { startOnce sync.Once msgMutex chan struct{} + initCh chan struct{} messageCh chan Message } @@ -48,13 +49,16 @@ func newConsumer(c *client, options ConsumerOptions) (*consumer, error) { if options.MessageChannel == nil { messageCh = make(chan Message, 1) } - + // only used for + initCh := make(chan struct{}, 1) + initCh <- struct{}{} return &consumer{ topic: options.Topic, client: c, consumerName: options.SubscriptionName, options: options, msgMutex: make(chan struct{}, 1), + initCh: initCh, messageCh: messageCh, }, nil } diff --git a/internal/util/rocksmq/server/rocksmq/rocksmq_impl.go b/internal/util/rocksmq/server/rocksmq/rocksmq_impl.go index 8261dcf74500d17030591c55f8c31811819cfb9b..8a197bbe4e801d8aa96daba2f6121d4e3e7c8e3d 100644 --- a/internal/util/rocksmq/server/rocksmq/rocksmq_impl.go +++ b/internal/util/rocksmq/server/rocksmq/rocksmq_impl.go @@ -18,6 +18,7 @@ import ( "math" "path" "strconv" + "strings" "sync" "time" @@ -444,7 +445,7 @@ func (rmq *rocksmq) Produce(topicName string, messages []ProducerMessage) ([]Uni idStart, idEnd, err := rmq.idAllocator.Alloc(uint32(msgLen)) if err != nil { - log.Debug("RocksMQ: alloc id failed.") + log.Error("RocksMQ: alloc id failed.", zap.Error(err)) return []UniqueID{}, err } @@ -488,7 +489,6 @@ func (rmq *rocksmq) Produce(topicName string, messages []ProducerMessage) ([]Uni kvValues := make(map[string]string) if beginIDValue == "0" { - log.Debug("RocksMQ: overwrite " + kvChannelBeginID + " with " + strconv.FormatInt(idStart, 10)) kvValues[kvChannelBeginID] = strconv.FormatInt(idStart, 10) } @@ -614,18 +614,14 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum log.Debug("RocksMQ: fixChannelName " + topicName + " failed") return nil, err } - dataKey := fixChanName + "/" + currentID - // msgID is DefaultMessageID means this is the first consume operation - // currentID may be not valid if the deprecated values has been removed, when - // we move currentID to first location. - // Note that we assume currentId is always correct and not larger than the latest endID. - if iter.Seek([]byte(dataKey)); currentID != DefaultMessageID && iter.Valid() { - iter.Next() + var dataKey string + if currentID == DefaultMessageID { + dataKey = fixChanName + "/" } else { - newKey := fixChanName + "/" - iter.Seek([]byte(newKey)) + dataKey = fixChanName + "/" + currentID } + iter.Seek([]byte(dataKey)) offset := 0 for ; iter.Valid() && offset < n; iter.Next() { @@ -666,9 +662,8 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum consumedIDs = append(consumedIDs, msg.MsgID) } newID := consumedIDs[len(consumedIDs)-1] - err = rmq.seek(topicName, groupName, newID) + err = rmq.moveConsumePos(topicName, groupName, newID+1) if err != nil { - log.Debug("RocksMQ: Seek(" + groupName + "," + topicName + "," + strconv.FormatInt(newID, 10) + ") failed") return nil, err } @@ -690,13 +685,11 @@ func (rmq *rocksmq) seek(topicName string, groupName string, msgID UniqueID) err log.Warn("RocksMQ: channel " + key + " not exists") return fmt.Errorf("ConsumerGroup %s, channel %s not exists", groupName, topicName) } - storeKey, err := combKey(topicName, msgID) if err != nil { log.Warn("RocksMQ: combKey(" + topicName + "," + strconv.FormatInt(msgID, 10) + ") failed") return err } - opts := gorocksdb.NewDefaultReadOptions() defer opts.Destroy() val, err := rmq.store.Get(opts, []byte(storeKey)) @@ -705,14 +698,22 @@ func (rmq *rocksmq) seek(topicName string, groupName string, msgID UniqueID) err log.Warn("RocksMQ: get " + storeKey + " failed") return err } + if !val.Exists() { + //skip seek if key is not found, this is the behavior as pulsar + return nil + } /* Step II: Save current_id in kv */ - err = rmq.kv.Save(key, strconv.FormatInt(msgID, 10)) + return rmq.moveConsumePos(topicName, groupName, msgID) +} + +func (rmq *rocksmq) moveConsumePos(topicName string, groupName string, msgID UniqueID) error { + key := constructCurrentID(topicName, groupName) + err := rmq.kv.Save(key, strconv.FormatInt(msgID, 10)) if err != nil { log.Warn("RocksMQ: save " + key + " failed") return err } - return nil } @@ -733,7 +734,7 @@ func (rmq *rocksmq) Seek(topicName string, groupName string, msgID UniqueID) err return rmq.seek(topicName, groupName, msgID) } -// SeekToLatest updates current id to the msg id of latest message +// SeekToLatest updates current id to the msg id of latest message + 1 func (rmq *rocksmq) SeekToLatest(topicName, groupName string) error { rmq.storeMu.Lock() defer rmq.storeMu.Unlock() @@ -745,39 +746,34 @@ func (rmq *rocksmq) SeekToLatest(topicName, groupName string) error { readOpts := gorocksdb.NewDefaultReadOptions() defer readOpts.Destroy() - readOpts.SetPrefixSameAsStart(true) iter := rmq.store.NewIterator(readOpts) defer iter.Close() fixChanName, _ := fixChannelName(topicName) - iter.Seek([]byte(fixChanName + "/")) - iKey := iter.Key() - // iter.SeekToLast bypass prefix limitation - // use for range until iterator invalid for now - if iter.Valid() { - iter.Next() - for iter.Valid() { - iKey.Free() - iKey = iter.Key() - iter.Next() - } - } else { - // In this case there are no messages, so shouldn't return error - return nil - } - if iKey == nil { + + // 0 is the ASC value of "/" + 1 + iter.SeekForPrev([]byte(fixChanName + "0")) + + // should find the last key we written into, start with fixChanName/ + // if not find, start from 0 + if !iter.Valid() { return nil } - seekMsgID := string(iKey.Data()) // bytes to string, copy + iKey := iter.Key() + seekMsgID := string(iKey.Data()) iKey.Free() + // if find message is not belong to current channel, start from 0 + if !strings.Contains(seekMsgID, fixChanName+"/") { + return nil + } msgID, err := strconv.ParseInt(seekMsgID[FixedChannelNameLen+1:], 10, 64) if err != nil { return err } - err = rmq.kv.Save(key, strconv.FormatInt(msgID, 10)) - return err + // current msgID should not be included + return rmq.moveConsumePos(topicName, groupName, msgID+1) } // Notify sends a mutex in MsgMutex channel to tell consumers to consume diff --git a/internal/util/rocksmq/server/rocksmq/rocksmq_impl_test.go b/internal/util/rocksmq/server/rocksmq/rocksmq_impl_test.go index 5060eab8e3bf898b0a285c1481fc273721abf660..82a08d0341743326bd8e77802f13e48cd1f13559 100644 --- a/internal/util/rocksmq/server/rocksmq/rocksmq_impl_test.go +++ b/internal/util/rocksmq/server/rocksmq/rocksmq_impl_test.go @@ -151,7 +151,7 @@ func TestRocksmq(t *testing.T) { assert.Nil(t, err) defer rmq.Close() - channelName := "channel_a" + channelName := "channel_rocks" err = rmq.CreateTopic(channelName) assert.Nil(t, err) defer rmq.DestroyTopic(channelName) @@ -257,6 +257,65 @@ func TestRocksmq_Dummy(t *testing.T) { } +func TestRocksmq_Seek(t *testing.T) { + suffix := "_seek" + kvPath := rmqPath + kvPathSuffix + suffix + defer os.RemoveAll(kvPath) + idAllocator := InitIDAllocator(kvPath) + + rocksdbPath := rmqPath + dbPathSuffix + suffix + defer os.RemoveAll(rocksdbPath) + metaPath := rmqPath + metaPathSuffix + suffix + defer os.RemoveAll(metaPath) + + rmq, err := NewRocksMQ(rocksdbPath, idAllocator) + assert.Nil(t, err) + defer rmq.Close() + + _, err = NewRocksMQ("", idAllocator) + assert.Error(t, err) + + channelName := "channel_seek" + err = rmq.CreateTopic(channelName) + assert.NoError(t, err) + defer rmq.DestroyTopic(channelName) + + var seekID UniqueID + var seekID2 UniqueID + for i := 0; i < 100; i++ { + msg := "message_" + strconv.Itoa(i) + pMsg := ProducerMessage{Payload: []byte(msg)} + pMsgs := make([]ProducerMessage, 1) + pMsgs[0] = pMsg + id, err := rmq.Produce(channelName, pMsgs) + if i == 50 { + seekID = id[0] + } + if i == 51 { + seekID2 = id[0] + } + assert.Nil(t, err) + } + + groupName1 := "group_dummy" + + err = rmq.CreateConsumerGroup(channelName, groupName1) + assert.NoError(t, err) + err = rmq.Seek(channelName, groupName1, seekID) + assert.NoError(t, err) + + messages, err := rmq.Consume(channelName, groupName1, 1) + assert.NoError(t, err) + assert.Equal(t, messages[0].MsgID, seekID) + + messages, err = rmq.Consume(channelName, groupName1, 1) + assert.NoError(t, err) + assert.Equal(t, messages[0].MsgID, seekID2) + + _ = rmq.DestroyConsumerGroup(channelName, groupName1) + +} + func TestRocksmq_Loop(t *testing.T) { ep := etcdEndpoints() etcdKV, err := etcdkv.NewEtcdKV(ep, "/etcd/test/root") @@ -599,21 +658,57 @@ func TestRocksmq_SeekToLatest(t *testing.T) { err = rmq.SeekToLatest(channelName, groupName) assert.NoError(t, err) + channelNamePrev := "channel_tes" + err = rmq.CreateTopic(channelNamePrev) + assert.Nil(t, err) + defer rmq.DestroyTopic(channelNamePrev) pMsgs := make([]ProducerMessage, loopNum) for i := 0; i < loopNum; i++ { - msg := "message_" + strconv.Itoa(i+loopNum) + msg := "message_" + strconv.Itoa(i) pMsg := ProducerMessage{Payload: []byte(msg)} pMsgs[i] = pMsg } - _, err = rmq.Produce(channelName, pMsgs) + _, err = rmq.Produce(channelNamePrev, pMsgs) assert.Nil(t, err) + // should hit the case where channel is null err = rmq.SeekToLatest(channelName, groupName) + assert.NoError(t, err) + + ids, err := rmq.Produce(channelName, pMsgs) assert.Nil(t, err) + // able to read out cMsgs, err := rmq.Consume(channelName, groupName, loopNum) assert.Nil(t, err) + assert.Equal(t, len(cMsgs), loopNum) + for i := 0; i < loopNum; i++ { + assert.Equal(t, cMsgs[i].MsgID, ids[i]) + } + + err = rmq.SeekToLatest(channelName, groupName) + assert.NoError(t, err) + + cMsgs, err = rmq.Consume(channelName, groupName, loopNum) + assert.Nil(t, err) assert.Equal(t, len(cMsgs), 0) + + pMsgs = make([]ProducerMessage, loopNum) + for i := 0; i < loopNum; i++ { + msg := "message_" + strconv.Itoa(i+loopNum) + pMsg := ProducerMessage{Payload: []byte(msg)} + pMsgs[i] = pMsg + } + ids, err = rmq.Produce(channelName, pMsgs) + assert.Nil(t, err) + + // make sure we only consume the latest message + cMsgs, err = rmq.Consume(channelName, groupName, loopNum) + assert.Nil(t, err) + assert.Equal(t, len(cMsgs), loopNum) + for i := 0; i < loopNum; i++ { + assert.Equal(t, cMsgs[i].MsgID, ids[i]) + } } func TestRocksmq_Reader(t *testing.T) {