diff --git a/internal/msgstream/mq_msgstream_test.go b/internal/msgstream/mq_msgstream_test.go index 70ec3ecc0ae950197d800084cce54e93b75fe844..f0d85b753aac30637b2b969ecd64168dc15503a1 100644 --- a/internal/msgstream/mq_msgstream_test.go +++ b/internal/msgstream/mq_msgstream_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "os" "strings" + "sync" "testing" "time" @@ -43,23 +44,31 @@ func TestMain(m *testing.M) { os.Exit(exitCode) } -func Test_NewMqMsgStream(t *testing.T) { +type fixture struct { + t *testing.T + etcdKV *etcdkv.EtcdKV +} + +type parameters struct { + client mqclient.Client +} + +func (f *fixture) setup() []parameters { pulsarAddress, _ := Params.Load("_PulsarAddress") - factory := &ProtoUDFactory{} pulsarClient, err := mqclient.GetPulsarClientInstance(pulsar.ClientOptions{URL: pulsarAddress}) - assert.Nil(t, err) + assert.Nil(f.t, err) - rocksdbName := "/tmp/rocksmq_unittest_" + t.Name() + rocksdbName := "/tmp/rocksmq_unittest_" + f.t.Name() endpoints := os.Getenv("ETCD_ENDPOINTS") if endpoints == "" { endpoints = "localhost:2379" } etcdEndpoints := strings.Split(endpoints, ",") - etcdKV, err := etcdkv.NewEtcdKV(etcdEndpoints, "/etcd/test/root") + f.etcdKV, err = etcdkv.NewEtcdKV(etcdEndpoints, "/etcd/test/root") if err != nil { log.Fatalf("New clientv3 error = %v", err) } - idAllocator := allocator.NewGlobalIDAllocator("dummy", etcdKV) + idAllocator := allocator.NewGlobalIDAllocator("dummy", f.etcdKV) _ = idAllocator.Initialize() err = rocksmq.InitRmq(rocksdbName, idAllocator) if err != nil { @@ -68,23 +77,225 @@ func Test_NewMqMsgStream(t *testing.T) { rmqClient, _ := mqclient.NewRmqClient(client.ClientOptions{Server: rocksmq.Rmq}) - parameters := []struct { - client mqclient.Client - }{ + parameters := []parameters{ {pulsarClient}, {rmqClient}, } + return parameters +} +func (f *fixture) teardown() { + rocksdbName := "/tmp/rocksmq_unittest_" + f.t.Name() + + rocksmq.CloseRocksMQ() + f.etcdKV.Close() + _ = os.RemoveAll(rocksdbName) + _ = os.RemoveAll(rocksdbName + "_meta_kv") +} + +func Test_NewMqMsgStream(t *testing.T) { + f := &fixture{t: t} + parameters := f.setup() + defer f.teardown() + + factory := &ProtoUDFactory{} for i := range parameters { func(client mqclient.Client) { _, err := NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) assert.Nil(t, err) }(parameters[i].client) } +} - rocksmq.CloseRocksMQ() - etcdKV.Close() - _ = os.RemoveAll(rocksdbName) - _ = os.RemoveAll(rocksdbName + "_meta_kv") +// TODO(wxyu): add a mock implement of mqclient.Client, then inject errors to improve coverage +func TestMqMsgStream_AsProducer(t *testing.T) { + f := &fixture{t: t} + parameters := f.setup() + defer f.teardown() + + factory := &ProtoUDFactory{} + for i := range parameters { + func(client mqclient.Client) { + m, err := NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.Nil(t, err) + + // empty channel name + m.AsProducer([]string{""}) + }(parameters[i].client) + } +} + +// TODO(wxyu): add a mock implement of mqclient.Client, then inject errors to improve coverage +func TestMqMsgStream_AsConsumer(t *testing.T) { + f := &fixture{t: t} + parameters := f.setup() + defer f.teardown() + + factory := &ProtoUDFactory{} + for i := range parameters { + func(client mqclient.Client) { + m, err := NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.Nil(t, err) + + // repeat calling AsConsumer + m.AsConsumer([]string{"a"}, "b") + m.AsConsumer([]string{"a"}, "b") + }(parameters[i].client) + } +} + +func TestMqMsgStream_ComputeProduceChannelIndexes(t *testing.T) { + f := &fixture{t: t} + parameters := f.setup() + defer f.teardown() + + factory := &ProtoUDFactory{} + for i := range parameters { + func(client mqclient.Client) { + m, err := NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.Nil(t, err) + + // empty parameters + reBucketValues := m.ComputeProduceChannelIndexes([]TsMsg{}) + assert.Nil(t, reBucketValues) + + // not called AsProducer yet + insertMsg := &InsertMsg{ + BaseMsg: generateBaseMsg(), + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 1, + Timestamp: 2, + SourceID: 3, + }, + + DbName: "test_db", + CollectionName: "test_collection", + PartitionName: "test_partition", + DbID: 4, + CollectionID: 5, + PartitionID: 6, + SegmentID: 7, + ShardName: "test-channel", + Timestamps: []uint64{2, 1, 3}, + RowData: []*commonpb.Blob{}, + }, + } + reBucketValues = m.ComputeProduceChannelIndexes([]TsMsg{insertMsg}) + assert.Nil(t, reBucketValues) + }(parameters[i].client) + } +} + +func TestMqMsgStream_GetProduceChannels(t *testing.T) { + f := &fixture{t: t} + parameters := f.setup() + defer f.teardown() + + factory := &ProtoUDFactory{} + for i := range parameters { + func(client mqclient.Client) { + m, err := NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.Nil(t, err) + + // empty if not called AsProducer yet + chs := m.GetProduceChannels() + assert.Equal(t, 0, len(chs)) + + // not empty after AsProducer + m.AsProducer([]string{"a"}) + chs = m.GetProduceChannels() + assert.Equal(t, 1, len(chs)) + }(parameters[i].client) + } +} + +func TestMqMsgStream_Produce(t *testing.T) { + f := &fixture{t: t} + parameters := f.setup() + defer f.teardown() + + factory := &ProtoUDFactory{} + for i := range parameters { + func(client mqclient.Client) { + m, err := NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.Nil(t, err) + + // Produce before called AsProducer + insertMsg := &InsertMsg{ + BaseMsg: generateBaseMsg(), + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 1, + Timestamp: 2, + SourceID: 3, + }, + + DbName: "test_db", + CollectionName: "test_collection", + PartitionName: "test_partition", + DbID: 4, + CollectionID: 5, + PartitionID: 6, + SegmentID: 7, + ShardName: "test-channel", + Timestamps: []uint64{2, 1, 3}, + RowData: []*commonpb.Blob{}, + }, + } + msgPack := &MsgPack{ + Msgs: []TsMsg{insertMsg}, + } + err = m.Produce(msgPack) + assert.NotNil(t, err) + }(parameters[i].client) + } +} + +func TestMqMsgStream_Broadcast(t *testing.T) { + f := &fixture{t: t} + parameters := f.setup() + defer f.teardown() + + factory := &ProtoUDFactory{} + for i := range parameters { + func(client mqclient.Client) { + m, err := NewMqMsgStream(context.Background(), 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.Nil(t, err) + + // Broadcast nil pointer + err = m.Broadcast(nil) + assert.Nil(t, err) + }(parameters[i].client) + } +} + +func TestMqMsgStream_Consume(t *testing.T) { + f := &fixture{t: t} + parameters := f.setup() + defer f.teardown() + + factory := &ProtoUDFactory{} + for i := range parameters { + func(client mqclient.Client) { + // Consume return nil when ctx canceled + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + m, err := NewMqMsgStream(ctx, 100, 100, client, factory.NewUnmarshalDispatcher()) + assert.Nil(t, err) + + wg.Add(1) + go func() { + defer wg.Done() + msgPack := m.Consume() + assert.Nil(t, msgPack) + }() + + cancel() + wg.Wait() + }(parameters[i].client) + } } /* ========================== Pulsar & RocksMQ Tests ========================== */