未验证 提交 ee0f753f 编写于 作者: D dragondriver 提交者: GitHub

Fix datarace between GetComponentStates and Register (#11935)

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 074687e3
......@@ -49,6 +49,9 @@ const (
// InvalidFieldID indicates that the field does not exist . It will be set when the field is not found.
InvalidFieldID = int64(-1)
// NotRegisteredID means node is not registered into etcd.
NotRegisteredID = int64(-1)
)
// Endian is type alias of binary.LittleEndian.
......
......@@ -27,6 +27,8 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/common"
memkv "github.com/milvus-io/milvus/internal/kv/mem"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
......@@ -491,6 +493,12 @@ func TestGetSegmentInfo(t *testing.T) {
func TestGetComponentStates(t *testing.T) {
svr := &Server{}
resp, err := svr.GetComponentStates(context.Background())
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
svr.session = &sessionutil.Session{}
svr.session.UpdateRegistered(true)
type testCase struct {
state ServerState
code internalpb.StateCode
......
......@@ -23,6 +23,8 @@ import (
"sync/atomic"
"time"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/log"
......@@ -384,9 +386,15 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath
// GetComponentStates returns DataCoord's current state
func (s *Server) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
nodeID := common.NotRegisteredID
if s.session != nil && s.session.Registered() {
nodeID = s.session.ServerID // or Params.NodeID
}
resp := &internalpb.ComponentStates{
State: &internalpb.ComponentInfo{
NodeID: Params.NodeID,
// NodeID: Params.NodeID, // will race with Server.Register()
NodeID: nodeID,
Role: "datacoord",
StateCode: 0,
},
......
......@@ -32,6 +32,8 @@ import (
"sync/atomic"
"time"
"github.com/milvus-io/milvus/internal/common"
v3rpc "go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
clientv3 "go.etcd.io/etcd/client/v3"
......@@ -473,9 +475,14 @@ func (node *DataNode) WatchDmChannels(ctx context.Context, in *datapb.WatchDmCha
// GetComponentStates will return current state of DataNode
func (node *DataNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
log.Debug("DataNode current state", zap.Any("State", node.State.Load()))
nodeID := common.NotRegisteredID
if node.session != nil && node.session.Registered() {
nodeID = node.session.ServerID
}
states := &internalpb.ComponentStates{
State: &internalpb.ComponentInfo{
NodeID: Params.NodeID,
// NodeID: Params.NodeID, // will race with DataNode.Register()
NodeID: nodeID,
Role: node.Role,
StateCode: node.State.Load().(internalpb.StateCode),
},
......
......@@ -28,6 +28,8 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/types"
......@@ -584,3 +586,17 @@ func TestWatchChannel(t *testing.T) {
})
}
func TestDataNode_GetComponentStates(t *testing.T) {
n := &DataNode{}
n.State.Store(internalpb.StateCode_Healthy)
resp, err := n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
n.session = &sessionutil.Session{}
n.session.UpdateRegistered(true)
resp, err = n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
}
......@@ -26,6 +26,8 @@ import (
"sync/atomic"
"time"
"github.com/milvus-io/milvus/internal/common"
"go.etcd.io/etcd/api/v3/mvccpb"
"go.uber.org/zap"
......@@ -59,8 +61,6 @@ var _ types.IndexCoord = (*IndexCoord)(nil)
type IndexCoord struct {
stateCode atomic.Value
ID UniqueID
loopCtx context.Context
loopCancel func()
loopWg sync.WaitGroup
......@@ -193,13 +193,6 @@ func (i *IndexCoord) Init() error {
return
}
i.ID, err = i.idAllocator.AllocOne()
if err != nil {
log.Error("IndexCoord idAllocator allocOne failed", zap.Error(err))
initErr = err
return
}
option := &miniokv.Option{
Address: Params.MinIOAddress,
AccessKeyID: Params.MinIOAccessKeyID,
......@@ -302,8 +295,14 @@ func (i *IndexCoord) isHealthy() bool {
// GetComponentStates gets the component states of IndexCoord.
func (i *IndexCoord) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
log.Debug("get IndexCoord component states ...")
nodeID := common.NotRegisteredID
if i.session != nil && i.session.Registered() {
nodeID = i.session.ServerID
}
stateInfo := &internalpb.ComponentInfo{
NodeID: i.ID,
NodeID: nodeID,
Role: "IndexCoord",
StateCode: i.stateCode.Load().(internalpb.StateCode),
}
......@@ -515,19 +514,19 @@ func (i *IndexCoord) GetIndexFilePaths(ctx context.Context, req *indexpb.GetInde
// GetMetrics gets the metrics info of IndexCoord.
func (i *IndexCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
log.Debug("IndexCoord.GetMetrics",
zap.Int64("node_id", i.ID),
zap.Int64("node_id", i.session.ServerID),
zap.String("req", req.Request))
if !i.isHealthy() {
log.Warn("IndexCoord.GetMetrics failed",
zap.Int64("node_id", i.ID),
zap.Int64("node_id", i.session.ServerID),
zap.String("req", req.Request),
zap.Error(errIndexCoordIsUnhealthy(i.ID)))
zap.Error(errIndexCoordIsUnhealthy(i.session.ServerID)))
return &milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgIndexCoordIsUnhealthy(i.ID),
Reason: msgIndexCoordIsUnhealthy(i.session.ServerID),
},
Response: "",
}, nil
......@@ -536,7 +535,7 @@ func (i *IndexCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReq
metricType, err := metricsinfo.ParseMetricType(req.Request)
if err != nil {
log.Error("IndexCoord.GetMetrics failed to parse metric type",
zap.Int64("node_id", i.ID),
zap.Int64("node_id", i.session.ServerID),
zap.String("req", req.Request),
zap.Error(err))
......@@ -563,7 +562,7 @@ func (i *IndexCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReq
metrics, err := getSystemInfoMetrics(ctx, req, i)
log.Debug("IndexCoord.GetMetrics",
zap.Int64("node_id", i.ID),
zap.Int64("node_id", i.session.ServerID),
zap.String("req", req.Request),
zap.String("metric_type", metricType),
zap.Any("metrics", metrics), // TODO(dragondriver): necessary? may be very large
......@@ -575,7 +574,7 @@ func (i *IndexCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReq
}
log.Debug("IndexCoord.GetMetrics failed, request metric type is not implemented yet",
zap.Int64("node_id", i.ID),
zap.Int64("node_id", i.session.ServerID),
zap.String("req", req.Request),
zap.String("metric_type", metricType))
......
......@@ -23,6 +23,8 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
......@@ -231,3 +233,17 @@ func TestIndexCoord_watchNodeLoop(t *testing.T) {
assert.True(t, flag)
}
func TestIndexCoord_GetComponentStates(t *testing.T) {
n := &IndexCoord{}
n.stateCode.Store(internalpb.StateCode_Healthy)
resp, err := n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
n.session = &sessionutil.Session{}
n.session.UpdateRegistered(true)
resp, err = n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
}
......@@ -38,6 +38,8 @@ import (
"time"
"unsafe"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
......@@ -285,8 +287,13 @@ func (i *IndexNode) CreateIndex(ctx context.Context, request *indexpb.CreateInde
// GetComponentStates gets the component states of IndexNode.
func (i *IndexNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
log.Debug("get IndexNode components states ...")
nodeID := common.NotRegisteredID
if i.session != nil && i.session.Registered() {
nodeID = i.session.ServerID
}
stateInfo := &internalpb.ComponentInfo{
NodeID: Params.NodeID,
// NodeID: Params.NodeID, // will race with i.Register()
NodeID: nodeID,
Role: "NodeImpl",
StateCode: i.stateCode.Load().(internalpb.StateCode),
}
......
......@@ -24,12 +24,15 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/log"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/proto/internalpb"
......@@ -811,3 +814,17 @@ func TestIndexNode_InitError(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
}
func TestIndexNode_GetComponentStates(t *testing.T) {
n := &IndexNode{}
n.stateCode.Store(internalpb.StateCode_Healthy)
resp, err := n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
n.session = &sessionutil.Session{}
n.session.UpdateRegistered(true)
resp, err = n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
}
......@@ -23,6 +23,8 @@ import (
"os"
"strconv"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/trace"
......@@ -65,8 +67,13 @@ func (node *Proxy) GetComponentStates(ctx context.Context) (*internalpb.Componen
}
return stats, errors.New(errMsg)
}
nodeID := common.NotRegisteredID
if node.session != nil && node.session.Registered() {
nodeID = node.session.ServerID
}
info := &internalpb.ComponentInfo{
NodeID: Params.ProxyID,
// NodeID: Params.ProxyID, // will race with Proxy.Register()
NodeID: nodeID,
Role: typeutil.ProxyRole,
StateCode: code,
}
......
......@@ -29,6 +29,8 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/common"
......@@ -2480,3 +2482,17 @@ func Test_GetCompactionStateWithPlans(t *testing.T) {
assert.Nil(t, err)
})
}
func TestProxy_GetComponentStates(t *testing.T) {
n := &Proxy{}
n.stateCode.Store(internalpb.StateCode_Healthy)
resp, err := n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
n.session = &sessionutil.Session{}
n.session.UpdateRegistered(true)
resp, err = n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
}
......@@ -21,6 +21,8 @@ import (
"errors"
"fmt"
"github.com/milvus-io/milvus/internal/common"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
......@@ -33,8 +35,13 @@ import (
// GetComponentStates return information about whether the coord is healthy
func (qc *QueryCoord) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
nodeID := common.NotRegisteredID
if qc.session != nil && qc.session.Registered() {
nodeID = qc.session.ServerID
}
serviceComponentInfo := &internalpb.ComponentInfo{
NodeID: Params.QueryCoordID,
// NodeID: Params.QueryCoordID, // will race with QueryCoord.Register()
NodeID: nodeID,
StateCode: qc.stateCode.Load().(internalpb.StateCode),
}
......
......@@ -23,6 +23,9 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/commonpb"
......@@ -694,3 +697,17 @@ func Test_GrpcGetQueryChannelFail(t *testing.T) {
assert.NotNil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, res.Status.ErrorCode)
}
func TestQueryCoord_GetComponentStates(t *testing.T) {
n := &QueryCoord{}
n.stateCode.Store(internalpb.StateCode_Healthy)
resp, err := n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
n.session = &sessionutil.Session{}
n.session.UpdateRegistered(true)
resp, err = n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
}
......@@ -21,6 +21,8 @@ import (
"sync/atomic"
"time"
"github.com/milvus-io/milvus/internal/common"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/kv"
......@@ -1194,9 +1196,15 @@ func (c *Core) GetComponentStates(ctx context.Context) (*internalpb.ComponentSta
code := c.stateCode.Load().(internalpb.StateCode)
log.Debug("GetComponentStates", zap.String("State Code", internalpb.StateCode_name[int32(code)]))
nodeID := common.NotRegisteredID
if c.session != nil && c.session.Registered() {
nodeID = c.session.ServerID
}
return &internalpb.ComponentStates{
State: &internalpb.ComponentInfo{
NodeID: c.session.ServerID,
// NodeID: c.session.ServerID, // will race with Core.Register()
NodeID: nodeID,
Role: typeutil.RootCoordRole,
StateCode: code,
ExtraInfo: nil,
......@@ -1207,7 +1215,7 @@ func (c *Core) GetComponentStates(ctx context.Context) (*internalpb.ComponentSta
},
SubcomponentStates: []*internalpb.ComponentInfo{
{
NodeID: c.session.ServerID,
NodeID: nodeID,
Role: typeutil.RootCoordRole,
StateCode: code,
ExtraInfo: nil,
......
......@@ -2715,3 +2715,17 @@ func TestRootCoord_CheckZeroShardsNum(t *testing.T) {
err = core.Stop()
assert.Nil(t, err)
}
func TestCore_GetComponentStates(t *testing.T) {
n := &Core{}
n.stateCode.Store(internalpb.StateCode_Healthy)
resp, err := n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
n.session = &sessionutil.Session{}
n.session.UpdateRegistered(true)
resp, err = n.GetComponentStates(context.Background())
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
}
......@@ -7,6 +7,7 @@ import (
"fmt"
"path"
"strconv"
"sync/atomic"
"time"
"github.com/milvus-io/milvus/internal/log"
......@@ -58,6 +59,8 @@ type Session struct {
leaseID *clientv3.LeaseID
metaRoot string
registered atomic.Value
}
// NewSession is a helper to build Session object.
......@@ -70,6 +73,8 @@ func NewSession(ctx context.Context, metaRoot string, etcdEndpoints []string) *S
metaRoot: metaRoot,
}
session.UpdateRegistered(false)
connectEtcdFn := func() error {
log.Debug("Session try to connect to etcd")
etcdCli, err := clientv3.New(clientv3.Config{Endpoints: etcdEndpoints, DialTimeout: 5 * time.Second})
......@@ -112,6 +117,7 @@ func (s *Session) Init(serverName, address string, exclusive bool) {
panic(err)
}
s.liveCh = s.processKeepAliveResponse(ch)
s.UpdateRegistered(true)
}
func (s *Session) getServerID() (int64, error) {
......@@ -403,3 +409,17 @@ func (s *Session) Revoke(timeout time.Duration) {
// ignores resp & error, just do best effort to revoke
_, _ = s.etcdCli.Revoke(ctx, *s.leaseID)
}
// UpdateRegistered update the state of registered.
func (s *Session) UpdateRegistered(b bool) {
s.registered.Store(b)
}
// Registered check if session was registered into etcd.
func (s *Session) Registered() bool {
b, ok := s.registered.Load().(bool)
if !ok {
return false
}
return b
}
......@@ -238,3 +238,11 @@ func TestSessionRevoke(t *testing.T) {
s.Revoke(time.Second)
})
}
func TestSession_Registered(t *testing.T) {
session := &Session{}
session.UpdateRegistered(false)
assert.False(t, session.Registered())
session.UpdateRegistered(true)
assert.True(t, session.Registered())
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册