未验证 提交 41d9ab3d 编写于 作者: X Xiaofan 提交者: GitHub

Fix Dead lock in shard manager (#23446)

Signed-off-by: Nxiaofan-luan <xiaofan.luan@zilliz.com>
上级 52e8460e
......@@ -19,6 +19,7 @@ package querynode
import (
"context"
"sync"
"time"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/typeutil"
......@@ -42,9 +43,6 @@ type distribution struct {
// version indicator
version int64
// offline is the quick healthy check indicator for offline segments
offlines *atomic.Int32
snapshots *typeutil.ConcurrentMap[int64, *snapshot]
// current is the snapshot for quick usage for search/query
// generated for each change of distribution
......@@ -77,7 +75,6 @@ func NewDistribution(replicaID int64) *distribution {
replicaID: replicaID,
sealedSegments: make(map[UniqueID]SegmentEntry),
snapshots: typeutil.NewConcurrentMap[int64, *snapshot](),
offlines: atomic.NewInt32(0),
current: atomic.NewPointer[snapshot](nil),
}
......@@ -89,16 +86,29 @@ func (d *distribution) getLogger() *log.MLogger {
return log.Ctx(context.Background()).With(zap.Int64("replicaID", d.replicaID))
}
// Serviceable returns whether all segment recorded is in loaded state.
func (d *distribution) Serviceable() bool {
return d.offlines.Load() == 0
d.mut.RLock()
defer d.mut.RUnlock()
return d.serviceableImpl()
}
// Serviceable returns whether all segment recorded is in loaded state, hold d.mut before call it
func (d *distribution) serviceableImpl() bool {
for _, entry := range d.sealedSegments {
if entry.State != segmentStateLoaded {
return false
}
}
return true
}
// GetCurrent returns current snapshot.
func (d *distribution) GetCurrent(partitions ...int64) (sealed []SnapshotItem, version int64) {
d.mut.RLock()
defer d.mut.RUnlock()
if !d.serviceableImpl() {
return nil, -1
}
current := d.current.Load()
sealed = current.Get(partitions...)
version = current.version
......@@ -142,14 +152,15 @@ func (d *distribution) UpdateDistribution(entries ...SegmentEntry) {
for _, entry := range entries {
old, ok := d.sealedSegments[entry.SegmentID]
d.getLogger().Info("Update distribution", zap.Int64("segmentID", entry.SegmentID),
zap.Int64("node", entry.NodeID),
zap.Bool("segment exist", ok))
if !ok {
d.sealedSegments[entry.SegmentID] = entry
if entry.State == segmentStateOffline {
d.offlines.Add(1)
}
continue
}
d.updateSegment(old, entry)
old.Update(entry)
d.sealedSegments[old.SegmentID] = old
}
d.genSnapshot()
......@@ -160,43 +171,14 @@ func (d *distribution) NodeDown(nodeID int64) {
d.mut.Lock()
defer d.mut.Unlock()
var delta int32
d.getLogger().Info("handle node down", zap.Int64("node", nodeID))
for _, entry := range d.sealedSegments {
if entry.NodeID == nodeID && entry.State != segmentStateOffline {
entry.State = segmentStateOffline
d.sealedSegments[entry.SegmentID] = entry
delta++
d.getLogger().Info("update the segment to offline since nodeDown", zap.Int64("nodeID", nodeID), zap.Int64("segmentID", entry.SegmentID))
}
}
if delta != 0 {
d.offlines.Add(delta)
d.getLogger().Info("distribution updated since nodeDown", zap.Int32("delta", delta), zap.Int32("offlines", d.offlines.Load()), zap.Int64("nodeID", nodeID))
}
}
// updateSegment update segment entry value and offline segment number based on old/new state.
func (d *distribution) updateSegment(old, new SegmentEntry) {
delta := int32(0)
switch {
case old.State != segmentStateLoaded && new.State == segmentStateLoaded:
delta = -1
case old.State == segmentStateLoaded && new.State != segmentStateLoaded:
delta = 1
}
old.Update(new)
d.sealedSegments[old.SegmentID] = old
if delta != 0 {
d.offlines.Add(delta)
d.getLogger().Info("distribution updated since segment update",
zap.Int32("delta", delta),
zap.Int32("offlines", d.offlines.Load()),
zap.Int64("segmentID", new.SegmentID),
zap.Int32("state", int32(new.State)),
)
}
}
// RemoveDistributions remove segments distributions and returns the clear signal channel,
......@@ -204,32 +186,28 @@ func (d *distribution) updateSegment(old, new SegmentEntry) {
func (d *distribution) RemoveDistributions(releaseFn func(), sealedSegments ...SegmentEntry) {
d.mut.Lock()
defer d.mut.Unlock()
var delta int32
for _, sealed := range sealedSegments {
entry, ok := d.sealedSegments[sealed.SegmentID]
d.getLogger().Info("Remove distribution", zap.Int64("segmentID", sealed.SegmentID),
zap.Int64("node", sealed.NodeID),
zap.Bool("segment exist", ok))
if !ok {
continue
}
if entry.NodeID == sealed.NodeID || sealed.NodeID == wildcardNodeID {
if entry.State == segmentStateOffline {
delta--
}
delete(d.sealedSegments, sealed.SegmentID)
}
}
d.offlines.Add(delta)
ts := time.Now()
<-d.genSnapshot()
releaseFn()
d.getLogger().Info("successfully remove distribution", zap.Any("segments", sealedSegments), zap.Duration("time", time.Since(ts)))
}
// getSnapshot converts current distribution to snapshot format.
// in which, user could juse found nodeID=>segmentID list.
// mutex RLock is required before calling this method.
func (d *distribution) genSnapshot() chan struct{} {
nodeSegments := make(map[int64][]SegmentEntry)
for _, entry := range d.sealedSegments {
nodeSegments[entry.NodeID] = append(nodeSegments[entry.NodeID], entry)
......@@ -260,6 +238,7 @@ func (d *distribution) genSnapshot() chan struct{} {
return ch
}
d.getLogger().Info("gen snapshot for version", zap.Any("version", d.version), zap.Any("is serviceable", d.serviceableImpl()))
last.Expire(d.getCleanup(last.version))
return last.cleared
......
......@@ -50,10 +50,12 @@ func (s *DistributionSuite) TestAddDistribution() {
{
NodeID: 1,
SegmentID: 1,
State: segmentStateLoaded,
},
{
NodeID: 1,
SegmentID: 2,
State: segmentStateLoaded,
},
},
expected: []SnapshotItem{
......@@ -63,10 +65,12 @@ func (s *DistributionSuite) TestAddDistribution() {
{
NodeID: 1,
SegmentID: 1,
State: segmentStateLoaded,
},
{
NodeID: 1,
SegmentID: 2,
State: segmentStateLoaded,
},
},
},
......@@ -78,14 +82,17 @@ func (s *DistributionSuite) TestAddDistribution() {
{
NodeID: 1,
SegmentID: 1,
State: segmentStateLoaded,
},
{
NodeID: 2,
SegmentID: 2,
State: segmentStateLoaded,
},
{
NodeID: 1,
SegmentID: 3,
State: segmentStateLoaded,
},
},
expected: []SnapshotItem{
......@@ -95,11 +102,13 @@ func (s *DistributionSuite) TestAddDistribution() {
{
NodeID: 1,
SegmentID: 1,
State: segmentStateLoaded,
},
{
NodeID: 1,
SegmentID: 3,
State: segmentStateLoaded,
},
},
},
......@@ -109,6 +118,7 @@ func (s *DistributionSuite) TestAddDistribution() {
{
NodeID: 2,
SegmentID: 2,
State: segmentStateLoaded,
},
},
},
......@@ -161,13 +171,13 @@ func (s *DistributionSuite) TestRemoveDistribution() {
{
tag: "remove with no read",
presetSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 2, SegmentID: 2},
{NodeID: 1, SegmentID: 3},
{NodeID: 1, SegmentID: 1, State: segmentStateLoaded},
{NodeID: 2, SegmentID: 2, State: segmentStateLoaded},
{NodeID: 1, SegmentID: 3, State: segmentStateLoaded},
},
removalSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 1, SegmentID: 1, State: segmentStateLoaded},
},
withMockRead: false,
......@@ -176,13 +186,13 @@ func (s *DistributionSuite) TestRemoveDistribution() {
{
NodeID: 1,
Segments: []SegmentEntry{
{NodeID: 1, SegmentID: 3},
{NodeID: 1, SegmentID: 3, State: segmentStateLoaded},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{NodeID: 2, SegmentID: 2},
{NodeID: 2, SegmentID: 2, State: segmentStateLoaded},
},
},
},
......@@ -190,13 +200,13 @@ func (s *DistributionSuite) TestRemoveDistribution() {
{
tag: "remove with wrong nodeID",
presetSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 2, SegmentID: 2},
{NodeID: 1, SegmentID: 3},
{NodeID: 1, SegmentID: 1, State: segmentStateLoaded},
{NodeID: 2, SegmentID: 2, State: segmentStateLoaded},
{NodeID: 1, SegmentID: 3, State: segmentStateLoaded},
},
removalSealed: []SegmentEntry{
{NodeID: 2, SegmentID: 1},
{NodeID: 2, SegmentID: 1, State: segmentStateLoaded},
},
withMockRead: false,
......@@ -205,14 +215,14 @@ func (s *DistributionSuite) TestRemoveDistribution() {
{
NodeID: 1,
Segments: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 1, SegmentID: 3},
{NodeID: 1, SegmentID: 1, State: segmentStateLoaded},
{NodeID: 1, SegmentID: 3, State: segmentStateLoaded},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{NodeID: 2, SegmentID: 2},
{NodeID: 2, SegmentID: 2, State: segmentStateLoaded},
},
},
},
......@@ -220,13 +230,13 @@ func (s *DistributionSuite) TestRemoveDistribution() {
{
tag: "remove with wildcardNodeID",
presetSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 2, SegmentID: 2},
{NodeID: 1, SegmentID: 3},
{NodeID: 1, SegmentID: 1, State: segmentStateLoaded},
{NodeID: 2, SegmentID: 2, State: segmentStateLoaded},
{NodeID: 1, SegmentID: 3, State: segmentStateLoaded},
},
removalSealed: []SegmentEntry{
{NodeID: wildcardNodeID, SegmentID: 1},
{NodeID: wildcardNodeID, SegmentID: 1, State: segmentStateLoaded},
},
withMockRead: false,
......@@ -235,13 +245,13 @@ func (s *DistributionSuite) TestRemoveDistribution() {
{
NodeID: 1,
Segments: []SegmentEntry{
{NodeID: 1, SegmentID: 3},
{NodeID: 1, SegmentID: 3, State: segmentStateLoaded},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{NodeID: 2, SegmentID: 2},
{NodeID: 2, SegmentID: 2, State: segmentStateLoaded},
},
},
},
......@@ -249,13 +259,13 @@ func (s *DistributionSuite) TestRemoveDistribution() {
{
tag: "remove with read",
presetSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 2, SegmentID: 2},
{NodeID: 1, SegmentID: 3},
{NodeID: 1, SegmentID: 1, State: segmentStateLoaded},
{NodeID: 2, SegmentID: 2, State: segmentStateLoaded},
{NodeID: 1, SegmentID: 3, State: segmentStateLoaded},
},
removalSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 1, SegmentID: 1, State: segmentStateLoaded},
},
withMockRead: true,
......@@ -264,13 +274,13 @@ func (s *DistributionSuite) TestRemoveDistribution() {
{
NodeID: 1,
Segments: []SegmentEntry{
{NodeID: 1, SegmentID: 3},
{NodeID: 1, SegmentID: 3, State: segmentStateLoaded},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{NodeID: 2, SegmentID: 2},
{NodeID: 2, SegmentID: 2, State: segmentStateLoaded},
},
},
},
......@@ -407,10 +417,12 @@ func (s *DistributionSuite) TestPeek() {
{
NodeID: 1,
SegmentID: 1,
State: segmentStateLoaded,
},
{
NodeID: 1,
SegmentID: 2,
State: segmentStateLoaded,
},
},
expected: []SnapshotItem{
......@@ -420,10 +432,12 @@ func (s *DistributionSuite) TestPeek() {
{
NodeID: 1,
SegmentID: 1,
State: segmentStateLoaded,
},
{
NodeID: 1,
SegmentID: 2,
State: segmentStateLoaded,
},
},
},
......@@ -435,14 +449,17 @@ func (s *DistributionSuite) TestPeek() {
{
NodeID: 1,
SegmentID: 1,
State: segmentStateLoaded,
},
{
NodeID: 2,
SegmentID: 2,
State: segmentStateLoaded,
},
{
NodeID: 1,
SegmentID: 3,
State: segmentStateLoaded,
},
},
expected: []SnapshotItem{
......@@ -452,11 +469,13 @@ func (s *DistributionSuite) TestPeek() {
{
NodeID: 1,
SegmentID: 1,
State: segmentStateLoaded,
},
{
NodeID: 1,
SegmentID: 3,
State: segmentStateLoaded,
},
},
},
......@@ -466,6 +485,7 @@ func (s *DistributionSuite) TestPeek() {
{
NodeID: 2,
SegmentID: 2,
State: segmentStateLoaded,
},
},
},
......
......@@ -128,7 +128,7 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
s.Run("transfer load fail", func() {
cs, ok := s.querynode.ShardClusterService.getShardCluster(defaultChannelName)
s.Require().True(ok)
cs.nodes[100] = &shardNode{
cs.nodes.InsertIfNotPresent(100, &shardNode{
nodeID: 100,
nodeAddr: "test",
client: &mockShardQueryNode{
......@@ -137,7 +137,7 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
Reason: "error",
},
},
}
})
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
......@@ -161,8 +161,8 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
s.Run("insufficient memory", func() {
cs, ok := s.querynode.ShardClusterService.getShardCluster(defaultChannelName)
s.Require().True(ok)
cs.nodes[100] = &shardNode{
nodeID: 100,
cs.nodes.InsertIfNotPresent(101, &shardNode{
nodeID: 101,
nodeAddr: "test",
client: &mockShardQueryNode{
loadSegmentsResults: &commonpb.Status{
......@@ -170,13 +170,13 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
Reason: "mock InsufficientMemoryToLoad",
},
},
}
})
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID,
},
DstNodeID: 100,
DstNodeID: 101,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID,
......@@ -227,7 +227,7 @@ func (s *ImplUtilsSuite) TestTransferRelease() {
s.Run("transfer release fail", func() {
cs, ok := s.querynode.ShardClusterService.getShardCluster(defaultChannelName)
s.Require().True(ok)
cs.nodes[100] = &shardNode{
cs.nodes.InsertIfNotPresent(100, &shardNode{
nodeID: 100,
nodeAddr: "test",
client: &mockShardQueryNode{
......@@ -235,7 +235,7 @@ func (s *ImplUtilsSuite) TestTransferRelease() {
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
},
}
})
status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{
......
......@@ -62,7 +62,6 @@ const (
type segmentState int32
const (
segmentStateNone segmentState = 0
segmentStateOffline segmentState = 1
segmentStateLoading segmentState = 2
segmentStateLoaded segmentState = 3
......@@ -145,10 +144,9 @@ type ShardCluster struct {
segmentDetector ShardSegmentDetector
nodeBuilder ShardNodeBuilder
mut sync.RWMutex
leader *shardNode // shard leader node instance
nodes map[int64]*shardNode // online nodes
mut sync.RWMutex
leader *shardNode // shard leader node instance
nodes *typeutil.ConcurrentMap[int64, *shardNode] // online nodes
mutVersion sync.RWMutex
distribution *distribution
......@@ -170,8 +168,7 @@ func NewShardCluster(collectionID int64, replicaID int64, vchannelName string, v
segmentDetector: segmentDetector,
nodeDetector: nodeDetector,
nodeBuilder: nodeBuilder,
nodes: make(map[int64]*shardNode),
nodes: typeutil.NewConcurrentMap[int64, *shardNode](),
closeCh: make(chan struct{}),
}
......@@ -203,11 +200,10 @@ func (sc *ShardCluster) getLogger() *log.MLogger {
)
}
// serviceable returns whether shard cluster could provide query service.
// serviceable returns whether shard cluster could provide query service, used only for test
func (sc *ShardCluster) serviceable() bool {
sc.mutVersion.RLock()
defer sc.mutVersion.RUnlock()
return sc.distribution != nil && sc.distribution.Serviceable()
}
......@@ -218,12 +214,13 @@ func (sc *ShardCluster) addNode(evt nodeEvent) {
sc.mut.Lock()
defer sc.mut.Unlock()
oldNode, ok := sc.nodes[evt.nodeID]
oldNode, ok := sc.nodes.Get(evt.nodeID)
if ok {
if oldNode.nodeAddr == evt.nodeAddr {
log.Warn("ShardCluster add same node, skip", zap.Int64("nodeID", evt.nodeID), zap.String("addr", evt.nodeAddr))
return
}
sc.nodes.GetAndRemove(evt.nodeID)
defer oldNode.client.Stop()
}
......@@ -232,7 +229,7 @@ func (sc *ShardCluster) addNode(evt nodeEvent) {
nodeAddr: evt.nodeAddr,
client: sc.nodeBuilder(evt.nodeID, evt.nodeAddr),
}
sc.nodes[evt.nodeID] = node
sc.nodes.InsertIfNotPresent(evt.nodeID, node)
if evt.isLeader {
sc.leader = node
}
......@@ -245,15 +242,13 @@ func (sc *ShardCluster) removeNode(evt nodeEvent) {
sc.mut.Lock()
defer sc.mut.Unlock()
old, ok := sc.nodes[evt.nodeID]
old, ok := sc.nodes.GetAndRemove(evt.nodeID)
if !ok {
log.Warn("ShardCluster removeNode does not belong to it", zap.Int64("nodeID", evt.nodeID), zap.String("addr", evt.nodeAddr))
return
}
defer old.client.Stop()
delete(sc.nodes, evt.nodeID)
sc.distribution.NodeDown(evt.nodeID)
}
......@@ -451,13 +446,7 @@ func (sc *ShardCluster) watchSegments(evtCh <-chan segmentEvent) {
// getNode returns shallow copy of shardNode
func (sc *ShardCluster) getNode(nodeID int64) (*shardNode, bool) {
sc.mut.RLock()
defer sc.mut.RUnlock()
return sc.getNodeImpl(nodeID)
}
func (sc *ShardCluster) getNodeImpl(nodeID int64) (*shardNode, bool) {
node, ok := sc.nodes[nodeID]
node, ok := sc.nodes.Get(nodeID)
if !ok {
return nil, false
}
......@@ -485,8 +474,10 @@ func (sc *ShardCluster) getSegment(segmentID int64) (shardSegmentInfo, bool) {
// segmentAllocations returns node to segments mappings.
// calling this function also increases the reference count of related segments.
func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) (map[int64][]int64, int64) {
if !sc.serviceable() {
return nil, 0
sc.mutVersion.RLock()
defer sc.mutVersion.RUnlock()
if sc.distribution == nil {
return nil, -1
}
items, version := sc.distribution.GetCurrent(partitionIDs...)
return lo.SliceToMap(items, func(item SnapshotItem) (int64, []int64) {
......@@ -496,6 +487,9 @@ func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) (map[int64][]in
// finishUsage decreases the inUse count of provided segments
func (sc *ShardCluster) finishUsage(versionID int64) {
if versionID == -1 {
return
}
sc.distribution.FinishUsage(versionID)
}
......@@ -578,7 +572,7 @@ func (sc *ShardCluster) ReleaseSegments(ctx context.Context, req *querypb.Releas
// requires sc.mut read lock held
releaseFn := func() {
// try to release segments from nodes
node, ok := sc.getNodeImpl(req.GetNodeID())
node, ok := sc.getNode(req.GetNodeID())
if !ok {
log.Warn("node not in cluster", zap.Int64("nodeID", req.NodeID))
err = fmt.Errorf("node %d not in cluster ", req.NodeID)
......@@ -608,9 +602,6 @@ func (sc *ShardCluster) ReleaseSegments(ctx context.Context, req *querypb.Releas
// GetStatistics returns the statistics on the shard cluster.
func (sc *ShardCluster) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, withStreaming getStatisticsWithStreaming) ([]*internalpb.GetStatisticsResponse, error) {
if !sc.serviceable() {
return nil, fmt.Errorf("ShardCluster for %s replicaID %d is not available", sc.vchannelName, sc.replicaID)
}
if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) {
return nil, fmt.Errorf("ShardCluster for %s does not match request channels :%v", sc.vchannelName, req.GetDmlChannels())
}
......@@ -619,6 +610,10 @@ func (sc *ShardCluster) GetStatistics(ctx context.Context, req *querypb.GetStati
segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
defer sc.finishUsage(versionID)
if versionID == -1 {
return nil, fmt.Errorf("ShardCluster for %s replicaID %d is not available", sc.vchannelName, sc.replicaID)
}
log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)))
for nodeID, segmentIDs := range segAllocs {
log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs))
......@@ -698,7 +693,16 @@ func (sc *ShardCluster) GetStatistics(ctx context.Context, req *querypb.GetStati
// Search preforms search operation on shard cluster.
func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest, withStreaming searchWithStreaming) ([]*internalpb.SearchResults, error) {
if !sc.serviceable() {
if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) {
return nil, fmt.Errorf("ShardCluster for %s does not match request channels :%v", sc.vchannelName, req.GetDmlChannels())
}
// get node allocation and maintains the inUse reference count
segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
defer sc.finishUsage(versionID)
// not serviceable
if versionID == -1 {
err := WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName)
log.Warn("failed to search on shard",
zap.Int64("replicaID", sc.replicaID),
......@@ -707,13 +711,6 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest,
)
return nil, err
}
if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) {
return nil, fmt.Errorf("ShardCluster for %s does not match request channels :%v", sc.vchannelName, req.GetDmlChannels())
}
// get node allocation and maintains the inUse reference count
segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
defer sc.finishUsage(versionID)
log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)), zap.Int64s("partitionIDs", req.GetReq().GetPartitionIDs()))
for nodeID, segmentIDs := range segAllocs {
......@@ -807,10 +804,6 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest,
// Query performs query operation on shard cluster.
func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, withStreaming queryWithStreaming) ([]*internalpb.RetrieveResults, error) {
if !sc.serviceable() {
return nil, WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName)
}
// handles only the dml channel part, segment ids is dispatch by cluster itself
if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) {
return nil, fmt.Errorf("ShardCluster for %s does not match to request channels :%v", sc.vchannelName, req.GetDmlChannels())
......@@ -820,6 +813,11 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi
segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
defer sc.finishUsage(versionID)
// not serviceable
if versionID == -1 {
return nil, WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName)
}
// concurrent visiting nodes
var wg sync.WaitGroup
reqCtx, cancel := context.WithCancel(ctx)
......
......@@ -2040,7 +2040,7 @@ func TestShardCluster_Version(t *testing.T) {
defer sc.Close()
_, v := sc.segmentAllocations(nil)
assert.Equal(t, int64(0), v)
assert.Equal(t, int64(-1), v)
})
t.Run("normal alloc & finish", func(t *testing.T) {
......
......@@ -73,7 +73,6 @@ func (s *snapshot) Expire(cleanup snapshotCleanup) {
// Get returns segment distributions with provided partition ids.
func (s *snapshot) Get(partitions ...int64) []SnapshotItem {
s.inUse.Inc()
return s.filter(partitions...)
}
......@@ -120,8 +119,8 @@ func (s *snapshot) checkCleared(cleanup snapshotCleanup) {
go func() {
<-s.last.cleared
s.last = nil
cleanup()
close(s.cleared)
cleanup()
}()
})
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册