提交 6e70ce3f 编写于 作者: S sunby 提交者: yefu.chen

Add cluster unit tests

Signed-off-by: Nsunby <bingyi.sun@zilliz.com>
上级 1d7195e0
...@@ -25,7 +25,7 @@ type dataNode struct { ...@@ -25,7 +25,7 @@ type dataNode struct {
channelNum int channelNum int
} }
type dataNodeCluster struct { type dataNodeCluster struct {
mu sync.RWMutex sync.RWMutex
finishCh chan struct{} finishCh chan struct{}
nodes []*dataNode nodes []*dataNode
} }
...@@ -42,8 +42,8 @@ func newDataNodeCluster(finishCh chan struct{}) *dataNodeCluster { ...@@ -42,8 +42,8 @@ func newDataNodeCluster(finishCh chan struct{}) *dataNodeCluster {
} }
func (c *dataNodeCluster) Register(dataNode *dataNode) { func (c *dataNodeCluster) Register(dataNode *dataNode) {
c.mu.Lock() c.Lock()
defer c.mu.Unlock() defer c.Unlock()
if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) { if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) {
c.nodes = append(c.nodes, dataNode) c.nodes = append(c.nodes, dataNode)
if len(c.nodes) == Params.DataNodeNum { if len(c.nodes) == Params.DataNodeNum {
...@@ -62,23 +62,25 @@ func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool { ...@@ -62,23 +62,25 @@ func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool {
} }
func (c *dataNodeCluster) GetNumOfNodes() int { func (c *dataNodeCluster) GetNumOfNodes() int {
c.RLock()
defer c.RUnlock()
return len(c.nodes) return len(c.nodes)
} }
func (c *dataNodeCluster) GetNodeIDs() []int64 { func (c *dataNodeCluster) GetNodeIDs() []int64 {
c.mu.RLock() c.RLock()
defer c.mu.RUnlock() defer c.RUnlock()
ret := make([]int64, len(c.nodes)) ret := make([]int64, 0, len(c.nodes))
for i, node := range c.nodes { for _, node := range c.nodes {
ret[i] = node.id ret = append(ret, node.id)
} }
return ret return ret
} }
func (c *dataNodeCluster) WatchInsertChannels(channels []string) { func (c *dataNodeCluster) WatchInsertChannels(channels []string) {
ctx := context.TODO() ctx := context.TODO()
c.mu.Lock() c.Lock()
defer c.mu.Unlock() defer c.Unlock()
var groups [][]string var groups [][]string
if len(channels) < len(c.nodes) { if len(channels) < len(c.nodes) {
groups = make([][]string, len(channels)) groups = make([][]string, len(channels))
...@@ -108,8 +110,8 @@ func (c *dataNodeCluster) WatchInsertChannels(channels []string) { ...@@ -108,8 +110,8 @@ func (c *dataNodeCluster) WatchInsertChannels(channels []string) {
} }
func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.ComponentInfo, error) { func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.ComponentInfo, error) {
c.mu.RLock() c.RLock()
defer c.mu.RUnlock() defer c.RUnlock()
ret := make([]*internalpb.ComponentInfo, 0) ret := make([]*internalpb.ComponentInfo, 0)
for _, node := range c.nodes { for _, node := range c.nodes {
states, err := node.client.GetComponentStates(ctx) states, err := node.client.GetComponentStates(ctx)
...@@ -124,8 +126,8 @@ func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb. ...@@ -124,8 +126,8 @@ func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.
func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) { func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) {
ctx := context.TODO() ctx := context.TODO()
c.mu.RLock() c.Lock()
defer c.mu.RUnlock() defer c.Unlock()
for _, node := range c.nodes { for _, node := range c.nodes {
if _, err := node.client.FlushSegments(ctx, request); err != nil { if _, err := node.client.FlushSegments(ctx, request); err != nil {
log.Error("flush segment err", zap.Stringer("dataNode", node), zap.Error(err)) log.Error("flush segment err", zap.Stringer("dataNode", node), zap.Error(err))
...@@ -135,6 +137,8 @@ func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) { ...@@ -135,6 +137,8 @@ func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) {
} }
func (c *dataNodeCluster) ShutDownClients() { func (c *dataNodeCluster) ShutDownClients() {
c.Lock()
defer c.Unlock()
for _, node := range c.nodes { for _, node := range c.nodes {
if err := node.client.Stop(); err != nil { if err := node.client.Stop(); err != nil {
log.Error("stop client error", zap.Stringer("dataNode", node), zap.Error(err)) log.Error("stop client error", zap.Stringer("dataNode", node), zap.Error(err))
...@@ -145,8 +149,8 @@ func (c *dataNodeCluster) ShutDownClients() { ...@@ -145,8 +149,8 @@ func (c *dataNodeCluster) ShutDownClients() {
// Clear only for test // Clear only for test
func (c *dataNodeCluster) Clear() { func (c *dataNodeCluster) Clear() {
c.mu.Lock() c.Lock()
defer c.mu.Unlock() defer c.Unlock()
c.finishCh = make(chan struct{}) c.finishCh = make(chan struct{})
c.nodes = make([]*dataNode, 0) c.nodes = make([]*dataNode, 0)
} }
...@@ -4,8 +4,52 @@ import ( ...@@ -4,8 +4,52 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"golang.org/x/net/context"
) )
func TestDataNodeClusterRegister(t *testing.T) {
Params.Init()
Params.DataNodeNum = 3
ch := make(chan struct{})
cluster := newDataNodeCluster(ch)
ids := make([]int64, 0, Params.DataNodeNum)
for i := 0; i < Params.DataNodeNum; i++ {
c := newMockDataNodeClient(int64(i))
err := c.Init()
assert.Nil(t, err)
err = c.Start()
assert.Nil(t, err)
cluster.Register(&dataNode{
id: int64(i),
address: struct {
ip string
port int64
}{"localhost", int64(9999 + i)},
client: c,
channelNum: 0,
})
ids = append(ids, int64(i))
}
_, ok := <-ch
assert.False(t, ok)
assert.EqualValues(t, Params.DataNodeNum, cluster.GetNumOfNodes())
assert.EqualValues(t, ids, cluster.GetNodeIDs())
states, err := cluster.GetDataNodeStates(context.TODO())
assert.Nil(t, err)
assert.EqualValues(t, Params.DataNodeNum, len(states))
for _, s := range states {
assert.EqualValues(t, internalpb.StateCode_Healthy, s.StateCode)
}
cluster.ShutDownClients()
states, err = cluster.GetDataNodeStates(context.TODO())
assert.Nil(t, err)
assert.EqualValues(t, Params.DataNodeNum, len(states))
for _, s := range states {
assert.EqualValues(t, internalpb.StateCode_Abnormal, s.StateCode)
}
}
func TestWatchChannels(t *testing.T) { func TestWatchChannels(t *testing.T) {
Params.Init() Params.Init()
Params.DataNodeNum = 3 Params.DataNodeNum = 3
...@@ -23,13 +67,18 @@ func TestWatchChannels(t *testing.T) { ...@@ -23,13 +67,18 @@ func TestWatchChannels(t *testing.T) {
cluster := newDataNodeCluster(make(chan struct{})) cluster := newDataNodeCluster(make(chan struct{}))
for _, c := range cases { for _, c := range cases {
for i := 0; i < Params.DataNodeNum; i++ { for i := 0; i < Params.DataNodeNum; i++ {
c := newMockDataNodeClient(int64(i))
err := c.Init()
assert.Nil(t, err)
err = c.Start()
assert.Nil(t, err)
cluster.Register(&dataNode{ cluster.Register(&dataNode{
id: int64(i), id: int64(i),
address: struct { address: struct {
ip string ip string
port int64 port int64
}{"localhost", int64(9999 + i)}, }{"localhost", int64(9999 + i)},
client: newMockDataNodeClient(), client: c,
channelNum: 0, channelNum: 0,
}) })
} }
......
...@@ -53,6 +53,15 @@ func newTestSchema() *schemapb.CollectionSchema { ...@@ -53,6 +53,15 @@ func newTestSchema() *schemapb.CollectionSchema {
} }
type mockDataNodeClient struct { type mockDataNodeClient struct {
id int64
state internalpb.StateCode
}
func newMockDataNodeClient(id int64) *mockDataNodeClient {
return &mockDataNodeClient{
id: id,
state: internalpb.StateCode_Initializing,
}
} }
func (c *mockDataNodeClient) Init() error { func (c *mockDataNodeClient) Init() error {
...@@ -60,22 +69,23 @@ func (c *mockDataNodeClient) Init() error { ...@@ -60,22 +69,23 @@ func (c *mockDataNodeClient) Init() error {
} }
func (c *mockDataNodeClient) Start() error { func (c *mockDataNodeClient) Start() error {
c.state = internalpb.StateCode_Healthy
return nil return nil
} }
func (c *mockDataNodeClient) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { func (c *mockDataNodeClient) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
//TODO return &internalpb.ComponentStates{
return nil, nil State: &internalpb.ComponentInfo{
NodeID: c.id,
StateCode: c.state,
},
}, nil
} }
func (c *mockDataNodeClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { func (c *mockDataNodeClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return nil, nil return nil, nil
} }
func newMockDataNodeClient() *mockDataNodeClient {
return &mockDataNodeClient{}
}
func (c *mockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) { func (c *mockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) {
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
} }
...@@ -85,5 +95,6 @@ func (c *mockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.Flush ...@@ -85,5 +95,6 @@ func (c *mockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.Flush
} }
func (c *mockDataNodeClient) Stop() error { func (c *mockDataNodeClient) Stop() error {
c.state = internalpb.StateCode_Abnormal
return nil return nil
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册