未验证 提交 785a5a75 编写于 作者: C congqixia 提交者: GitHub

Use segment version instead of ref cnt (#17609)

Signed-off-by: NCongqi Xia <congqi.xia@zilliz.com>
上级 cc3ecc4b
......@@ -231,6 +231,7 @@ message ReleaseSegmentsRequest {
int64 collectionID = 4;
repeated int64 partitionIDs = 5;
repeated int64 segmentIDs = 6;
DataScope scope = 7; // All, Streaming, Historical
}
message SearchRequest {
......
......@@ -166,10 +166,8 @@ func syncReplicaSegments(ctx context.Context, cluster Cluster, childTasks []task
}
for dmc, leaders := range shardLeaders {
segments, ok := shardSegments[dmc]
if !ok {
continue
}
// invoke sync segments even no segment
segments := shardSegments[dmc]
for _, leader := range leaders {
req := querypb.SyncReplicaSegmentsRequest{
......@@ -187,7 +185,6 @@ func syncReplicaSegments(ctx context.Context, cluster Cluster, childTasks []task
})
}
}
err := cluster.SyncReplicaSegments(ctx, leader.LeaderID, &req)
if err != nil {
return err
......
......@@ -18,9 +18,7 @@ package querynode
import (
"context"
"errors"
"fmt"
"sync"
"go.uber.org/zap"
......@@ -353,11 +351,18 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *queryPb.ReleaseS
// collection lock is not needed since we guarantee not query/search will be dispatch from leader
for _, id := range in.SegmentIDs {
node.metaReplica.removeSegment(id, segmentTypeSealed)
node.metaReplica.removeSegment(id, segmentTypeGrowing)
switch in.GetScope() {
case queryPb.DataScope_Streaming:
node.metaReplica.removeSegment(id, segmentTypeGrowing)
case queryPb.DataScope_Historical:
node.metaReplica.removeSegment(id, segmentTypeSealed)
case queryPb.DataScope_All:
node.metaReplica.removeSegment(id, segmentTypeSealed)
node.metaReplica.removeSegment(id, segmentTypeGrowing)
}
}
log.Info("release segments done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", in.SegmentIDs))
log.Info("release segments done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", in.SegmentIDs), zap.String("Scope", in.GetScope().String()))
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
......@@ -513,72 +518,36 @@ func (node *QueryNode) Search(ctx context.Context, req *queryPb.SearchRequest) (
var results []*internalpb.SearchResults
var streamingResult *internalpb.SearchResults
var wg sync.WaitGroup
var errCluster error
wg.Add(1) // search cluster
go func() {
defer wg.Done()
// shard leader dispatches request to its shard cluster
oResults, cErr := cluster.Search(searchCtx, req)
if cErr != nil {
log.Warn("search cluster failed", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(cErr))
cancel()
errCluster = cErr
return
}
results = oResults
}()
var errStreaming error
wg.Add(1) // search streaming
go func() {
defer func() {
if errStreaming != nil {
cancel()
}
}()
defer wg.Done()
streamingTask, err2 := newSearchTask(searchCtx, req)
if err2 != nil {
errStreaming = err2
withStreaming := func(ctx context.Context) error {
streamingTask, err := newSearchTask(searchCtx, req)
if err != nil {
return err
}
streamingTask.QS = qs
streamingTask.DataScope = querypb.DataScope_Streaming
err2 = node.scheduler.AddReadTask(searchCtx, streamingTask)
if err2 != nil {
errStreaming = err2
return
err = node.scheduler.AddReadTask(searchCtx, streamingTask)
if err != nil {
return err
}
err2 = streamingTask.WaitToFinish()
if err2 != nil {
errStreaming = err2
return
err = streamingTask.WaitToFinish()
if err != nil {
return err
}
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.SearchLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.SearchLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
streamingResult = streamingTask.Ret
}()
wg.Wait()
var mainErr error
if errCluster != nil {
mainErr = errCluster
if errors.Is(errCluster, context.Canceled) {
if errStreaming != nil {
mainErr = errStreaming
}
}
} else if errStreaming != nil {
mainErr = errStreaming
return nil
}
if mainErr != nil {
failRet.Status.Reason = mainErr.Error()
// shard leader dispatches request to its shard cluster
results, errCluster = cluster.Search(searchCtx, req, withStreaming)
if errCluster != nil {
log.Warn("search cluster failed", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster))
failRet.Status.Reason = errCluster.Error()
return failRet, nil
}
......@@ -694,66 +663,34 @@ func (node *QueryNode) Query(ctx context.Context, req *queryPb.QueryRequest) (*i
var results []*internalpb.RetrieveResults
var streamingResult *internalpb.RetrieveResults
var wg sync.WaitGroup
var errCluster error
wg.Add(1)
go func() {
defer wg.Done()
// shard leader dispatches request to its shard cluster
oResults, cErr := cluster.Query(queryCtx, req)
if cErr != nil {
log.Warn("failed to query cluster", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(cErr))
errCluster = cErr
cancel()
return
}
results = oResults
}()
var errStreaming error
wg.Add(1)
go func() {
defer wg.Done()
withStreaming := func(ctx context.Context) error {
streamingTask := newQueryTask(queryCtx, req)
streamingTask.DataScope = querypb.DataScope_Streaming
streamingTask.QS = qs
err2 := node.scheduler.AddReadTask(queryCtx, streamingTask)
defer func() {
errStreaming = err2
if err2 != nil {
cancel()
}
}()
if err2 != nil {
return
err := node.scheduler.AddReadTask(queryCtx, streamingTask)
if err != nil {
return err
}
err2 = streamingTask.WaitToFinish()
if err2 != nil {
return
err = streamingTask.WaitToFinish()
if err != nil {
return err
}
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.QueryLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.QueryLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
streamingResult = streamingTask.Ret
}()
wg.Wait()
var mainErr error
if errCluster != nil {
mainErr = errCluster
if errors.Is(errCluster, context.Canceled) {
if errStreaming != nil {
mainErr = errStreaming
}
}
} else if errStreaming != nil {
mainErr = errStreaming
return nil
}
if mainErr != nil {
failRet.Status.Reason = mainErr.Error()
var errCluster error
// shard leader dispatches request to its shard cluster
results, errCluster = cluster.Query(queryCtx, req, withStreaming)
if errCluster != nil {
log.Warn("failed to query cluster", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster))
failRet.Status.Reason = errCluster.Error()
return failRet, nil
}
......
......@@ -392,10 +392,19 @@ func TestImpl_ReleaseSegments(t *testing.T) {
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
SegmentIDs: []UniqueID{defaultSegmentID},
Scope: queryPb.DataScope_All,
}
_, err = node.ReleaseSegments(ctx, req)
assert.NoError(t, err)
req.Scope = queryPb.DataScope_Streaming
_, err = node.ReleaseSegments(ctx, req)
assert.NoError(t, err)
req.Scope = queryPb.DataScope_Historical
_, err = node.ReleaseSegments(ctx, req)
assert.NoError(t, err)
})
wg.Add(1)
......
此差异已折叠。
......@@ -79,6 +79,21 @@ func buildMockQueryNode(nodeID int64, addr string) shardQueryNode {
}
}
func segmentEventsToSyncInfo(events []segmentEvent) []*querypb.ReplicaSegmentsInfo {
infos := make([]*querypb.ReplicaSegmentsInfo, 0, len(events))
for _, event := range events {
for _, nodeID := range event.nodeIDs {
infos = append(infos, &querypb.ReplicaSegmentsInfo{
NodeId: nodeID,
SegmentIds: []int64{event.segmentID},
PartitionId: event.partitionID,
})
}
}
return infos
}
func TestShardCluster_Create(t *testing.T) {
collectionID := int64(1)
vchannelName := "dml_1_1_v0"
......@@ -97,6 +112,7 @@ func TestShardCluster_Create(t *testing.T) {
{
nodeID: 1,
nodeAddr: "addr_1",
isLeader: true,
},
{
nodeID: 2,
......@@ -114,6 +130,11 @@ func TestShardCluster_Create(t *testing.T) {
assert.True(t, has)
assert.Equal(t, e.nodeAddr, node.nodeAddr)
}
sc.mut.Lock()
defer sc.mut.Unlock()
require.NotNil(t, sc.leader)
assert.Equal(t, int64(1), sc.leader.nodeID)
})
t.Run("init segments", func(t *testing.T) {
......@@ -461,8 +482,11 @@ func TestShardCluster_segmentEvent(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(segmentEvents), segmentStateLoaded)
// make reference greater than 0
allocs := sc.segmentAllocations(nil)
_, versionID := sc.segmentAllocations(nil)
defer sc.finishUsage(versionID)
evtCh <- segmentEvent{
segmentID: 4,
......@@ -515,14 +539,6 @@ func TestShardCluster_segmentEvent(t *testing.T) {
_, has := sc.getSegment(4)
assert.False(t, has)
sc.mut.RLock()
assert.Equal(t, 0, len(sc.legacySegments))
sc.mut.RUnlock()
sc.finishUsage(allocs)
sc.mut.RLock()
assert.Equal(t, 0, len(sc.legacySegments))
sc.mut.RUnlock()
})
t.Run("from loaded, node changed", func(t *testing.T) {
......@@ -558,8 +574,11 @@ func TestShardCluster_segmentEvent(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(segmentEvents), segmentStateLoaded)
// make reference greater than 0
allocs := sc.segmentAllocations(nil)
_, versionID := sc.segmentAllocations(nil)
defer sc.finishUsage(versionID)
// bring segment online in the other querynode
evtCh <- segmentEvent{
......@@ -586,18 +605,6 @@ func TestShardCluster_segmentEvent(t *testing.T) {
return has && seg.nodeID == 1 && seg.state == segmentStateLoaded
}, time.Second, time.Millisecond)
sc.mut.RLock()
assert.Equal(t, 2, len(sc.legacySegments))
assert.ElementsMatch(t, []shardSegmentInfo{
{segmentID: 1, nodeID: 1, state: segmentStateLoaded, inUse: 1},
{segmentID: 2, nodeID: 2, state: segmentStateLoaded, inUse: 1},
}, sc.legacySegments)
sc.mut.RUnlock()
sc.finishUsage(allocs)
sc.mut.RLock()
assert.Equal(t, 0, len(sc.legacySegments))
sc.mut.RUnlock()
})
t.Run("from offline", func(t *testing.T) {
......@@ -641,6 +648,9 @@ func TestShardCluster_segmentEvent(t *testing.T) {
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
evtCh <- segmentEvent{
segmentID: 3,
nodeIDs: []int64{3},
......@@ -793,6 +803,7 @@ func TestShardCluster_segmentEvent(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
// non-exist segment
evtCh <- segmentEvent{
segmentID: 4,
......@@ -955,6 +966,9 @@ func TestShardCluster_SyncSegments(t *testing.T) {
}
var streamingDoNothing = func(context.Context) error { return nil }
var streamingError = func(context.Context) error { return errors.New("mock streaming error") }
func TestShardCluster_Search(t *testing.T) {
collectionID := int64(1)
vchannelName := "dml_1_1_v0"
......@@ -1005,7 +1019,7 @@ func TestShardCluster_Search(t *testing.T) {
_, err := sc.Search(ctx, &querypb.SearchRequest{
DmlChannel: vchannelName,
})
}, streamingDoNothing)
assert.Error(t, err)
})
......@@ -1017,7 +1031,7 @@ func TestShardCluster_Search(t *testing.T) {
_, err := sc.Search(ctx, &querypb.SearchRequest{
DmlChannel: vchannelName + "_suffix",
})
}, streamingDoNothing)
assert.Error(t, err)
})
......@@ -1059,15 +1073,67 @@ func TestShardCluster_Search(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
result, err := sc.Search(ctx, &querypb.SearchRequest{
DmlChannel: vchannelName,
})
}, streamingDoNothing)
assert.NoError(t, err)
assert.Equal(t, len(nodeEvents), len(result))
})
t.Run("with streaming fail", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeIDs: []int64{1},
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
{
segmentID: 3,
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Search(ctx, &querypb.SearchRequest{
DmlChannel: vchannelName,
}, func(ctx context.Context) error { return errors.New("mocked") })
assert.Error(t, err)
})
t.Run("partial fail", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
......@@ -1114,11 +1180,14 @@ func TestShardCluster_Search(t *testing.T) {
})
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Search(ctx, &querypb.SearchRequest{
DmlChannel: vchannelName,
})
}, streamingDoNothing)
assert.Error(t, err)
})
......@@ -1155,7 +1224,7 @@ func TestShardCluster_Search(t *testing.T) {
//mock meta error
sc.mut.Lock()
sc.segments[3] = &shardSegmentInfo{
sc.segments[3] = shardSegmentInfo{
segmentID: 3,
nodeID: 3, // node does not exist
state: segmentStateLoaded,
......@@ -1163,11 +1232,14 @@ func TestShardCluster_Search(t *testing.T) {
sc.mut.Unlock()
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Search(ctx, &querypb.SearchRequest{
DmlChannel: vchannelName,
})
}, streamingDoNothing)
assert.Error(t, err)
})
}
......@@ -1218,11 +1290,14 @@ func TestShardCluster_Query(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, unavailable, sc.state.Load())
_, err := sc.Query(ctx, &querypb.QueryRequest{
DmlChannel: vchannelName,
})
}, streamingDoNothing)
assert.Error(t, err)
})
t.Run("query wrong channel", func(t *testing.T) {
......@@ -1230,10 +1305,12 @@ func TestShardCluster_Query(t *testing.T) {
&mockNodeDetector{}, &mockSegmentDetector{}, buildMockQueryNode)
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
_, err := sc.Query(ctx, &querypb.QueryRequest{
DmlChannel: vchannelName + "_suffix",
})
}, streamingDoNothing)
assert.Error(t, err)
})
t.Run("normal query", func(t *testing.T) {
......@@ -1274,16 +1351,18 @@ func TestShardCluster_Query(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
result, err := sc.Query(ctx, &querypb.QueryRequest{
DmlChannel: vchannelName,
})
}, streamingDoNothing)
assert.NoError(t, err)
assert.Equal(t, len(nodeEvents), len(result))
})
t.Run("partial fail", func(t *testing.T) {
t.Run("with streaming fail", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
......@@ -1318,25 +1397,21 @@ func TestShardCluster_Query(t *testing.T) {
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, func(nodeID int64, addr string) shardQueryNode {
if nodeID != 2 { // hard code error one
return buildMockQueryNode(nodeID, addr)
}
return &mockShardQueryNode{
searchErr: errors.New("mocked error"),
queryErr: errors.New("mocked error"),
}
})
}, buildMockQueryNode)
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Query(ctx, &querypb.QueryRequest{
DmlChannel: vchannelName,
})
}, func(ctx context.Context) error { return errors.New("mocked") })
assert.Error(t, err)
})
t.Run("test meta error", func(t *testing.T) {
t.Run("partial fail", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
......@@ -1359,6 +1434,11 @@ func TestShardCluster_Query(t *testing.T) {
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
{
segmentID: 3,
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
......@@ -1366,35 +1446,28 @@ func TestShardCluster_Query(t *testing.T) {
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
//mock meta error
sc.mut.Lock()
sc.segments[3] = &shardSegmentInfo{
segmentID: 3,
nodeID: 3, // node does not exist
state: segmentStateLoaded,
}
sc.mut.Unlock()
}, func(nodeID int64, addr string) shardQueryNode {
if nodeID != 2 { // hard code error one
return buildMockQueryNode(nodeID, addr)
}
return &mockShardQueryNode{
searchErr: errors.New("mocked error"),
queryErr: errors.New("mocked error"),
}
})
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Query(ctx, &querypb.QueryRequest{
DmlChannel: vchannelName,
})
}, streamingDoNothing)
assert.Error(t, err)
})
}
func TestShardCluster_ReferenceCount(t *testing.T) {
collectionID := int64(1)
vchannelName := "dml_1_1_v0"
replicaID := int64(0)
// ctx := context.Background()
t.Run("normal alloc & finish", func(t *testing.T) {
t.Run("test meta error", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
......@@ -1418,36 +1491,52 @@ func TestShardCluster_ReferenceCount(t *testing.T) {
state: segmentStateLoaded,
},
}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
//mock meta error
sc.mut.Lock()
sc.segments[3] = shardSegmentInfo{
segmentID: 3,
nodeID: 3, // node does not exist
state: segmentStateLoaded,
}
sc.mut.Unlock()
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
allocs := sc.segmentAllocations(nil)
require.EqualValues(t, available, sc.state.Load())
sc.mut.RLock()
for _, segment := range sc.segments {
assert.Greater(t, segment.inUse, int32(0))
}
sc.mut.RUnlock()
_, err := sc.Query(ctx, &querypb.QueryRequest{
DmlChannel: vchannelName,
}, streamingDoNothing)
assert.Error(t, err)
})
assert.True(t, sc.segmentsInUse([]shardSegmentInfo{{nodeID: 1, segmentID: 1}, {nodeID: 2, segmentID: 2}}))
assert.True(t, sc.segmentsInUse([]shardSegmentInfo{{nodeID: 1, segmentID: 1}, {nodeID: 2, segmentID: 2}, {nodeID: 2, segmentID: -1}}))
}
sc.finishUsage(allocs)
sc.mut.RLock()
for _, segment := range sc.segments {
assert.EqualValues(t, segment.inUse, 0)
}
sc.mut.RUnlock()
assert.False(t, sc.segmentsInUse([]shardSegmentInfo{{nodeID: 1, segmentID: 1}, {nodeID: 2, segmentID: 2}}))
assert.False(t, sc.segmentsInUse([]shardSegmentInfo{{nodeID: 1, segmentID: 1}, {nodeID: 2, segmentID: 2}, {nodeID: 2, segmentID: -1}}))
func TestShardCluster_Version(t *testing.T) {
collectionID := int64(1)
vchannelName := "dml_1_1_v0"
replicaID := int64(0)
// ctx := context.Background()
t.Run("alloc with non-serviceable", func(t *testing.T) {
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{}, &mockSegmentDetector{}, buildMockQueryNode)
defer sc.Close()
_, v := sc.segmentAllocations(nil)
assert.Equal(t, int64(0), v)
})
t.Run("alloc & finish with modified alloc", func(t *testing.T) {
t.Run("normal alloc & finish", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
......@@ -1462,7 +1551,7 @@ func TestShardCluster_ReferenceCount(t *testing.T) {
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeIDs: []int64{2},
nodeIDs: []int64{1},
state: segmentStateLoaded,
},
{
......@@ -1479,27 +1568,18 @@ func TestShardCluster_ReferenceCount(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
allocs := sc.segmentAllocations(nil)
sc.SyncSegments(nil, segmentStateLoaded)
_, version := sc.segmentAllocations(nil)
sc.mut.RLock()
for _, segment := range sc.segments {
assert.Greater(t, segment.inUse, int32(0))
}
assert.Equal(t, version, sc.currentVersion.versionID)
assert.Equal(t, int64(1), sc.currentVersion.inUse.Load())
sc.mut.RUnlock()
for node, segments := range allocs {
segments = append(segments, -1) // add non-exist segment
// shall be ignored in finishUsage
allocs[node] = segments
}
assert.NotPanics(t, func() {
sc.finishUsage(allocs)
})
sc.finishUsage(version)
sc.mut.RLock()
for _, segment := range sc.segments {
assert.EqualValues(t, segment.inUse, 0)
}
assert.Equal(t, int64(0), sc.currentVersion.inUse.Load())
sc.mut.RUnlock()
})
......@@ -1537,6 +1617,8 @@ func TestShardCluster_ReferenceCount(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(nil, segmentStateLoaded)
assert.True(t, sc.segmentsOnline([]shardSegmentInfo{{nodeID: 1, segmentID: 1}, {nodeID: 2, segmentID: 2}}))
assert.False(t, sc.segmentsOnline([]shardSegmentInfo{{nodeID: 1, segmentID: 1}, {nodeID: 2, segmentID: 2}, {nodeID: 1, segmentID: 3}}))
......@@ -1597,6 +1679,8 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{2}},
......@@ -1700,8 +1784,10 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
// add rc to all segments
allocs := sc.segmentAllocations(nil)
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
//add in-use count
_, versionID := sc.segmentAllocations(nil)
sig := make(chan struct{})
go func() {
......@@ -1718,11 +1804,6 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
close(sig)
}()
sc.mut.RLock()
// still waiting online
assert.Equal(t, 0, len(sc.handoffs))
sc.mut.RUnlock()
evtCh <- segmentEvent{
eventType: segmentAdd,
segmentID: 3,
......@@ -1734,10 +1815,10 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
assert.Eventually(t, func() bool {
sc.mut.RLock()
defer sc.mut.RUnlock()
return len(sc.handoffs) > 0
return sc.currentVersion.versionID != versionID
}, time.Second, time.Millisecond*10)
tmpAllocs := sc.segmentAllocations(nil)
tmpAllocs, nVersionID := sc.segmentAllocations(nil)
found := false
for _, segments := range tmpAllocs {
if inList(segments, int64(1)) {
......@@ -1747,9 +1828,9 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}
// segment 1 shall not be allocated again!
assert.False(t, found)
sc.finishUsage(tmpAllocs)
sc.finishUsage(nVersionID)
// rc shall be 0 now
sc.finishUsage(allocs)
sc.finishUsage(versionID)
// wait handoff finished
<-sig
......@@ -1795,8 +1876,10 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
// add rc to all segments
allocs := sc.segmentAllocations(nil)
_, versionID := sc.segmentAllocations(nil)
sig := make(chan struct{})
go func() {
......@@ -1813,11 +1896,6 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
close(sig)
}()
sc.mut.RLock()
// still waiting online
assert.Equal(t, 0, len(sc.handoffs))
sc.mut.RUnlock()
evtCh <- segmentEvent{
eventType: segmentAdd,
segmentID: 1,
......@@ -1829,10 +1907,10 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
assert.Eventually(t, func() bool {
sc.mut.RLock()
defer sc.mut.RUnlock()
return len(sc.handoffs) > 0
return sc.currentVersion.versionID != versionID
}, time.Second, time.Millisecond*10)
tmpAllocs := sc.segmentAllocations(nil)
tmpAllocs, tmpVersionID := sc.segmentAllocations(nil)
for nodeID, segments := range tmpAllocs {
for _, segment := range segments {
if segment == int64(1) {
......@@ -1840,9 +1918,9 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}
}
}
sc.finishUsage(tmpAllocs)
sc.finishUsage(tmpVersionID)
// rc shall be 0 now
sc.finishUsage(allocs)
sc.finishUsage(versionID)
// wait handoff finished
<-sig
......@@ -1889,6 +1967,8 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{2}},
......
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynode
import (
"sync"
"go.uber.org/atomic"
)
// SegmentsStatus alias for map[int64]shardSegmentInfo.
// Provides some helper function to get segment allocation.
type SegmentsStatus map[int64]shardSegmentInfo
// GetAllocations returns node to segments mappings.
func (s SegmentsStatus) GetAllocations(partitionIDs []int64) map[int64][]int64 {
result := make(map[int64][]int64) // nodeID => segmentIDs
// only read operations here, no need to lock
for _, segment := range s {
if len(partitionIDs) > 0 && !inList(partitionIDs, segment.partitionID) {
continue
}
result[segment.nodeID] = append(result[segment.nodeID], segment.segmentID)
}
return result
}
// Clone returns a copy of segments status data.
func (s SegmentsStatus) Clone(filter func(int64) bool) SegmentsStatus {
c := make(map[int64]shardSegmentInfo)
for k, v := range s {
if filter(v.segmentID) {
continue
}
c[k] = v
}
return c
}
// helper filter function that filters nothing
var filterNothing = func(int64) bool { return false }
// ShardClusterVersion maintains a snapshot of sealed segments allocation.
type ShardClusterVersion struct {
versionID int64 // identifier for version
segments SegmentsStatus // nodeID => []segmentID
current *atomic.Bool // is this version current
inUse *atomic.Int64
ch chan struct{} // signal channel to notify safe remove
closeOnce sync.Once
}
// NewShardClusterVersion creates a version with id and allocation.
func NewShardClusterVersion(vID int64, status SegmentsStatus) *ShardClusterVersion {
return &ShardClusterVersion{
versionID: vID,
segments: status,
current: atomic.NewBool(true), // by default new version will be current
inUse: atomic.NewInt64(0),
ch: make(chan struct{}),
}
}
// IsCurrent returns whether this version is current version.
func (v *ShardClusterVersion) IsCurrent() bool {
return v.current.Load()
}
// GetAllocation returns version allocation and record in-use.
func (v *ShardClusterVersion) GetAllocation(partitionIDs []int64) map[int64][]int64 {
v.inUse.Add(1)
return v.segments.GetAllocations(partitionIDs)
}
// FinishUsage decreases the inUse count and cause pending change check.
func (v *ShardClusterVersion) FinishUsage() {
v.inUse.Add(-1)
v.checkSafeGC()
}
// Expire sets the current flag to false for this version.
// invocation shall be goroutine safe for Expire.
func (v *ShardClusterVersion) Expire() chan struct{} {
v.current.Store(false)
v.checkSafeGC()
return v.ch
}
// checkSafeGC check version is safe to release changeInfo offline segments.
func (v *ShardClusterVersion) checkSafeGC() {
if !v.IsCurrent() && v.inUse.Load() == int64(0) {
v.closeOnce.Do(func() {
close(v.ch)
})
}
}
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynode
import (
"testing"
"github.com/stretchr/testify/assert"
)
func channelClose(ch chan struct{}) bool {
select {
case <-ch:
return true
default:
return false
}
}
func TestShardClusterVersion(t *testing.T) {
t.Run("new version", func(t *testing.T) {
v := NewShardClusterVersion(1, SegmentsStatus{})
assert.True(t, v.IsCurrent())
assert.Equal(t, int64(1), v.versionID)
assert.Equal(t, int64(0), v.inUse.Load())
})
t.Run("version expired", func(t *testing.T) {
v := NewShardClusterVersion(1, SegmentsStatus{})
assert.True(t, v.IsCurrent())
ch := v.Expire()
assert.False(t, v.IsCurrent())
assert.True(t, channelClose(ch))
})
t.Run("In use check", func(t *testing.T) {
v := NewShardClusterVersion(1, SegmentsStatus{
1: shardSegmentInfo{segmentID: 1, partitionID: 0, nodeID: 1},
2: shardSegmentInfo{segmentID: 2, partitionID: 1, nodeID: 2},
})
allocs := v.GetAllocation([]int64{1})
assert.EqualValues(t, map[int64][]int64{2: {2}}, allocs)
assert.Equal(t, int64(1), v.inUse.Load())
ch := v.Expire()
assert.False(t, channelClose(ch))
v.FinishUsage()
assert.True(t, channelClose(ch))
})
}
......@@ -89,6 +89,15 @@ func (nd *etcdShardNodeDetector) watchNodes(collectionID int64, replicaID int64,
if info.GetCollectionID() != collectionID || info.GetReplicaID() != replicaID {
continue
}
// find the leader id for the shard replica
var leaderID int64
for _, shardReplica := range info.GetShardReplicas() {
if shardReplica.GetDmChannelName() == vchannelName {
leaderID = shardReplica.GetLeaderID()
break
}
}
// generate node event
for _, nodeID := range info.GetNodeIds() {
addr, has := idAddr[nodeID]
if !has {
......@@ -99,6 +108,7 @@ func (nd *etcdShardNodeDetector) watchNodes(collectionID int64, replicaID int64,
nodeID: nodeID,
nodeAddr: addr,
eventType: nodeAdd,
isLeader: nodeID == leaderID,
})
}
}
......
......@@ -61,6 +61,12 @@ func TestEtcdShardNodeDetector_watch(t *testing.T) {
CollectionID: 1,
ReplicaID: 1,
NodeIds: []int64{1, 2},
ShardReplicas: []*milvuspb.ShardReplica{
{
LeaderID: 1,
DmChannelName: "dml1",
},
},
},
},
oldGarbage: map[string]string{
......@@ -71,6 +77,7 @@ func TestEtcdShardNodeDetector_watch(t *testing.T) {
nodeID: 1,
nodeAddr: "1",
eventType: nodeAdd,
isLeader: true,
},
{
nodeID: 2,
......@@ -80,6 +87,7 @@ func TestEtcdShardNodeDetector_watch(t *testing.T) {
},
collectionID: 1,
replicaID: 1,
channel: "dml1",
},
{
name: "normal case with other replica",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册