diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 3049a59e8b171e9e408095def63a1895fbb18f45..6669fbe2c5e240df8d969babf1f4710ede48270c 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -58,7 +58,6 @@ import ( "github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/paramtable" - "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/typeutil" ) @@ -452,12 +451,7 @@ func (node *QueryNode) watchChangeInfo() { log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error())) continue } - go func() { - err = node.removeSegments(info) - if err != nil { - log.Warn("cleanup segments failed", zap.Any("error", err.Error())) - } - }() + go node.handleSealedSegmentsChangeInfo(info) default: // do nothing } @@ -466,58 +460,47 @@ func (node *QueryNode) watchChangeInfo() { } } -func (node *QueryNode) waitChangeInfo(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error { - fn := func() error { - /* - for _, info := range segmentChangeInfos.Infos { - canDoLoadBalance := true - // make sure all query channel already received segment location changes - // Check online segments: - for _, segmentInfo := range info.OnlineSegments { - if node.queryService.hasQueryCollection(segmentInfo.CollectionID) { - qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID) - if err != nil { - canDoLoadBalance = false - break - } - if info.OnlineNodeID == Params.QueryNodeCfg.QueryNodeID && !qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) { - canDoLoadBalance = false - break - } - } - } - // Check offline segments: - for _, segmentInfo := range info.OfflineSegments { - if node.queryService.hasQueryCollection(segmentInfo.CollectionID) { - qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID) - if err != nil { - canDoLoadBalance = false - break - } - if info.OfflineNodeID == Params.QueryNodeCfg.QueryNodeID && qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) { - canDoLoadBalance = false - break - } - } - } - if canDoLoadBalance { - return nil - } - return errors.New(fmt.Sprintln("waitChangeInfo failed, infoID = ", segmentChangeInfos.Base.GetMsgID())) - } - */ - return nil +func (node *QueryNode) handleSealedSegmentsChangeInfo(info *querypb.SealedSegmentsChangeInfo) { + for _, line := range info.GetInfos() { + vchannel, err := validateChangeChannel(line) + if err != nil { + log.Warn("failed to validate vchannel for SegmentChangeInfo", zap.Error(err)) + continue + } + + node.ShardClusterService.HandoffVChannelSegments(vchannel, line) + } +} + +func validateChangeChannel(info *querypb.SegmentChangeInfo) (string, error) { + if len(info.GetOnlineSegments()) == 0 && len(info.GetOfflineSegments()) == 0 { + return "", errors.New("SegmentChangeInfo with no segments info") + } + + var channelName string + + for _, segment := range info.GetOnlineSegments() { + if channelName == "" { + channelName = segment.GetDmChannel() + } + if segment.GetDmChannel() != channelName { + return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel()) + } + } + for _, segment := range info.GetOfflineSegments() { + if channelName == "" { + channelName = segment.GetDmChannel() + } + if segment.GetDmChannel() != channelName { + return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel()) + } } - return retry.Do(node.queryNodeLoopCtx, fn, retry.Attempts(50)) + return channelName, nil } // remove the segments since it's already compacted or balanced to other QueryNodes func (node *QueryNode) removeSegments(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error { - err := node.waitChangeInfo(segmentChangeInfos) - if err != nil { - return err - } node.streaming.replica.queryLock() node.historical.replica.queryLock() diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index 1093de4fd4b9e548b5be8a11f33b9c30ef4a3517..27d9349427ddefce1ce84926209f02af05b59d92 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -31,6 +31,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/server/v3/embed" "github.com/milvus-io/milvus/internal/util/dependency" @@ -329,17 +330,6 @@ func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, e return node, nil } -func TestQueryNode_waitChangeInfo(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) - assert.NoError(t, err) - - err = node.waitChangeInfo(genSimpleChangeInfo()) - assert.NoError(t, err) -} - func TestQueryNode_adjustByChangeInfo(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -534,3 +524,139 @@ func TestQueryNode_watchService(t *testing.T) { assert.True(t, flag) }) } + +func TestQueryNode_validateChangeChannel(t *testing.T) { + + type testCase struct { + name string + info *querypb.SegmentChangeInfo + expectedError bool + expectedChannelName string + } + + cases := []testCase{ + { + name: "empty info", + info: &querypb.SegmentChangeInfo{}, + expectedError: true, + }, + { + name: "normal segment change info", + info: &querypb.SegmentChangeInfo{ + OnlineSegments: []*querypb.SegmentInfo{ + {DmChannel: defaultDMLChannel}, + }, + OfflineSegments: []*querypb.SegmentInfo{ + {DmChannel: defaultDMLChannel}, + }, + }, + expectedError: false, + expectedChannelName: defaultDMLChannel, + }, + { + name: "empty offline change info", + info: &querypb.SegmentChangeInfo{ + OnlineSegments: []*querypb.SegmentInfo{ + {DmChannel: defaultDMLChannel}, + }, + }, + expectedError: false, + expectedChannelName: defaultDMLChannel, + }, + { + name: "empty online change info", + info: &querypb.SegmentChangeInfo{ + OfflineSegments: []*querypb.SegmentInfo{ + {DmChannel: defaultDMLChannel}, + }, + }, + expectedError: false, + expectedChannelName: defaultDMLChannel, + }, + { + name: "different channel in online", + info: &querypb.SegmentChangeInfo{ + OnlineSegments: []*querypb.SegmentInfo{ + {DmChannel: defaultDMLChannel}, + {DmChannel: "other_channel"}, + }, + }, + expectedError: true, + }, + { + name: "different channel in offline", + info: &querypb.SegmentChangeInfo{ + OnlineSegments: []*querypb.SegmentInfo{ + {DmChannel: defaultDMLChannel}, + }, + OfflineSegments: []*querypb.SegmentInfo{ + {DmChannel: "other_channel"}, + }, + }, + expectedError: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + channelName, err := validateChangeChannel(tc.info) + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedChannelName, channelName) + } + }) + } +} + +func TestQueryNode_handleSealedSegmentsChangeInfo(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + qn, err := genSimpleQueryNode(ctx) + require.NoError(t, err) + + t.Run("empty info", func(t *testing.T) { + assert.NotPanics(t, func() { + qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{}) + }) + assert.NotPanics(t, func() { + qn.handleSealedSegmentsChangeInfo(nil) + }) + }) + + t.Run("normal segment change info", func(t *testing.T) { + assert.NotPanics(t, func() { + qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{ + Infos: []*querypb.SegmentChangeInfo{ + { + OnlineSegments: []*querypb.SegmentInfo{ + {DmChannel: defaultDMLChannel}, + }, + OfflineSegments: []*querypb.SegmentInfo{ + {DmChannel: defaultDMLChannel}, + }, + }, + }, + }) + }) + }) + + t.Run("bad change info", func(t *testing.T) { + assert.NotPanics(t, func() { + qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{ + Infos: []*querypb.SegmentChangeInfo{ + { + OnlineSegments: []*querypb.SegmentInfo{ + {DmChannel: defaultDMLChannel}, + }, + OfflineSegments: []*querypb.SegmentInfo{ + {DmChannel: "other_channel"}, + }, + }, + }, + }) + }) + + }) +} diff --git a/internal/querynode/shard_cluster.go b/internal/querynode/shard_cluster.go index 69bc21d663d60cf50c7728f0a60ed69289376d00..67425643f97fe13772c8d5766847e41dbe37be2b 100644 --- a/internal/querynode/shard_cluster.go +++ b/internal/querynode/shard_cluster.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/errorutil" "go.uber.org/atomic" "go.uber.org/zap" ) @@ -78,6 +79,7 @@ type segmentEvent struct { type shardQueryNode interface { Search(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error) Query(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error) + ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) Stop() error } @@ -363,7 +365,6 @@ func (sc *ShardCluster) watchNodes(evtCh <-chan nodeEvent) { for { select { case evt, ok := <-evtCh: - log.Debug("node event", zap.Any("evt", evt)) if !ok { log.Warn("ShardCluster node channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID)) return @@ -514,6 +515,8 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error { offlineSegments = append(offlineSegments, seg.GetSegmentID()) } sc.waitSegmentsNotInUse(offlineSegments) + + removes := make(map[int64][]int64) // nodeID => []segmentIDs // remove offline segments record for _, seg := range info.OfflineSegments { // filter out segments not maintained in this cluster @@ -521,11 +524,39 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error { continue } sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()}) + + removes[seg.GetNodeID()] = append(removes[seg.GetNodeID()], seg.SegmentID) + } + + var errs errorutil.ErrorList + + // notify querynode(s) to release segments + for nodeID, segmentIDs := range removes { + node, ok := sc.getNode(nodeID) + if !ok { + log.Warn("node not in cluster", zap.Int64("nodeID", nodeID), zap.Int64("collectionID", sc.collectionID), zap.String("vchannel", sc.vchannelName)) + errs = append(errs, fmt.Errorf("node not in cluster nodeID %d", nodeID)) + continue + } + state, err := node.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{ + CollectionID: sc.collectionID, + SegmentIDs: segmentIDs, + }) + if err != nil { + errs = append(errs, err) + continue + } + if state.GetErrorCode() != commonpb.ErrorCode_Success { + errs = append(errs, fmt.Errorf("Release segments failed with reason: %s", state.GetReason())) + } } // finish handoff and remove it from pending list sc.finishHandoff(token) + if len(errs) > 0 { + return errs + } return nil } diff --git a/internal/querynode/shard_cluster_service.go b/internal/querynode/shard_cluster_service.go index f7d776975600025c337d758f91fb794eac5b2df4..a36c3dc9dc95855d958c36c96ceb091e51ea57ca 100644 --- a/internal/querynode/shard_cluster_service.go +++ b/internal/querynode/shard_cluster_service.go @@ -137,3 +137,14 @@ func (s *ShardClusterService) SyncReplicaSegments(vchannelName string, distribut return nil } + +// HandoffVChannelSegments dispatches SegmentChangeInfo to related ShardCluster with VChannel +func (s *ShardClusterService) HandoffVChannelSegments(vchannel string, info *querypb.SegmentChangeInfo) error { + raw, ok := s.clusters.Load(vchannel) + if !ok { + // not leader for this channel, ignore without error + return nil + } + sc := raw.(*ShardCluster) + return sc.HandoffSegments(info) +} diff --git a/internal/querynode/shard_cluster_service_test.go b/internal/querynode/shard_cluster_service_test.go index c43c49b6bfa329ac9db98af9da79a1d19c9fa5db..3be4d72e2e7791621f6ac75b3a2ffd1638363e96 100644 --- a/internal/querynode/shard_cluster_service_test.go +++ b/internal/querynode/shard_cluster_service_test.go @@ -88,3 +88,23 @@ func TestShardClusterService_SyncReplicaSegments(t *testing.T) { assert.Equal(t, segmentStateLoaded, segment.state) }) } + +func TestShardClusterService_HandoffVChannelSegments(t *testing.T) { + qn, err := genSimpleQueryNode(context.Background()) + require.NoError(t, err) + + client := v3client.New(embedetcdServer.Server) + defer client.Close() + session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client) + clusterService := newShardClusterService(client, session, qn) + + err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{}) + assert.NoError(t, err) + + clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel) + //TODO change shardCluster to interface to mock test behavior + assert.NotPanics(t, func() { + err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{}) + assert.NoError(t, err) + }) +} diff --git a/internal/querynode/shard_cluster_test.go b/internal/querynode/shard_cluster_test.go index 332dc72c4b7058b393395efb8093f69fb650769e..e6797412eff9b4fedbcbb7f8c8cb9933e6c7eab2 100644 --- a/internal/querynode/shard_cluster_test.go +++ b/internal/querynode/shard_cluster_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/stretchr/testify/assert" @@ -47,10 +48,12 @@ func (m *mockSegmentDetector) watchSegments(collectionID int64, replicaID int64, } type mockShardQueryNode struct { - searchResult *internalpb.SearchResults - searchErr error - queryResult *internalpb.RetrieveResults - queryErr error + searchResult *internalpb.SearchResults + searchErr error + queryResult *internalpb.RetrieveResults + queryErr error + releaseSegmentsResult *commonpb.Status + releaseSegmentsErr error } func (m *mockShardQueryNode) Search(_ context.Context, _ *querypb.SearchRequest) (*internalpb.SearchResults, error) { @@ -61,6 +64,10 @@ func (m *mockShardQueryNode) Query(_ context.Context, _ *querypb.QueryRequest) ( return m.queryResult, m.queryErr } +func (m *mockShardQueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { + return m.releaseSegmentsResult, m.releaseSegmentsErr +} + func (m *mockShardQueryNode) Stop() error { return nil } @@ -1336,7 +1343,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) { }, buildMockQueryNode) defer sc.Close() - sc.HandoffSegments(&querypb.SegmentChangeInfo{ + err := sc.HandoffSegments(&querypb.SegmentChangeInfo{ OnlineSegments: []*querypb.SegmentInfo{ {SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName}, }, @@ -1344,6 +1351,10 @@ func TestShardCluster_HandoffSegments(t *testing.T) { {SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName}, }, }) + if err != nil { + t.Log(err.Error()) + } + assert.NoError(t, err) sc.mut.RLock() _, has := sc.segments[1] @@ -1383,7 +1394,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) { }, buildMockQueryNode) defer sc.Close() - sc.HandoffSegments(&querypb.SegmentChangeInfo{ + err := sc.HandoffSegments(&querypb.SegmentChangeInfo{ OnlineSegments: []*querypb.SegmentInfo{ {SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName}, {SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName}, @@ -1393,6 +1404,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) { {SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName}, }, }) + assert.NoError(t, err) sc.mut.RLock() _, has := sc.segments[3] @@ -1439,7 +1451,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) { sig := make(chan struct{}) go func() { - sc.HandoffSegments(&querypb.SegmentChangeInfo{ + err := sc.HandoffSegments(&querypb.SegmentChangeInfo{ OnlineSegments: []*querypb.SegmentInfo{ {SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName}, }, @@ -1448,6 +1460,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) { }, }) + assert.NoError(t, err) close(sig) }() @@ -1493,4 +1506,113 @@ func TestShardCluster_HandoffSegments(t *testing.T) { assert.False(t, has) }) + + t.Run("handoff from non-exist node", func(t *testing.T) { + nodeEvents := []nodeEvent{ + { + nodeID: 1, + nodeAddr: "addr_1", + }, + { + nodeID: 2, + nodeAddr: "addr_2", + }, + } + + segmentEvents := []segmentEvent{ + { + segmentID: 1, + nodeID: 1, + state: segmentStateLoaded, + }, + { + segmentID: 2, + nodeID: 2, + state: segmentStateLoaded, + }, + } + evtCh := make(chan segmentEvent, 10) + sc := NewShardCluster(collectionID, replicaID, vchannelName, + &mockNodeDetector{ + initNodes: nodeEvents, + }, &mockSegmentDetector{ + initSegments: segmentEvents, + evtCh: evtCh, + }, buildMockQueryNode) + defer sc.Close() + + err := sc.HandoffSegments(&querypb.SegmentChangeInfo{ + OnlineSegments: []*querypb.SegmentInfo{ + {SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName}, + }, + OfflineSegments: []*querypb.SegmentInfo{ + {SegmentID: 1, NodeID: 3, CollectionID: collectionID, DmChannel: vchannelName}, + }, + }) + + assert.Error(t, err) + }) + + t.Run("release failed", func(t *testing.T) { + nodeEvents := []nodeEvent{ + { + nodeID: 1, + nodeAddr: "addr_1", + }, + { + nodeID: 2, + nodeAddr: "addr_2", + }, + } + + segmentEvents := []segmentEvent{ + { + segmentID: 1, + nodeID: 1, + state: segmentStateLoaded, + }, + { + segmentID: 2, + nodeID: 2, + state: segmentStateLoaded, + }, + } + evtCh := make(chan segmentEvent, 10) + mqn := &mockShardQueryNode{} + sc := NewShardCluster(collectionID, replicaID, vchannelName, + &mockNodeDetector{ + initNodes: nodeEvents, + }, &mockSegmentDetector{ + initSegments: segmentEvents, + evtCh: evtCh, + }, func(nodeID int64, addr string) shardQueryNode { + return mqn + }) + defer sc.Close() + + mqn.releaseSegmentsErr = errors.New("mocked error") + + err := sc.HandoffSegments(&querypb.SegmentChangeInfo{ + OfflineSegments: []*querypb.SegmentInfo{ + {SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName}, + }, + }) + + assert.Error(t, err) + + mqn.releaseSegmentsErr = nil + mqn.releaseSegmentsResult = &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mocked error", + } + + err = sc.HandoffSegments(&querypb.SegmentChangeInfo{ + OfflineSegments: []*querypb.SegmentInfo{ + {SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName}, + }, + }) + + assert.Error(t, err) + + }) }