未验证 提交 0aaa87a7 编写于 作者: X Xiaofan 提交者: GitHub

Fix MsgStream inconsistent seek (#12042)

Signed-off-by: Nxiaofan-luan <xiaofan.luan@zilliz.com>
上级 99397535
......@@ -17,7 +17,6 @@
package msgstream
import (
"bytes"
"context"
"errors"
"fmt"
......@@ -133,19 +132,18 @@ func (ms *mqMsgStream) AsConsumer(channels []string, subName string) {
}
// Create consumer to receive message from channels, with initial position
// if initial position is set to latest, last message in the channel is exclusive
func (ms *mqMsgStream) AsConsumerWithPosition(channels []string, subName string, position mqclient.SubscriptionInitialPosition) {
for _, channel := range channels {
if _, ok := ms.consumers[channel]; ok {
continue
}
fn := func() error {
receiveChannel := make(chan mqclient.Message, ms.bufSize)
pc, err := ms.client.Subscribe(mqclient.ConsumerOptions{
Topic: channel,
SubscriptionName: subName,
Type: mqclient.KeyShared,
Type: mqclient.Exclusive,
SubscriptionInitialPosition: position,
MessageChannel: receiveChannel,
})
if err != nil {
return err
......@@ -599,7 +597,7 @@ func (ms *mqMsgStream) Next(ctx context.Context, channelName string) (TsMsg, err
}
// Seek reset the subscription associated with this consumer to a specific position
// Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive
// User has to ensure mq_msgstream is not closed before seek, and the seek position is already written.
func (ms *mqMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
for _, mp := range msgPositions {
......@@ -612,25 +610,12 @@ func (ms *mqMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
return err
}
log.Debug("MsgStream begin to seek", zap.Any("MessageID", mp.MsgID))
err = consumer.Seek(messageID)
err = consumer.Seek(messageID, false)
if err != nil {
log.Debug("Failed to seek", zap.Error(err))
return err
}
log.Debug("MsgStream seek finished", zap.Any("MessageID", messageID))
if consumer.ConsumeAfterSeek() {
log.Debug("MsgStream start to pop one message after seek")
msg, ok := <-consumer.Chan()
if !ok {
return errors.New("consumer closed")
}
log.Debug("MsgStream finish to pop one message after seek")
consumer.Ack(msg)
if !bytes.Equal(msg.ID().Serialize(), messageID.Serialize()) {
err = fmt.Errorf("seek msg not correct")
log.Error("msMsgStream seek", zap.Error(err))
}
}
}
return nil
}
......@@ -708,13 +693,11 @@ func (ms *MqTtMsgStream) AsConsumerWithPosition(channels []string, subName strin
continue
}
fn := func() error {
receiveChannel := make(chan mqclient.Message, ms.bufSize)
pc, err := ms.client.Subscribe(mqclient.ConsumerOptions{
Topic: channel,
SubscriptionName: subName,
Type: mqclient.KeyShared,
Type: mqclient.Exclusive,
SubscriptionInitialPosition: position,
MessageChannel: receiveChannel,
})
if err != nil {
return err
......@@ -955,7 +938,7 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
if err != nil {
return err
}
err = consumer.Seek(seekMsgID)
err = consumer.Seek(seekMsgID, true)
if err != nil {
return err
}
......@@ -975,16 +958,10 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
return fmt.Errorf("Failed to seek, error %s", err.Error())
}
ms.addConsumer(consumer, mp.ChannelName)
ms.chanMsgPos[consumer] = mp
// rmq seek behavior (position, ...)
// pulsar seek behavior [position, ...)
// skip one tt for pulsar
ms.chanMsgPos[consumer] = (proto.Clone(mp)).(*MsgPosition)
runLoop := false
if consumer.ConsumeAfterSeek() {
runLoop = true
}
// skip all data before current tt
runLoop := true
for runLoop {
select {
case <-ms.ctx.Done():
......
......@@ -34,6 +34,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
......@@ -738,11 +739,12 @@ func TestStream_PulsarTtMsgStream_NoSeek(t *testing.T) {
assert.Equal(t, o3.BeginTs, p3.BeginTs)
}
func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
consumerChannels := []string{c1, c2}
c1 := funcutil.RandomString(8)
producerChannels := []string{c1}
consumerChannels := []string{c1}
consumerSubName := funcutil.RandomString(8)
msgPack0 := MsgPack{}
......@@ -750,6 +752,7 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
msgPack1 := MsgPack{}
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 19))
msgPack2 := MsgPack{}
......@@ -763,7 +766,14 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
msgPack4.Msgs = append(msgPack4.Msgs, getTimeTickMsg(11))
msgPack5 := MsgPack{}
msgPack5.Msgs = append(msgPack5.Msgs, getTimeTickMsg(15))
msgPack5.Msgs = append(msgPack5.Msgs, getTsMsg(commonpb.MsgType_Insert, 12))
msgPack5.Msgs = append(msgPack5.Msgs, getTsMsg(commonpb.MsgType_Insert, 13))
msgPack6 := MsgPack{}
msgPack6.Msgs = append(msgPack6.Msgs, getTimeTickMsg(15))
msgPack7 := MsgPack{}
msgPack7.Msgs = append(msgPack7.Msgs, getTimeTickMsg(20))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(pulsarAddress, consumerChannels, consumerSubName)
......@@ -778,18 +788,66 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
assert.Nil(t, err)
err = inputStream.Broadcast(&msgPack4)
assert.Nil(t, err)
err = inputStream.Produce(&msgPack5)
assert.Nil(t, err)
err = inputStream.Broadcast(&msgPack6)
assert.Nil(t, err)
err = inputStream.Broadcast(&msgPack7)
assert.Nil(t, err)
outputStream.Consume()
receivedMsg := outputStream.Consume()
assert.Equal(t, len(receivedMsg.Msgs), 2)
assert.Equal(t, receivedMsg.BeginTs, uint64(0))
assert.Equal(t, receivedMsg.EndTs, uint64(5))
assert.Equal(t, receivedMsg.StartPositions[0].Timestamp, uint64(0))
assert.Equal(t, receivedMsg.EndPositions[0].Timestamp, uint64(5))
receivedMsg2 := outputStream.Consume()
assert.Equal(t, len(receivedMsg2.Msgs), 1)
assert.Equal(t, receivedMsg2.BeginTs, uint64(5))
assert.Equal(t, receivedMsg2.EndTs, uint64(11))
assert.Equal(t, receivedMsg2.StartPositions[0].Timestamp, uint64(5))
assert.Equal(t, receivedMsg2.EndPositions[0].Timestamp, uint64(11))
receivedMsg3 := outputStream.Consume()
assert.Equal(t, len(receivedMsg3.Msgs), 3)
assert.Equal(t, receivedMsg3.BeginTs, uint64(11))
assert.Equal(t, receivedMsg3.EndTs, uint64(15))
assert.Equal(t, receivedMsg3.StartPositions[0].Timestamp, uint64(11))
assert.Equal(t, receivedMsg3.EndPositions[0].Timestamp, uint64(15))
receivedMsg4 := outputStream.Consume()
assert.Equal(t, len(receivedMsg4.Msgs), 1)
assert.Equal(t, receivedMsg4.BeginTs, uint64(15))
assert.Equal(t, receivedMsg4.EndTs, uint64(20))
assert.Equal(t, receivedMsg4.StartPositions[0].Timestamp, uint64(15))
assert.Equal(t, receivedMsg4.EndPositions[0].Timestamp, uint64(20))
outputStream.Close()
outputStream = getPulsarTtOutputStreamAndSeek(pulsarAddress, receivedMsg.EndPositions)
err = inputStream.Broadcast(&msgPack5)
assert.Nil(t, err)
outputStream = getPulsarTtOutputStreamAndSeek(pulsarAddress, receivedMsg3.StartPositions)
seekMsg := outputStream.Consume()
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
for i, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), result[i])
}
seekMsg2 := outputStream.Consume()
assert.Equal(t, len(seekMsg2.Msgs), 1)
for _, msg := range seekMsg2.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
}
//outputStream.Close()
outputStream = getPulsarTtOutputStreamAndSeek(pulsarAddress, receivedMsg3.EndPositions)
seekMsg = outputStream.Consume()
assert.Equal(t, len(seekMsg.Msgs), 1)
for _, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(14))
assert.Equal(t, msg.BeginTs(), uint64(19))
}
inputStream.Close()
outputStream.Close()
}
......@@ -1061,12 +1119,12 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
c := funcutil.RandomString(8)
producerChannels := []string{c}
consumerChannels := []string{c}
consumerSubName := funcutil.RandomString(8)
msgPack := &MsgPack{}
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName)
defer inputStream.Close()
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, funcutil.RandomString(8))
defer outputStream.Close()
for i := 0; i < 10; i++ {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
......@@ -1080,16 +1138,15 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
assert.Equal(t, result.Msgs[0].ID(), int64(i))
seekPosition = result.EndPositions[0]
}
outputStream.Close()
factory := ProtoUDFactory{}
pulsarClient, _ := mqclient.GetPulsarClientInstance(pulsar.ClientOptions{URL: pulsarAddress})
outputStream2, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream2.AsConsumer(consumerChannels, consumerSubName)
outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8))
defer outputStream2.Close()
messageID, _ := pulsar.DeserializeMessageID(seekPosition.MsgID)
// try to seek to not written position
patchMessageID(&messageID, 11)
patchMessageID(&messageID, 13)
p := []*internalpb.MsgPosition{
{
......@@ -1100,13 +1157,78 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
},
}
go func() {
time.Sleep(1 * time.Second)
outputStream2.Close()
}()
err = outputStream2.Seek(p)
assert.Nil(t, err)
outputStream2.Start()
for i := 10; i < 20; i++ {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err = inputStream.Produce(msgPack)
assert.Nil(t, err)
result := outputStream2.Consume()
assert.Equal(t, result.Msgs[0].ID(), int64(1))
}
func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
rocksdbName := "/tmp/rocksmq_tt_msg_seekInvalid"
etcdKV := initRmq(rocksdbName)
c := funcutil.RandomString(8)
producerChannels := []string{c}
consumerChannels := []string{c}
consumerSubName := funcutil.RandomString(8)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
msgPack := &MsgPack{}
for i := 0; i < 10; i++ {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err := inputStream.Produce(msgPack)
assert.Nil(t, err)
var seekPosition *internalpb.MsgPosition
for i := 0; i < 10; i++ {
result := outputStream.Consume()
assert.Equal(t, result.Msgs[0].ID(), int64(i))
seekPosition = result.EndPositions[0]
}
outputStream.Close()
factory := ProtoUDFactory{}
rmqClient2, _ := mqclient.NewRmqClient(client.ClientOptions{Server: rocksmq.Rmq})
outputStream2, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8))
id := common.Endian.Uint64(seekPosition.MsgID) + 10
bs := make([]byte, 8)
common.Endian.PutUint64(bs, id)
p := []*internalpb.MsgPosition{
{
ChannelName: seekPosition.ChannelName,
Timestamp: seekPosition.Timestamp,
MsgGroup: seekPosition.MsgGroup,
MsgID: bs,
},
}
err = outputStream2.Seek(p)
assert.Error(t, err)
assert.Nil(t, err)
outputStream2.Start()
for i := 10; i < 20; i++ {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err = inputStream.Produce(msgPack)
assert.Nil(t, err)
result := outputStream2.Consume()
assert.Equal(t, result.Msgs[0].ID(), int64(1))
Close(rocksdbName, inputStream, outputStream2, etcdKV)
}
func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
......@@ -1332,6 +1454,118 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
Close(rocksdbName, inputStream, outputStream, etcdKV)
}
func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
rocksdbName := "/tmp/rocksmq_tt_msg_seek"
etcdKV := initRmq(rocksdbName)
c1 := funcutil.RandomString(8)
producerChannels := []string{c1}
consumerChannels := []string{c1}
consumerSubName := funcutil.RandomString(8)
msgPack0 := MsgPack{}
msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0))
msgPack1 := MsgPack{}
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 19))
msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5))
msgPack3 := MsgPack{}
msgPack3.Msgs = append(msgPack3.Msgs, getTsMsg(commonpb.MsgType_Insert, 14))
msgPack3.Msgs = append(msgPack3.Msgs, getTsMsg(commonpb.MsgType_Insert, 9))
msgPack4 := MsgPack{}
msgPack4.Msgs = append(msgPack4.Msgs, getTimeTickMsg(11))
msgPack5 := MsgPack{}
msgPack5.Msgs = append(msgPack5.Msgs, getTsMsg(commonpb.MsgType_Insert, 12))
msgPack5.Msgs = append(msgPack5.Msgs, getTsMsg(commonpb.MsgType_Insert, 13))
msgPack6 := MsgPack{}
msgPack6.Msgs = append(msgPack6.Msgs, getTimeTickMsg(15))
msgPack7 := MsgPack{}
msgPack7.Msgs = append(msgPack7.Msgs, getTimeTickMsg(20))
inputStream, outputStream := initRmqTtStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(&msgPack0)
assert.Nil(t, err)
err = inputStream.Produce(&msgPack1)
assert.Nil(t, err)
err = inputStream.Broadcast(&msgPack2)
assert.Nil(t, err)
err = inputStream.Produce(&msgPack3)
assert.Nil(t, err)
err = inputStream.Broadcast(&msgPack4)
assert.Nil(t, err)
err = inputStream.Produce(&msgPack5)
assert.Nil(t, err)
err = inputStream.Broadcast(&msgPack6)
assert.Nil(t, err)
err = inputStream.Broadcast(&msgPack7)
assert.Nil(t, err)
receivedMsg := outputStream.Consume()
assert.Equal(t, len(receivedMsg.Msgs), 2)
assert.Equal(t, receivedMsg.BeginTs, uint64(0))
assert.Equal(t, receivedMsg.EndTs, uint64(5))
assert.Equal(t, receivedMsg.StartPositions[0].Timestamp, uint64(0))
assert.Equal(t, receivedMsg.EndPositions[0].Timestamp, uint64(5))
receivedMsg2 := outputStream.Consume()
assert.Equal(t, len(receivedMsg2.Msgs), 1)
assert.Equal(t, receivedMsg2.BeginTs, uint64(5))
assert.Equal(t, receivedMsg2.EndTs, uint64(11))
assert.Equal(t, receivedMsg2.StartPositions[0].Timestamp, uint64(5))
assert.Equal(t, receivedMsg2.EndPositions[0].Timestamp, uint64(11))
receivedMsg3 := outputStream.Consume()
assert.Equal(t, len(receivedMsg3.Msgs), 3)
assert.Equal(t, receivedMsg3.BeginTs, uint64(11))
assert.Equal(t, receivedMsg3.EndTs, uint64(15))
assert.Equal(t, receivedMsg3.StartPositions[0].Timestamp, uint64(11))
assert.Equal(t, receivedMsg3.EndPositions[0].Timestamp, uint64(15))
receivedMsg4 := outputStream.Consume()
assert.Equal(t, len(receivedMsg4.Msgs), 1)
assert.Equal(t, receivedMsg4.BeginTs, uint64(15))
assert.Equal(t, receivedMsg4.EndTs, uint64(20))
assert.Equal(t, receivedMsg4.StartPositions[0].Timestamp, uint64(15))
assert.Equal(t, receivedMsg4.EndPositions[0].Timestamp, uint64(20))
outputStream.Close()
factory := ProtoUDFactory{}
rmqClient, _ := mqclient.NewRmqClient(client.ClientOptions{Server: rocksmq.Rmq})
outputStream, _ = NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
consumerSubName = funcutil.RandomString(8)
outputStream.AsConsumer(consumerChannels, consumerSubName)
outputStream.Seek(receivedMsg3.StartPositions)
outputStream.Start()
seekMsg := outputStream.Consume()
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
for i, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), result[i])
}
seekMsg2 := outputStream.Consume()
assert.Equal(t, len(seekMsg2.Msgs), 1)
for _, msg := range seekMsg2.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
}
Close(rocksdbName, inputStream, outputStream, etcdKV)
}
func TestStream_BroadcastMark(t *testing.T) {
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1 := funcutil.RandomString(8)
......@@ -1722,7 +1956,7 @@ func getPulsarTtOutputStreamAndSeek(pulsarAddress string, positions []*MsgPositi
for _, c := range positions {
consumerName = append(consumerName, c.ChannelName)
}
outputStream.AsConsumer(consumerName, positions[0].MsgGroup)
outputStream.AsConsumer(consumerName, funcutil.RandomString(8))
outputStream.Seek(positions)
outputStream.Start()
return outputStream
......
......@@ -43,6 +43,8 @@ func TestQueryNodeFlowGraph_consumerFlowGraph(t *testing.T) {
err = fg.consumerFlowGraph(defaultVChannel, defaultSubName)
assert.NoError(t, err)
fg.close()
}
func TestQueryNodeFlowGraph_seekQueryNodeFlowGraph(t *testing.T) {
......@@ -74,4 +76,6 @@ func TestQueryNodeFlowGraph_seekQueryNodeFlowGraph(t *testing.T) {
}
err = fg.seekQueryNodeFlowGraph(position)
assert.Error(t, err)
fg.close()
}
......@@ -154,7 +154,8 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
}
return status, err
}
log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels))
log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels),
zap.String("seek position", string(in.SeekPosition.MsgID)))
}
}
......
......@@ -59,10 +59,6 @@ type ConsumerOptions struct {
// Default is `Latest`
SubscriptionInitialPosition
// Message for this consumer
// When a message is received, it will be pushed to this channel for consumption
MessageChannel chan Message
// Set receive channel size
BufSize int64
......@@ -80,14 +76,11 @@ type Consumer interface {
Chan() <-chan Message
// Seek to the uniqueID position
Seek(MessageID) error //nolint:govet
Seek(MessageID, bool) error //nolint:govet
// Make sure that msg is received. Only used in pulsar
Ack(Message)
// ConsumeAfterSeek defines the behavior whether to consume after seeking is done
ConsumeAfterSeek() bool
// Close consumer
Close()
}
......@@ -124,7 +124,7 @@ func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string,
assert.NotNil(t, consumer)
defer consumer.Close()
err = consumer.Seek(msgID)
err = consumer.Seek(msgID, true)
assert.Nil(t, err)
// skip the last received message
......@@ -376,7 +376,7 @@ func TestPulsarClient_Consume2(t *testing.T) {
log.Info("main done")
}
func TestPulsarClient_Seek(t *testing.T) {
func TestPulsarClient_SeekPosition(t *testing.T) {
pulsarAddress, _ := Params.Load("_PulsarAddress")
pc, err := GetPulsarClientInstance(pulsar.ClientOptions{URL: pulsarAddress})
defer pc.Close()
......@@ -393,14 +393,15 @@ func TestPulsarClient_Seek(t *testing.T) {
assert.NotNil(t, producer)
log.Info("Produce start")
var id MessageID
ids := []MessageID{}
arr := []int{1, 2, 3}
for _, v := range arr {
msg := &ProducerMessage{
Payload: IntToBytes(v),
Properties: map[string]string{},
}
id, err = producer.Send(ctx, msg)
id, err := producer.Send(ctx, msg)
ids = append(ids, id)
assert.Nil(t, err)
}
......@@ -415,16 +416,99 @@ func TestPulsarClient_Seek(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, consumer)
defer consumer.Close()
seekID := id.(*pulsarID).messageID
seekID := ids[2].(*pulsarID).messageID
consumer.Seek(seekID)
msgChan := consumer.Chan()
select {
case msg := <-msgChan:
assert.Equal(t, seekID.BatchIdx(), msg.ID().BatchIdx())
assert.Equal(t, seekID.LedgerID(), msg.ID().LedgerID())
assert.Equal(t, seekID.EntryID(), msg.ID().EntryID())
assert.Equal(t, seekID.PartitionIdx(), msg.ID().PartitionIdx())
assert.Equal(t, 3, BytesToInt(msg.Payload()))
case <-time.After(2 * time.Second):
log.Info("after 2 seconds")
assert.FailNow(t, "should not wait")
}
seekID = ids[1].(*pulsarID).messageID
consumer.Seek(seekID)
msgChan = consumer.Chan()
select {
case msg := <-msgChan:
assert.Equal(t, seekID.BatchIdx(), msg.ID().BatchIdx())
assert.Equal(t, seekID.LedgerID(), msg.ID().LedgerID())
assert.Equal(t, seekID.EntryID(), msg.ID().EntryID())
assert.Equal(t, seekID.PartitionIdx(), msg.ID().PartitionIdx())
assert.Equal(t, 2, BytesToInt(msg.Payload()))
case <-time.After(2 * time.Second):
assert.FailNow(t, "should not wait")
}
}
func TestPulsarClient_SeekLatest(t *testing.T) {
pulsarAddress, _ := Params.Load("_PulsarAddress")
pc, err := GetPulsarClientInstance(pulsar.ClientOptions{URL: pulsarAddress})
defer pc.Close()
assert.NoError(t, err)
assert.NotNil(t, pc)
rand.Seed(time.Now().UnixNano())
ctx := context.Background()
topic := fmt.Sprintf("test-topic-%d", rand.Int())
subName := fmt.Sprintf("test-subname-%d", rand.Int())
producer, err := pc.CreateProducer(ProducerOptions{Topic: topic})
assert.Nil(t, err)
assert.NotNil(t, producer)
log.Info("Produce start")
arr := []int{1, 2, 3}
for _, v := range arr {
msg := &ProducerMessage{
Payload: IntToBytes(v),
Properties: map[string]string{},
}
_, err = producer.Send(ctx, msg)
assert.Nil(t, err)
}
log.Info("Produced")
consumer, err := pc.client.Subscribe(pulsar.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
Type: pulsar.KeyShared,
SubscriptionInitialPosition: pulsar.SubscriptionPositionLatest,
})
assert.Nil(t, err)
assert.NotNil(t, consumer)
defer consumer.Close()
msgChan := consumer.Chan()
loop := true
for loop {
select {
case msg := <-msgChan:
consumer.Ack(msg)
v := BytesToInt(msg.Payload())
log.Info("RECV", zap.Any("v", v))
assert.Equal(t, v, 4)
loop = false
case <-time.After(2 * time.Second):
log.Info("after 2 seconds")
msg := &ProducerMessage{
Payload: IntToBytes(4),
Properties: map[string]string{},
}
_, err = producer.Send(ctx, msg)
assert.Nil(t, err)
}
}
}
......
......@@ -27,6 +27,7 @@ type PulsarConsumer struct {
AtLatest bool
closeCh chan struct{}
once sync.Once
skip bool
}
func (pc *PulsarConsumer) Subscription() string {
......@@ -58,7 +59,11 @@ func (pc *PulsarConsumer) Chan() <-chan Message {
log.Debug("pulsar consumer channel closed")
return
}
pc.msgChannel <- &pulsarMessage{msg: msg}
if !pc.skip {
pc.msgChannel <- &pulsarMessage{msg: msg}
} else {
pc.skip = false
}
case <-pc.closeCh: // workaround for pulsar consumer.receiveCh not closed
close(pc.msgChannel)
return
......@@ -72,20 +77,17 @@ func (pc *PulsarConsumer) Chan() <-chan Message {
// Seek seek consume position to the pointed messageID,
// the pointed messageID will be consumed after the seek in pulsar
func (pc *PulsarConsumer) Seek(id MessageID) error {
func (pc *PulsarConsumer) Seek(id MessageID, inclusive bool) error {
messageID := id.(*pulsarID).messageID
err := pc.c.Seek(messageID)
if err == nil {
pc.hasSeek = true
// skip the first message when consume
pc.skip = !inclusive
}
return err
}
// ConsumeAfterSeek defines pulsar consumer SHOULD consume after seek
func (pc *PulsarConsumer) ConsumeAfterSeek() bool {
return true
}
func (pc *PulsarConsumer) Ack(message Message) {
pm := message.(*pulsarMessage)
pc.c.Ack(pm.msg)
......
......@@ -136,7 +136,7 @@ func TestRmqClient_Subscribe(t *testing.T) {
msgID := rmqmsg.ID()
rID := msgID.(*rmqID)
assert.NotZero(t, rID)
err = consumer.Seek(msgID)
err = consumer.Seek(msgID, true)
assert.Nil(t, err)
}
}
......
......@@ -23,6 +23,7 @@ type RmqConsumer struct {
msgChannel chan Message
closeCh chan struct{}
once sync.Once
skip bool
}
// Subscription returns the subscription name of this consumer
......@@ -43,7 +44,11 @@ func (rc *RmqConsumer) Chan() <-chan Message {
close(rc.msgChannel)
return
}
rc.msgChannel <- &rmqMessage{msg: msg}
if !rc.skip {
rc.msgChannel <- &rmqMessage{msg: msg}
} else {
rc.skip = false
}
case <-rc.closeCh:
close(rc.msgChannel)
return
......@@ -56,16 +61,13 @@ func (rc *RmqConsumer) Chan() <-chan Message {
}
// Seek is used to seek the position in rocksmq topic
func (rc *RmqConsumer) Seek(id MessageID) error {
func (rc *RmqConsumer) Seek(id MessageID, inclusive bool) error {
msgID := id.(*rmqID).messageID
// skip the first message when consume
rc.skip = !inclusive
return rc.c.Seek(msgID)
}
// ConsumeAfterSeek defines rmq consumer should NOT consume after seek
func (rc *RmqConsumer) ConsumeAfterSeek() bool {
return false
}
// Ack is used to ask a rocksmq message
func (rc *RmqConsumer) Ack(message Message) {
}
......
......@@ -123,36 +123,46 @@ func (c *client) consume(consumer *consumer) {
select {
case <-c.closeCh:
return
case _, ok := <-consumer.initCh:
if !ok {
return
}
c.deliver(consumer, 100)
case _, ok := <-consumer.MsgMutex():
if !ok {
// consumer MsgMutex closed, goroutine exit
log.Debug("Consumer MsgMutex closed")
return
}
c.deliver(consumer, 100)
}
}
}
for {
n := cap(consumer.messageCh) - len(consumer.messageCh)
if n < 100 { // batch min size
n = 100
}
msgs, err := consumer.client.server.Consume(consumer.topic, consumer.consumerName, n)
if err != nil {
log.Debug("Consumer's goroutine cannot consume from (" + consumer.topic +
"," + consumer.consumerName + "): " + err.Error())
break
}
// no more msgs
if len(msgs) == 0 {
break
}
for _, msg := range msgs {
consumer.messageCh <- Message{
MsgID: msg.MsgID,
Payload: msg.Payload,
Topic: consumer.Topic(),
}
}
func (c *client) deliver(consumer *consumer, batchMin int) {
for {
n := cap(consumer.messageCh) - len(consumer.messageCh)
if n < batchMin { // batch min size
n = batchMin
}
msgs, err := consumer.client.server.Consume(consumer.topic, consumer.consumerName, n)
if err != nil {
log.Warn("Consumer's goroutine cannot consume from (" + consumer.topic + "," + consumer.consumerName + "): " + err.Error())
break
}
// no more msgs
if len(msgs) == 0 {
break
}
for _, msg := range msgs {
select {
case consumer.messageCh <- Message{
MsgID: msg.MsgID,
Payload: msg.Payload,
Topic: consumer.Topic()}:
case <-c.closeCh:
return
}
}
}
......
......@@ -13,6 +13,7 @@ package rocksmq
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
......@@ -120,6 +121,76 @@ func TestClient_Subscribe(t *testing.T) {
assert.NoError(t, err)
}
func TestClient_SeekLatest(t *testing.T) {
rmqPath := "/tmp/milvus/seekLatest"
rmq := newRocksMQ(rmqPath)
defer removePath(rmqPath)
client, err := NewClient(ClientOptions{
Server: rmq,
})
assert.NoError(t, err)
defer client.Close()
topicName := newTopicName()
opt := ConsumerOptions{
Topic: topicName,
SubscriptionName: newConsumerName(),
SubscriptionInitialPosition: SubscriptionPositionEarliest,
}
consumer1, err := client.Subscribe(opt)
assert.NoError(t, err)
assert.NotNil(t, consumer1)
producer, err := client.CreateProducer(ProducerOptions{
Topic: topicName,
})
assert.NotNil(t, producer)
assert.NoError(t, err)
msg := &ProducerMessage{
Payload: make([]byte, 10),
}
id, err := producer.Send(msg)
assert.Nil(t, err)
msgChan := consumer1.Chan()
msgRead, ok := <-msgChan
assert.Equal(t, ok, true)
assert.Equal(t, msgRead.MsgID, id)
consumer1.Close()
opt1 := ConsumerOptions{
Topic: topicName,
SubscriptionName: newConsumerName(),
SubscriptionInitialPosition: SubscriptionPositionLatest,
}
consumer2, err := client.Subscribe(opt1)
assert.NoError(t, err)
assert.NotNil(t, consumer2)
msgChan = consumer2.Chan()
loop := true
for loop {
select {
case msg := <-msgChan:
assert.Equal(t, len(msg.Payload), 8)
loop = false
case <-time.After(2 * time.Second):
msg := &ProducerMessage{
Payload: make([]byte, 8),
}
_, err = producer.Send(msg)
assert.Nil(t, err)
}
}
producer1, err := client.CreateProducer(ProducerOptions{
Topic: newTopicName(),
})
assert.NotNil(t, producer1)
assert.NoError(t, err)
}
func TestClient_consume(t *testing.T) {
rmqPath := "/tmp/milvus/test_client3"
rmq := newRocksMQ(rmqPath)
......@@ -148,8 +219,11 @@ func TestClient_consume(t *testing.T) {
msg := &ProducerMessage{
Payload: make([]byte, 10),
}
_, err = producer.Send(msg)
id, err := producer.Send(msg)
assert.Nil(t, err)
<-consumer.Chan()
msgChan := consumer.Chan()
msgConsume, ok := <-msgChan
assert.Equal(t, ok, true)
assert.Equal(t, id, msgConsume.MsgID)
}
......@@ -28,6 +28,7 @@ type consumer struct {
startOnce sync.Once
msgMutex chan struct{}
initCh chan struct{}
messageCh chan Message
}
......@@ -48,13 +49,16 @@ func newConsumer(c *client, options ConsumerOptions) (*consumer, error) {
if options.MessageChannel == nil {
messageCh = make(chan Message, 1)
}
// only used for
initCh := make(chan struct{}, 1)
initCh <- struct{}{}
return &consumer{
topic: options.Topic,
client: c,
consumerName: options.SubscriptionName,
options: options,
msgMutex: make(chan struct{}, 1),
initCh: initCh,
messageCh: messageCh,
}, nil
}
......
......@@ -18,6 +18,7 @@ import (
"math"
"path"
"strconv"
"strings"
"sync"
"time"
......@@ -444,7 +445,7 @@ func (rmq *rocksmq) Produce(topicName string, messages []ProducerMessage) ([]Uni
idStart, idEnd, err := rmq.idAllocator.Alloc(uint32(msgLen))
if err != nil {
log.Debug("RocksMQ: alloc id failed.")
log.Error("RocksMQ: alloc id failed.", zap.Error(err))
return []UniqueID{}, err
}
......@@ -488,7 +489,6 @@ func (rmq *rocksmq) Produce(topicName string, messages []ProducerMessage) ([]Uni
kvValues := make(map[string]string)
if beginIDValue == "0" {
log.Debug("RocksMQ: overwrite " + kvChannelBeginID + " with " + strconv.FormatInt(idStart, 10))
kvValues[kvChannelBeginID] = strconv.FormatInt(idStart, 10)
}
......@@ -614,18 +614,14 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum
log.Debug("RocksMQ: fixChannelName " + topicName + " failed")
return nil, err
}
dataKey := fixChanName + "/" + currentID
// msgID is DefaultMessageID means this is the first consume operation
// currentID may be not valid if the deprecated values has been removed, when
// we move currentID to first location.
// Note that we assume currentId is always correct and not larger than the latest endID.
if iter.Seek([]byte(dataKey)); currentID != DefaultMessageID && iter.Valid() {
iter.Next()
var dataKey string
if currentID == DefaultMessageID {
dataKey = fixChanName + "/"
} else {
newKey := fixChanName + "/"
iter.Seek([]byte(newKey))
dataKey = fixChanName + "/" + currentID
}
iter.Seek([]byte(dataKey))
offset := 0
for ; iter.Valid() && offset < n; iter.Next() {
......@@ -666,9 +662,8 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum
consumedIDs = append(consumedIDs, msg.MsgID)
}
newID := consumedIDs[len(consumedIDs)-1]
err = rmq.seek(topicName, groupName, newID)
err = rmq.moveConsumePos(topicName, groupName, newID+1)
if err != nil {
log.Debug("RocksMQ: Seek(" + groupName + "," + topicName + "," + strconv.FormatInt(newID, 10) + ") failed")
return nil, err
}
......@@ -690,13 +685,11 @@ func (rmq *rocksmq) seek(topicName string, groupName string, msgID UniqueID) err
log.Warn("RocksMQ: channel " + key + " not exists")
return fmt.Errorf("ConsumerGroup %s, channel %s not exists", groupName, topicName)
}
storeKey, err := combKey(topicName, msgID)
if err != nil {
log.Warn("RocksMQ: combKey(" + topicName + "," + strconv.FormatInt(msgID, 10) + ") failed")
return err
}
opts := gorocksdb.NewDefaultReadOptions()
defer opts.Destroy()
val, err := rmq.store.Get(opts, []byte(storeKey))
......@@ -705,14 +698,22 @@ func (rmq *rocksmq) seek(topicName string, groupName string, msgID UniqueID) err
log.Warn("RocksMQ: get " + storeKey + " failed")
return err
}
if !val.Exists() {
//skip seek if key is not found, this is the behavior as pulsar
return nil
}
/* Step II: Save current_id in kv */
err = rmq.kv.Save(key, strconv.FormatInt(msgID, 10))
return rmq.moveConsumePos(topicName, groupName, msgID)
}
func (rmq *rocksmq) moveConsumePos(topicName string, groupName string, msgID UniqueID) error {
key := constructCurrentID(topicName, groupName)
err := rmq.kv.Save(key, strconv.FormatInt(msgID, 10))
if err != nil {
log.Warn("RocksMQ: save " + key + " failed")
return err
}
return nil
}
......@@ -733,7 +734,7 @@ func (rmq *rocksmq) Seek(topicName string, groupName string, msgID UniqueID) err
return rmq.seek(topicName, groupName, msgID)
}
// SeekToLatest updates current id to the msg id of latest message
// SeekToLatest updates current id to the msg id of latest message + 1
func (rmq *rocksmq) SeekToLatest(topicName, groupName string) error {
rmq.storeMu.Lock()
defer rmq.storeMu.Unlock()
......@@ -745,39 +746,34 @@ func (rmq *rocksmq) SeekToLatest(topicName, groupName string) error {
readOpts := gorocksdb.NewDefaultReadOptions()
defer readOpts.Destroy()
readOpts.SetPrefixSameAsStart(true)
iter := rmq.store.NewIterator(readOpts)
defer iter.Close()
fixChanName, _ := fixChannelName(topicName)
iter.Seek([]byte(fixChanName + "/"))
iKey := iter.Key()
// iter.SeekToLast bypass prefix limitation
// use for range until iterator invalid for now
if iter.Valid() {
iter.Next()
for iter.Valid() {
iKey.Free()
iKey = iter.Key()
iter.Next()
}
} else {
// In this case there are no messages, so shouldn't return error
return nil
}
if iKey == nil {
// 0 is the ASC value of "/" + 1
iter.SeekForPrev([]byte(fixChanName + "0"))
// should find the last key we written into, start with fixChanName/
// if not find, start from 0
if !iter.Valid() {
return nil
}
seekMsgID := string(iKey.Data()) // bytes to string, copy
iKey := iter.Key()
seekMsgID := string(iKey.Data())
iKey.Free()
// if find message is not belong to current channel, start from 0
if !strings.Contains(seekMsgID, fixChanName+"/") {
return nil
}
msgID, err := strconv.ParseInt(seekMsgID[FixedChannelNameLen+1:], 10, 64)
if err != nil {
return err
}
err = rmq.kv.Save(key, strconv.FormatInt(msgID, 10))
return err
// current msgID should not be included
return rmq.moveConsumePos(topicName, groupName, msgID+1)
}
// Notify sends a mutex in MsgMutex channel to tell consumers to consume
......
......@@ -151,7 +151,7 @@ func TestRocksmq(t *testing.T) {
assert.Nil(t, err)
defer rmq.Close()
channelName := "channel_a"
channelName := "channel_rocks"
err = rmq.CreateTopic(channelName)
assert.Nil(t, err)
defer rmq.DestroyTopic(channelName)
......@@ -257,6 +257,65 @@ func TestRocksmq_Dummy(t *testing.T) {
}
func TestRocksmq_Seek(t *testing.T) {
suffix := "_seek"
kvPath := rmqPath + kvPathSuffix + suffix
defer os.RemoveAll(kvPath)
idAllocator := InitIDAllocator(kvPath)
rocksdbPath := rmqPath + dbPathSuffix + suffix
defer os.RemoveAll(rocksdbPath)
metaPath := rmqPath + metaPathSuffix + suffix
defer os.RemoveAll(metaPath)
rmq, err := NewRocksMQ(rocksdbPath, idAllocator)
assert.Nil(t, err)
defer rmq.Close()
_, err = NewRocksMQ("", idAllocator)
assert.Error(t, err)
channelName := "channel_seek"
err = rmq.CreateTopic(channelName)
assert.NoError(t, err)
defer rmq.DestroyTopic(channelName)
var seekID UniqueID
var seekID2 UniqueID
for i := 0; i < 100; i++ {
msg := "message_" + strconv.Itoa(i)
pMsg := ProducerMessage{Payload: []byte(msg)}
pMsgs := make([]ProducerMessage, 1)
pMsgs[0] = pMsg
id, err := rmq.Produce(channelName, pMsgs)
if i == 50 {
seekID = id[0]
}
if i == 51 {
seekID2 = id[0]
}
assert.Nil(t, err)
}
groupName1 := "group_dummy"
err = rmq.CreateConsumerGroup(channelName, groupName1)
assert.NoError(t, err)
err = rmq.Seek(channelName, groupName1, seekID)
assert.NoError(t, err)
messages, err := rmq.Consume(channelName, groupName1, 1)
assert.NoError(t, err)
assert.Equal(t, messages[0].MsgID, seekID)
messages, err = rmq.Consume(channelName, groupName1, 1)
assert.NoError(t, err)
assert.Equal(t, messages[0].MsgID, seekID2)
_ = rmq.DestroyConsumerGroup(channelName, groupName1)
}
func TestRocksmq_Loop(t *testing.T) {
ep := etcdEndpoints()
etcdKV, err := etcdkv.NewEtcdKV(ep, "/etcd/test/root")
......@@ -599,21 +658,57 @@ func TestRocksmq_SeekToLatest(t *testing.T) {
err = rmq.SeekToLatest(channelName, groupName)
assert.NoError(t, err)
channelNamePrev := "channel_tes"
err = rmq.CreateTopic(channelNamePrev)
assert.Nil(t, err)
defer rmq.DestroyTopic(channelNamePrev)
pMsgs := make([]ProducerMessage, loopNum)
for i := 0; i < loopNum; i++ {
msg := "message_" + strconv.Itoa(i+loopNum)
msg := "message_" + strconv.Itoa(i)
pMsg := ProducerMessage{Payload: []byte(msg)}
pMsgs[i] = pMsg
}
_, err = rmq.Produce(channelName, pMsgs)
_, err = rmq.Produce(channelNamePrev, pMsgs)
assert.Nil(t, err)
// should hit the case where channel is null
err = rmq.SeekToLatest(channelName, groupName)
assert.NoError(t, err)
ids, err := rmq.Produce(channelName, pMsgs)
assert.Nil(t, err)
// able to read out
cMsgs, err := rmq.Consume(channelName, groupName, loopNum)
assert.Nil(t, err)
assert.Equal(t, len(cMsgs), loopNum)
for i := 0; i < loopNum; i++ {
assert.Equal(t, cMsgs[i].MsgID, ids[i])
}
err = rmq.SeekToLatest(channelName, groupName)
assert.NoError(t, err)
cMsgs, err = rmq.Consume(channelName, groupName, loopNum)
assert.Nil(t, err)
assert.Equal(t, len(cMsgs), 0)
pMsgs = make([]ProducerMessage, loopNum)
for i := 0; i < loopNum; i++ {
msg := "message_" + strconv.Itoa(i+loopNum)
pMsg := ProducerMessage{Payload: []byte(msg)}
pMsgs[i] = pMsg
}
ids, err = rmq.Produce(channelName, pMsgs)
assert.Nil(t, err)
// make sure we only consume the latest message
cMsgs, err = rmq.Consume(channelName, groupName, loopNum)
assert.Nil(t, err)
assert.Equal(t, len(cMsgs), loopNum)
for i := 0; i < loopNum; i++ {
assert.Equal(t, cMsgs[i].MsgID, ids[i])
}
}
func TestRocksmq_Reader(t *testing.T) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册