From 6e70ce3f66745e491eb87336cdcd23cbb8ba5a4c Mon Sep 17 00:00:00 2001 From: sunby Date: Thu, 8 Apr 2021 15:08:34 +0800 Subject: [PATCH] Add cluster unit tests Signed-off-by: sunby --- internal/dataservice/cluster.go | 36 +++++++++++--------- internal/dataservice/cluster_test.go | 51 +++++++++++++++++++++++++++- internal/dataservice/mock.go | 23 +++++++++---- 3 files changed, 87 insertions(+), 23 deletions(-) diff --git a/internal/dataservice/cluster.go b/internal/dataservice/cluster.go index f21dd6d53..a9df8ee31 100644 --- a/internal/dataservice/cluster.go +++ b/internal/dataservice/cluster.go @@ -25,7 +25,7 @@ type dataNode struct { channelNum int } type dataNodeCluster struct { - mu sync.RWMutex + sync.RWMutex finishCh chan struct{} nodes []*dataNode } @@ -42,8 +42,8 @@ func newDataNodeCluster(finishCh chan struct{}) *dataNodeCluster { } func (c *dataNodeCluster) Register(dataNode *dataNode) { - c.mu.Lock() - defer c.mu.Unlock() + c.Lock() + defer c.Unlock() if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) { c.nodes = append(c.nodes, dataNode) if len(c.nodes) == Params.DataNodeNum { @@ -62,23 +62,25 @@ func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool { } func (c *dataNodeCluster) GetNumOfNodes() int { + c.RLock() + defer c.RUnlock() return len(c.nodes) } func (c *dataNodeCluster) GetNodeIDs() []int64 { - c.mu.RLock() - defer c.mu.RUnlock() - ret := make([]int64, len(c.nodes)) - for i, node := range c.nodes { - ret[i] = node.id + c.RLock() + defer c.RUnlock() + ret := make([]int64, 0, len(c.nodes)) + for _, node := range c.nodes { + ret = append(ret, node.id) } return ret } func (c *dataNodeCluster) WatchInsertChannels(channels []string) { ctx := context.TODO() - c.mu.Lock() - defer c.mu.Unlock() + c.Lock() + defer c.Unlock() var groups [][]string if len(channels) < len(c.nodes) { groups = make([][]string, len(channels)) @@ -108,8 +110,8 @@ func (c *dataNodeCluster) WatchInsertChannels(channels []string) { } func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.ComponentInfo, error) { - c.mu.RLock() - defer c.mu.RUnlock() + c.RLock() + defer c.RUnlock() ret := make([]*internalpb.ComponentInfo, 0) for _, node := range c.nodes { states, err := node.client.GetComponentStates(ctx) @@ -124,8 +126,8 @@ func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb. func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) { ctx := context.TODO() - c.mu.RLock() - defer c.mu.RUnlock() + c.Lock() + defer c.Unlock() for _, node := range c.nodes { if _, err := node.client.FlushSegments(ctx, request); err != nil { log.Error("flush segment err", zap.Stringer("dataNode", node), zap.Error(err)) @@ -135,6 +137,8 @@ func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) { } func (c *dataNodeCluster) ShutDownClients() { + c.Lock() + defer c.Unlock() for _, node := range c.nodes { if err := node.client.Stop(); err != nil { log.Error("stop client error", zap.Stringer("dataNode", node), zap.Error(err)) @@ -145,8 +149,8 @@ func (c *dataNodeCluster) ShutDownClients() { // Clear only for test func (c *dataNodeCluster) Clear() { - c.mu.Lock() - defer c.mu.Unlock() + c.Lock() + defer c.Unlock() c.finishCh = make(chan struct{}) c.nodes = make([]*dataNode, 0) } diff --git a/internal/dataservice/cluster_test.go b/internal/dataservice/cluster_test.go index 11ce0184e..19def2e17 100644 --- a/internal/dataservice/cluster_test.go +++ b/internal/dataservice/cluster_test.go @@ -4,8 +4,52 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "golang.org/x/net/context" ) +func TestDataNodeClusterRegister(t *testing.T) { + Params.Init() + Params.DataNodeNum = 3 + ch := make(chan struct{}) + cluster := newDataNodeCluster(ch) + ids := make([]int64, 0, Params.DataNodeNum) + for i := 0; i < Params.DataNodeNum; i++ { + c := newMockDataNodeClient(int64(i)) + err := c.Init() + assert.Nil(t, err) + err = c.Start() + assert.Nil(t, err) + cluster.Register(&dataNode{ + id: int64(i), + address: struct { + ip string + port int64 + }{"localhost", int64(9999 + i)}, + client: c, + channelNum: 0, + }) + ids = append(ids, int64(i)) + } + _, ok := <-ch + assert.False(t, ok) + assert.EqualValues(t, Params.DataNodeNum, cluster.GetNumOfNodes()) + assert.EqualValues(t, ids, cluster.GetNodeIDs()) + states, err := cluster.GetDataNodeStates(context.TODO()) + assert.Nil(t, err) + assert.EqualValues(t, Params.DataNodeNum, len(states)) + for _, s := range states { + assert.EqualValues(t, internalpb.StateCode_Healthy, s.StateCode) + } + cluster.ShutDownClients() + states, err = cluster.GetDataNodeStates(context.TODO()) + assert.Nil(t, err) + assert.EqualValues(t, Params.DataNodeNum, len(states)) + for _, s := range states { + assert.EqualValues(t, internalpb.StateCode_Abnormal, s.StateCode) + } +} + func TestWatchChannels(t *testing.T) { Params.Init() Params.DataNodeNum = 3 @@ -23,13 +67,18 @@ func TestWatchChannels(t *testing.T) { cluster := newDataNodeCluster(make(chan struct{})) for _, c := range cases { for i := 0; i < Params.DataNodeNum; i++ { + c := newMockDataNodeClient(int64(i)) + err := c.Init() + assert.Nil(t, err) + err = c.Start() + assert.Nil(t, err) cluster.Register(&dataNode{ id: int64(i), address: struct { ip string port int64 }{"localhost", int64(9999 + i)}, - client: newMockDataNodeClient(), + client: c, channelNum: 0, }) } diff --git a/internal/dataservice/mock.go b/internal/dataservice/mock.go index aacb55cfd..eade3dbab 100644 --- a/internal/dataservice/mock.go +++ b/internal/dataservice/mock.go @@ -53,6 +53,15 @@ func newTestSchema() *schemapb.CollectionSchema { } type mockDataNodeClient struct { + id int64 + state internalpb.StateCode +} + +func newMockDataNodeClient(id int64) *mockDataNodeClient { + return &mockDataNodeClient{ + id: id, + state: internalpb.StateCode_Initializing, + } } func (c *mockDataNodeClient) Init() error { @@ -60,22 +69,23 @@ func (c *mockDataNodeClient) Init() error { } func (c *mockDataNodeClient) Start() error { + c.state = internalpb.StateCode_Healthy return nil } func (c *mockDataNodeClient) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - //TODO - return nil, nil + return &internalpb.ComponentStates{ + State: &internalpb.ComponentInfo{ + NodeID: c.id, + StateCode: c.state, + }, + }, nil } func (c *mockDataNodeClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { return nil, nil } -func newMockDataNodeClient() *mockDataNodeClient { - return &mockDataNodeClient{} -} - func (c *mockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) { return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } @@ -85,5 +95,6 @@ func (c *mockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.Flush } func (c *mockDataNodeClient) Stop() error { + c.state = internalpb.StateCode_Abnormal return nil } -- GitLab