提交 09922039 编写于 作者: Y yukun 提交者: yefu.chen

Add rmq_msgstream unittest and fix bugs

Signed-off-by: Nyukun <kun.yu@zilliz.com>
上级 e9ee9a27
......@@ -8,21 +8,18 @@ import (
type Factory struct {
dispatcherFactory msgstream.ProtoUDFactory
address string
receiveBufSize int64
pulsarBufSize int64
rmqBufSize int64
}
func (f *Factory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
return newRmqMsgStream(ctx, f.receiveBufSize, f.dispatcherFactory.NewUnmarshalDispatcher())
return newRmqMsgStream(ctx, f.receiveBufSize, f.rmqBufSize, f.dispatcherFactory.NewUnmarshalDispatcher())
}
func NewFactory(address string, receiveBufSize int64, pulsarBufSize int64) *Factory {
f := &Factory{
dispatcherFactory: msgstream.ProtoUDFactory{},
address: address,
receiveBufSize: receiveBufSize,
pulsarBufSize: pulsarBufSize,
}
return f
}
......@@ -42,18 +42,24 @@ type RmqMsgStream struct {
receiveBuf chan *msgstream.MsgPack
wait *sync.WaitGroup
// tso ticker
streamCancel func()
streamCancel func()
rmqBufSize int64
consumerReflects []reflect.SelectCase
}
func newRmqMsgStream(ctx context.Context, receiveBufSize int64,
func newRmqMsgStream(ctx context.Context, receiveBufSize int64, rmqBufSize int64,
unmarshal msgstream.UnmarshalDispatcher) (*RmqMsgStream, error) {
streamCtx, streamCancel := context.WithCancel(ctx)
receiveBuf := make(chan *msgstream.MsgPack, receiveBufSize)
consumerReflects := make([]reflect.SelectCase, 0)
stream := &RmqMsgStream{
ctx: streamCtx,
receiveBuf: receiveBuf,
unmarshal: unmarshal,
streamCancel: streamCancel,
ctx: streamCtx,
receiveBuf: receiveBuf,
unmarshal: unmarshal,
streamCancel: streamCancel,
rmqBufSize: rmqBufSize,
consumerReflects: consumerReflects,
}
return stream, nil
......@@ -68,6 +74,17 @@ func (ms *RmqMsgStream) Start() {
}
func (ms *RmqMsgStream) Close() {
ms.streamCancel()
for _, producer := range ms.producers {
if producer != "" {
_ = rocksmq.Rmq.DestroyChannel(producer)
}
}
for _, consumer := range ms.consumers {
_ = rocksmq.Rmq.DestroyConsumerGroup(consumer.GroupName, consumer.ChannelName)
close(consumer.MsgNum)
}
}
type propertiesReaderWriter struct {
......@@ -85,16 +102,22 @@ func (ms *RmqMsgStream) AsProducer(channels []string) {
errMsg := "Failed to create producer " + channel + ", error = " + err.Error()
panic(errMsg)
}
ms.producers = append(ms.producers, channel)
}
}
func (ms *RmqMsgStream) AsConsumer(channels []string, groupName string) {
for _, channelName := range channels {
if err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName); err != nil {
consumer, err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName)
if err != nil {
panic(err.Error())
}
msgNum := make(chan int)
ms.consumers = append(ms.consumers, rocksmq.Consumer{GroupName: groupName, ChannelName: channelName, MsgNum: msgNum})
consumer.MsgNum = make(chan int, ms.rmqBufSize)
ms.consumers = append(ms.consumers, *consumer)
ms.consumerReflects = append(ms.consumerReflects, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(consumer.MsgNum),
})
}
}
......@@ -240,12 +263,6 @@ func (ms *RmqMsgStream) Consume() *msgstream.MsgPack {
func (ms *RmqMsgStream) bufMsgPackToChannel() {
defer ms.wait.Done()
cases := make([]reflect.SelectCase, len(ms.consumers))
for i := 0; i < len(ms.consumers); i++ {
ch := ms.consumers[i].MsgNum
cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}
}
for {
select {
case <-ms.ctx.Done():
......@@ -255,7 +272,7 @@ func (ms *RmqMsgStream) bufMsgPackToChannel() {
tsMsgList := make([]msgstream.TsMsg, 0)
for {
chosen, value, ok := reflect.Select(cases)
chosen, value, ok := reflect.Select(ms.consumerReflects)
if !ok {
log.Printf("channel closed")
return
......
package rmqms
import (
"context"
"fmt"
"log"
"os"
"testing"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
"github.com/zilliztech/milvus-distributed/internal/util/rocksmq"
"go.etcd.io/etcd/clientv3"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
)
var rocksmqName string = "/tmp/rocksmq"
func repackFunc(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) {
result := make(map[int32]*MsgPack)
for i, request := range msgs {
keys := hashKeys[i]
for _, channelID := range keys {
_, ok := result[channelID]
if ok == false {
msgPack := MsgPack{}
result[channelID] = &msgPack
}
result[channelID].Msgs = append(result[channelID].Msgs, request)
}
}
return result, nil
}
func getTsMsg(msgType MsgType, reqID UniqueID, hashValue uint32) TsMsg {
baseMsg := BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []uint32{hashValue},
}
switch msgType {
case commonpb.MsgType_kInsert:
insertRequest := internalpb2.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kInsert,
MsgID: reqID,
Timestamp: 11,
SourceID: reqID,
},
CollectionName: "Collection",
PartitionName: "Partition",
SegmentID: 1,
ChannelID: "0",
Timestamps: []Timestamp{uint64(reqID)},
RowIDs: []int64{1},
RowData: []*commonpb.Blob{{}},
}
insertMsg := &msgstream.InsertMsg{
BaseMsg: baseMsg,
InsertRequest: insertRequest,
}
return insertMsg
case commonpb.MsgType_kDelete:
deleteRequest := internalpb2.DeleteRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kDelete,
MsgID: reqID,
Timestamp: 11,
SourceID: reqID,
},
CollectionName: "Collection",
ChannelID: "1",
Timestamps: []Timestamp{1},
PrimaryKeys: []IntPrimaryKey{1},
}
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: baseMsg,
DeleteRequest: deleteRequest,
}
return deleteMsg
case commonpb.MsgType_kSearch:
searchRequest := internalpb2.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kSearch,
MsgID: reqID,
Timestamp: 11,
SourceID: reqID,
},
Query: nil,
ResultChannelID: "0",
}
searchMsg := &msgstream.SearchMsg{
BaseMsg: baseMsg,
SearchRequest: searchRequest,
}
return searchMsg
case commonpb.MsgType_kSearchResult:
searchResult := internalpb2.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kSearchResult,
MsgID: reqID,
Timestamp: 1,
SourceID: reqID,
},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
ResultChannelID: "0",
}
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: baseMsg,
SearchResults: searchResult,
}
return searchResultMsg
case commonpb.MsgType_kTimeTick:
timeTickResult := internalpb2.TimeTickMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kTimeTick,
MsgID: reqID,
Timestamp: 1,
SourceID: reqID,
},
}
timeTickMsg := &TimeTickMsg{
BaseMsg: baseMsg,
TimeTickMsg: timeTickResult,
}
return timeTickMsg
case commonpb.MsgType_kQueryNodeStats:
queryNodeSegStats := internalpb2.QueryNodeStats{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kQueryNodeStats,
SourceID: reqID,
},
}
queryNodeSegStatsMsg := &QueryNodeStatsMsg{
BaseMsg: baseMsg,
QueryNodeStats: queryNodeSegStats,
}
return queryNodeSegStatsMsg
}
return nil
}
func initRmq(name string) *etcdkv.EtcdKV {
etcdAddr := os.Getenv("ETCD_ADDRESS")
if etcdAddr == "" {
etcdAddr = "localhost:2379"
}
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
if err != nil {
log.Fatalf("New clientv3 error = %v", err)
}
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
idAllocator := rocksmq.NewGlobalIDAllocator("dummy", etcdKV)
_ = idAllocator.Initialize()
err = rocksmq.InitRmq(name, idAllocator)
if err != nil {
log.Fatalf("InitRmq error = %v", err)
}
return etcdKV
}
func Close(intputStream, outputStream msgstream.MsgStream, etcdKV *etcdkv.EtcdKV) {
intputStream.Close()
outputStream.Close()
etcdKV.Close()
_ = os.RemoveAll(rocksmqName)
}
func initRmqStream(producerChannels []string,
consumerChannels []string,
consumerGroupName string,
opts ...RepackFunc) (msgstream.MsgStream, msgstream.MsgStream) {
factory := msgstream.ProtoUDFactory{}
inputStream, _ := newRmqMsgStream(context.Background(), 100, 100, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
for _, opt := range opts {
inputStream.SetRepackFunc(opt)
}
inputStream.Start()
var input msgstream.MsgStream = inputStream
outputStream, _ := newRmqMsgStream(context.Background(), 100, 100, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(consumerChannels, consumerGroupName)
outputStream.Start()
var output msgstream.MsgStream = outputStream
return input, output
}
func receiveMsg(outputStream msgstream.MsgStream, msgCount int) {
receiveCount := 0
for {
result := outputStream.Consume()
if len(result.Msgs) > 0 {
msgs := result.Msgs
for _, v := range msgs {
receiveCount++
fmt.Println("msg type: ", v.Type(), ", msg value: ", v)
}
}
if receiveCount >= msgCount {
break
}
}
}
func TestStream_RmqMsgStream_Insert(t *testing.T) {
producerChannels := []string{"insert1", "insert2"}
consumerChannels := []string{"insert1", "insert2"}
consumerGroupName := "InsertGroup"
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 1, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 3, 3))
etcdKV := initRmq("/tmp/rocksmq_insert")
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerGroupName)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
receiveMsg(outputStream, len(msgPack.Msgs))
Close(inputStream, outputStream, etcdKV)
}
func TestStream_RmqMsgStream_Delete(t *testing.T) {
producerChannels := []string{"delete"}
consumerChannels := []string{"delete"}
consumerSubName := "subDelete"
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kDelete, 1, 1))
etcdKV := initRmq("/tmp/rocksmq_delete")
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
receiveMsg(outputStream, len(msgPack.Msgs))
Close(inputStream, outputStream, etcdKV)
}
func TestStream_RmqMsgStream_Search(t *testing.T) {
producerChannels := []string{"search"}
consumerChannels := []string{"search"}
consumerSubName := "subSearch"
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearch, 1, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearch, 3, 3))
etcdKV := initRmq("/tmp/rocksmq_search")
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
receiveMsg(outputStream, len(msgPack.Msgs))
Close(inputStream, outputStream, etcdKV)
}
func TestStream_RmqMsgStream_SearchResult(t *testing.T) {
producerChannels := []string{"searchResult"}
consumerChannels := []string{"searchResult"}
consumerSubName := "subSearchResult"
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearchResult, 1, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearchResult, 3, 3))
etcdKV := initRmq("/tmp/rocksmq_searchresult")
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
receiveMsg(outputStream, len(msgPack.Msgs))
Close(inputStream, outputStream, etcdKV)
}
func TestStream_RmqMsgStream_TimeTick(t *testing.T) {
producerChannels := []string{"timeTick"}
consumerChannels := []string{"timeTick"}
consumerSubName := "subTimeTick"
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 1, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 3, 3))
etcdKV := initRmq("/tmp/rocksmq_timetick")
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
receiveMsg(outputStream, len(msgPack.Msgs))
Close(inputStream, outputStream, etcdKV)
}
func TestStream_RmqMsgStream_BroadCast(t *testing.T) {
producerChannels := []string{"insert1", "insert2"}
consumerChannels := []string{"insert1", "insert2"}
consumerSubName := "subInsert"
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 1, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 3, 3))
etcdKV := initRmq("/tmp/rocksmq_broadcast")
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
receiveMsg(outputStream, len(consumerChannels)*len(msgPack.Msgs))
Close(inputStream, outputStream, etcdKV)
}
func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) {
producerChannels := []string{"insert1", "insert2"}
consumerChannels := []string{"insert1", "insert2"}
consumerSubName := "subInsert"
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 1, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 3, 3))
etcdKV := initRmq("/tmp/rocksmq_repackfunc")
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName, repackFunc)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
receiveMsg(outputStream, len(msgPack.Msgs))
Close(inputStream, outputStream, etcdKV)
}
......@@ -76,7 +76,7 @@ type RocksMQ struct {
produceMu sync.Mutex
consumeMu sync.Mutex
notify map[string][]Consumer
notify map[string][]*Consumer
//ctx context.Context
//serverLoopWg sync.WaitGroup
//serverLoopCtx context.Context
......@@ -107,7 +107,7 @@ func NewRocksMQ(name string, idAllocator IDAllocator) (*RocksMQ, error) {
idAllocator: idAllocator,
}
rmq.channels = make(map[string]*Channel)
rmq.notify = make(map[string][]Consumer)
rmq.notify = make(map[string][]*Consumer)
return rmq, nil
}
......@@ -166,17 +166,24 @@ func (rmq *RocksMQ) DestroyChannel(channelName string) error {
return nil
}
func (rmq *RocksMQ) CreateConsumerGroup(groupName string, channelName string) error {
func (rmq *RocksMQ) CreateConsumerGroup(groupName string, channelName string) (*Consumer, error) {
key := groupName + "/" + channelName + "/current_id"
if rmq.checkKeyExist(key) {
return errors.New("ConsumerGroup " + groupName + " already exists.")
return nil, errors.New("ConsumerGroup " + groupName + " already exists.")
}
err := rmq.kv.Save(key, DefaultMessageID)
if err != nil {
return err
return nil, err
}
return nil
//msgNum := make(chan int, 100)
consumer := Consumer{
GroupName: groupName,
ChannelName: channelName,
//MsgNum: msgNum,
}
rmq.notify[channelName] = append(rmq.notify[channelName], &consumer)
return &consumer, nil
}
func (rmq *RocksMQ) DestroyConsumerGroup(groupName string, channelName string) error {
......@@ -243,7 +250,9 @@ func (rmq *RocksMQ) Produce(channelName string, messages []ProducerMessage) erro
}
for _, consumer := range rmq.notify[channelName] {
consumer.MsgNum <- msgLen
if consumer.MsgNum != nil {
consumer.MsgNum <- msgLen
}
}
return nil
}
......
......@@ -61,7 +61,7 @@ func TestRocksMQ(t *testing.T) {
groupName := "test_group"
_ = rmq.DestroyConsumerGroup(groupName, channelName)
err = rmq.CreateConsumerGroup(groupName, channelName)
_, err = rmq.CreateConsumerGroup(groupName, channelName)
assert.Nil(t, err)
cMsgs, err := rmq.Consume(groupName, channelName, 1)
assert.Nil(t, err)
......@@ -122,7 +122,7 @@ func TestRocksMQ_Loop(t *testing.T) {
// Consume loopNum message once
groupName := "test_group"
_ = rmq.DestroyConsumerGroup(groupName, channelName)
err = rmq.CreateConsumerGroup(groupName, channelName)
_, err = rmq.CreateConsumerGroup(groupName, channelName)
assert.Nil(t, err)
cMsgs, err := rmq.Consume(groupName, channelName, loopNum)
assert.Nil(t, err)
......@@ -189,7 +189,7 @@ func TestRocksMQ_Goroutines(t *testing.T) {
groupName := "test_group"
_ = rmq.DestroyConsumerGroup(groupName, channelName)
err = rmq.CreateConsumerGroup(groupName, channelName)
_, err = rmq.CreateConsumerGroup(groupName, channelName)
assert.Nil(t, err)
// Consume one message in each goroutine
for i := 0; i < loopNum; i++ {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册