未验证 提交 3a6db2fa 编写于 作者: C congqixia 提交者: GitHub

Fix handling segment change logic (#16695)

Dispatch segmentChangeInfo to ShardCluster leader
Hold segment remove before search is done
Signed-off-by: NCongqi Xia <congqi.xia@zilliz.com>
上级 bb6cd4b4
......@@ -58,7 +58,6 @@ import (
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
......@@ -452,12 +451,7 @@ func (node *QueryNode) watchChangeInfo() {
log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error()))
continue
}
go func() {
err = node.removeSegments(info)
if err != nil {
log.Warn("cleanup segments failed", zap.Any("error", err.Error()))
}
}()
go node.handleSealedSegmentsChangeInfo(info)
default:
// do nothing
}
......@@ -466,58 +460,47 @@ func (node *QueryNode) watchChangeInfo() {
}
}
func (node *QueryNode) waitChangeInfo(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error {
fn := func() error {
/*
for _, info := range segmentChangeInfos.Infos {
canDoLoadBalance := true
// make sure all query channel already received segment location changes
// Check online segments:
for _, segmentInfo := range info.OnlineSegments {
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) {
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID)
if err != nil {
canDoLoadBalance = false
break
}
if info.OnlineNodeID == Params.QueryNodeCfg.QueryNodeID && !qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) {
canDoLoadBalance = false
break
}
}
}
// Check offline segments:
for _, segmentInfo := range info.OfflineSegments {
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) {
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID)
if err != nil {
canDoLoadBalance = false
break
}
if info.OfflineNodeID == Params.QueryNodeCfg.QueryNodeID && qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) {
canDoLoadBalance = false
break
}
}
}
if canDoLoadBalance {
return nil
}
return errors.New(fmt.Sprintln("waitChangeInfo failed, infoID = ", segmentChangeInfos.Base.GetMsgID()))
}
*/
return nil
func (node *QueryNode) handleSealedSegmentsChangeInfo(info *querypb.SealedSegmentsChangeInfo) {
for _, line := range info.GetInfos() {
vchannel, err := validateChangeChannel(line)
if err != nil {
log.Warn("failed to validate vchannel for SegmentChangeInfo", zap.Error(err))
continue
}
node.ShardClusterService.HandoffVChannelSegments(vchannel, line)
}
}
func validateChangeChannel(info *querypb.SegmentChangeInfo) (string, error) {
if len(info.GetOnlineSegments()) == 0 && len(info.GetOfflineSegments()) == 0 {
return "", errors.New("SegmentChangeInfo with no segments info")
}
var channelName string
for _, segment := range info.GetOnlineSegments() {
if channelName == "" {
channelName = segment.GetDmChannel()
}
if segment.GetDmChannel() != channelName {
return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel())
}
}
for _, segment := range info.GetOfflineSegments() {
if channelName == "" {
channelName = segment.GetDmChannel()
}
if segment.GetDmChannel() != channelName {
return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel())
}
}
return retry.Do(node.queryNodeLoopCtx, fn, retry.Attempts(50))
return channelName, nil
}
// remove the segments since it's already compacted or balanced to other QueryNodes
func (node *QueryNode) removeSegments(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error {
err := node.waitChangeInfo(segmentChangeInfos)
if err != nil {
return err
}
node.streaming.replica.queryLock()
node.historical.replica.queryLock()
......
......@@ -31,6 +31,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed"
"github.com/milvus-io/milvus/internal/util/dependency"
......@@ -329,17 +330,6 @@ func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, e
return node, nil
}
func TestQueryNode_waitChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
err = node.waitChangeInfo(genSimpleChangeInfo())
assert.NoError(t, err)
}
func TestQueryNode_adjustByChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
......@@ -534,3 +524,139 @@ func TestQueryNode_watchService(t *testing.T) {
assert.True(t, flag)
})
}
func TestQueryNode_validateChangeChannel(t *testing.T) {
type testCase struct {
name string
info *querypb.SegmentChangeInfo
expectedError bool
expectedChannelName string
}
cases := []testCase{
{
name: "empty info",
info: &querypb.SegmentChangeInfo{},
expectedError: true,
},
{
name: "normal segment change info",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
expectedError: false,
expectedChannelName: defaultDMLChannel,
},
{
name: "empty offline change info",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
expectedError: false,
expectedChannelName: defaultDMLChannel,
},
{
name: "empty online change info",
info: &querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
expectedError: false,
expectedChannelName: defaultDMLChannel,
},
{
name: "different channel in online",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
{DmChannel: "other_channel"},
},
},
expectedError: true,
},
{
name: "different channel in offline",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: "other_channel"},
},
},
expectedError: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
channelName, err := validateChangeChannel(tc.info)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expectedChannelName, channelName)
}
})
}
}
func TestQueryNode_handleSealedSegmentsChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
qn, err := genSimpleQueryNode(ctx)
require.NoError(t, err)
t.Run("empty info", func(t *testing.T) {
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{})
})
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(nil)
})
})
t.Run("normal segment change info", func(t *testing.T) {
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
Infos: []*querypb.SegmentChangeInfo{
{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
},
})
})
})
t.Run("bad change info", func(t *testing.T) {
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
Infos: []*querypb.SegmentChangeInfo{
{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: "other_channel"},
},
},
},
})
})
})
}
......@@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/errorutil"
"go.uber.org/atomic"
"go.uber.org/zap"
)
......@@ -78,6 +79,7 @@ type segmentEvent struct {
type shardQueryNode interface {
Search(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)
Query(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)
ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
Stop() error
}
......@@ -363,7 +365,6 @@ func (sc *ShardCluster) watchNodes(evtCh <-chan nodeEvent) {
for {
select {
case evt, ok := <-evtCh:
log.Debug("node event", zap.Any("evt", evt))
if !ok {
log.Warn("ShardCluster node channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
return
......@@ -514,6 +515,8 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
offlineSegments = append(offlineSegments, seg.GetSegmentID())
}
sc.waitSegmentsNotInUse(offlineSegments)
removes := make(map[int64][]int64) // nodeID => []segmentIDs
// remove offline segments record
for _, seg := range info.OfflineSegments {
// filter out segments not maintained in this cluster
......@@ -521,11 +524,39 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
continue
}
sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()})
removes[seg.GetNodeID()] = append(removes[seg.GetNodeID()], seg.SegmentID)
}
var errs errorutil.ErrorList
// notify querynode(s) to release segments
for nodeID, segmentIDs := range removes {
node, ok := sc.getNode(nodeID)
if !ok {
log.Warn("node not in cluster", zap.Int64("nodeID", nodeID), zap.Int64("collectionID", sc.collectionID), zap.String("vchannel", sc.vchannelName))
errs = append(errs, fmt.Errorf("node not in cluster nodeID %d", nodeID))
continue
}
state, err := node.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{
CollectionID: sc.collectionID,
SegmentIDs: segmentIDs,
})
if err != nil {
errs = append(errs, err)
continue
}
if state.GetErrorCode() != commonpb.ErrorCode_Success {
errs = append(errs, fmt.Errorf("Release segments failed with reason: %s", state.GetReason()))
}
}
// finish handoff and remove it from pending list
sc.finishHandoff(token)
if len(errs) > 0 {
return errs
}
return nil
}
......
......@@ -137,3 +137,14 @@ func (s *ShardClusterService) SyncReplicaSegments(vchannelName string, distribut
return nil
}
// HandoffVChannelSegments dispatches SegmentChangeInfo to related ShardCluster with VChannel
func (s *ShardClusterService) HandoffVChannelSegments(vchannel string, info *querypb.SegmentChangeInfo) error {
raw, ok := s.clusters.Load(vchannel)
if !ok {
// not leader for this channel, ignore without error
return nil
}
sc := raw.(*ShardCluster)
return sc.HandoffSegments(info)
}
......@@ -88,3 +88,23 @@ func TestShardClusterService_SyncReplicaSegments(t *testing.T) {
assert.Equal(t, segmentStateLoaded, segment.state)
})
}
func TestShardClusterService_HandoffVChannelSegments(t *testing.T) {
qn, err := genSimpleQueryNode(context.Background())
require.NoError(t, err)
client := v3client.New(embedetcdServer.Server)
defer client.Close()
session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client)
clusterService := newShardClusterService(client, session, qn)
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
assert.NoError(t, err)
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
//TODO change shardCluster to interface to mock test behavior
assert.NotPanics(t, func() {
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
assert.NoError(t, err)
})
}
......@@ -22,6 +22,7 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/stretchr/testify/assert"
......@@ -47,10 +48,12 @@ func (m *mockSegmentDetector) watchSegments(collectionID int64, replicaID int64,
}
type mockShardQueryNode struct {
searchResult *internalpb.SearchResults
searchErr error
queryResult *internalpb.RetrieveResults
queryErr error
searchResult *internalpb.SearchResults
searchErr error
queryResult *internalpb.RetrieveResults
queryErr error
releaseSegmentsResult *commonpb.Status
releaseSegmentsErr error
}
func (m *mockShardQueryNode) Search(_ context.Context, _ *querypb.SearchRequest) (*internalpb.SearchResults, error) {
......@@ -61,6 +64,10 @@ func (m *mockShardQueryNode) Query(_ context.Context, _ *querypb.QueryRequest) (
return m.queryResult, m.queryErr
}
func (m *mockShardQueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
return m.releaseSegmentsResult, m.releaseSegmentsErr
}
func (m *mockShardQueryNode) Stop() error {
return nil
}
......@@ -1336,7 +1343,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.HandoffSegments(&querypb.SegmentChangeInfo{
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
},
......@@ -1344,6 +1351,10 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
},
})
if err != nil {
t.Log(err.Error())
}
assert.NoError(t, err)
sc.mut.RLock()
_, has := sc.segments[1]
......@@ -1383,7 +1394,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.HandoffSegments(&querypb.SegmentChangeInfo{
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
{SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
......@@ -1393,6 +1404,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
{SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
},
})
assert.NoError(t, err)
sc.mut.RLock()
_, has := sc.segments[3]
......@@ -1439,7 +1451,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
sig := make(chan struct{})
go func() {
sc.HandoffSegments(&querypb.SegmentChangeInfo{
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
},
......@@ -1448,6 +1460,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
},
})
assert.NoError(t, err)
close(sig)
}()
......@@ -1493,4 +1506,113 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
assert.False(t, has)
})
t.Run("handoff from non-exist node", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 3, CollectionID: collectionID, DmChannel: vchannelName},
},
})
assert.Error(t, err)
})
t.Run("release failed", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
mqn := &mockShardQueryNode{}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, func(nodeID int64, addr string) shardQueryNode {
return mqn
})
defer sc.Close()
mqn.releaseSegmentsErr = errors.New("mocked error")
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
},
})
assert.Error(t, err)
mqn.releaseSegmentsErr = nil
mqn.releaseSegmentsResult = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mocked error",
}
err = sc.HandoffSegments(&querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
},
})
assert.Error(t, err)
})
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册