未验证 提交 2146af1f 编写于 作者: B bigsheeper 提交者: GitHub

Return insufficient memory error when load failed (#21574)

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 c550427e
......@@ -23,11 +23,23 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
// TODO(dragondriver): add more common error type
// ErrInsufficientMemory returns insufficient memory error.
var ErrInsufficientMemory = errors.New("InsufficientMemoryToLoad")
// InSufficientMemoryStatus returns insufficient memory status.
func InSufficientMemoryStatus(collectionName string) *commonpb.Status {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad,
Reason: fmt.Sprintf("deny to load, insufficient memory, please allocate more resources, collectionName: %s", collectionName),
}
}
func errInvalidNumRows(numRows uint32) error {
return fmt.Errorf("invalid num_rows: %d", numRows)
}
......
......@@ -17,10 +17,14 @@
package proxy
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/log"
)
......@@ -150,3 +154,11 @@ func Test_errProxyIsUnhealthy(t *testing.T) {
zap.Error(errProxyIsUnhealthy(id)))
}
}
func Test_ErrInsufficientMemory(t *testing.T) {
err := fmt.Errorf("%w, mock insufficient memory error", ErrInsufficientMemory)
assert.True(t, errors.Is(err, ErrInsufficientMemory))
status := InSufficientMemoryStatus("collection1")
assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, status.GetErrorCode())
}
......@@ -18,6 +18,7 @@ package proxy
import (
"context"
"errors"
"fmt"
"os"
"strconv"
......@@ -1445,6 +1446,11 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
zap.Strings("partition_name", request.PartitionNames),
zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
if errors.Is(err, ErrInsufficientMemory) {
return &milvuspb.GetLoadingProgressResponse{
Status: InSufficientMemoryStatus(request.GetCollectionName()),
}
}
return &milvuspb.GetLoadingProgressResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -1574,12 +1580,22 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt
var progress int64
if len(request.GetPartitionNames()) == 0 {
if progress, err = getCollectionProgress(ctx, node.queryCoord, request.GetBase(), collectionID); err != nil {
if errors.Is(err, ErrInsufficientMemory) {
return &milvuspb.GetLoadStateResponse{
Status: InSufficientMemoryStatus(request.GetCollectionName()),
}, nil
}
successResponse.State = commonpb.LoadState_LoadStateNotLoad
return successResponse, nil
}
} else {
if progress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(),
request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil {
if errors.Is(err, ErrInsufficientMemory) {
return &milvuspb.GetLoadStateResponse{
Status: InSufficientMemoryStatus(request.GetCollectionName()),
}, nil
}
successResponse.State = commonpb.LoadState_LoadStateNotLoad
return successResponse, nil
}
......
......@@ -4299,4 +4299,31 @@ func TestProxy_GetLoadState(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode)
assert.Equal(t, int64(50), progressResp.Progress)
}
t.Run("test insufficient memory", func(t *testing.T) {
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad},
}, nil
}), SetQueryCoordShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad},
}, nil
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, stateResp.Status.ErrorCode)
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, progressResp.Status.ErrorCode)
progressResp, err = proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo", PartitionNames: []string{"p1"}})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, progressResp.Status.ErrorCode)
})
}
......@@ -1002,6 +1002,11 @@ func getCollectionProgress(ctx context.Context, queryCoord types.QueryCoord,
return 0, err
}
if resp.Status.ErrorCode == commonpb.ErrorCode_InsufficientMemoryToLoad {
log.Warn("detected insufficientMemoryError when getCollectionProgress", zap.Int64("collection_id", collectionID), zap.String("reason", resp.GetStatus().GetReason()))
return 0, ErrInsufficientMemory
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("fail to show collections", zap.Int64("collection_id", collectionID),
zap.String("reason", resp.Status.Reason))
......@@ -1043,6 +1048,11 @@ func getPartitionProgress(ctx context.Context, queryCoord types.QueryCoord,
zap.Error(err))
return 0, err
}
if resp.GetStatus().GetErrorCode() == commonpb.ErrorCode_InsufficientMemoryToLoad {
log.Warn("detected insufficientMemoryError when getPartitionProgress", zap.Int64("collection_id", collectionID),
zap.String("collection_name", collectionName), zap.Strings("partition_names", partitionNames), zap.String("reason", resp.GetStatus().GetReason()))
return 0, ErrInsufficientMemory
}
if len(resp.InMemoryPercentages) != len(partitionIDs) {
errMsg := "fail to show partitions from the querycoord, invalid data num"
log.Warn(errMsg, zap.Int64("collection_id", collectionID),
......
......@@ -186,6 +186,8 @@ func (job *LoadCollectionJob) Execute() error {
zap.Int64("collectionID", req.GetCollectionID()),
)
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// Clear stale replicas
err := job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
......@@ -393,6 +395,8 @@ func (job *LoadPartitionJob) Execute() error {
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// Clear stale replicas
err := job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
......
......@@ -136,6 +136,7 @@ func (suite *JobSuite) SetupTest() {
suite.scheduler = NewScheduler()
suite.scheduler.Start(context.Background())
meta.GlobalFailedLoadCache = meta.NewFailedLoadCache()
}
func (suite *JobSuite) TearDownTest() {
......
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package meta
import (
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
. "github.com/milvus-io/milvus/internal/util/typeutil"
)
const expireTime = 24 * time.Hour
var GlobalFailedLoadCache *FailedLoadCache
type failInfo struct {
count int
err error
lastTime time.Time
}
type FailedLoadCache struct {
mu sync.RWMutex
records map[UniqueID]map[commonpb.ErrorCode]*failInfo
}
func NewFailedLoadCache() *FailedLoadCache {
return &FailedLoadCache{
records: make(map[UniqueID]map[commonpb.ErrorCode]*failInfo),
}
}
func (l *FailedLoadCache) Get(collectionID UniqueID) *commonpb.Status {
l.mu.RLock()
defer l.mu.RUnlock()
status := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
if _, ok := l.records[collectionID]; !ok {
return status
}
if len(l.records[collectionID]) == 0 {
return status
}
var max = 0
for code, info := range l.records[collectionID] {
if info.count > max {
max = info.count
status.ErrorCode = code
status.Reason = info.err.Error()
}
}
log.Warn("FailedLoadCache hits failed record", zap.Int64("collectionID", collectionID),
zap.String("errCode", status.GetErrorCode().String()), zap.String("reason", status.GetReason()))
return status
}
func (l *FailedLoadCache) Put(collectionID UniqueID, errCode commonpb.ErrorCode, err error) {
if errCode == commonpb.ErrorCode_Success {
return
}
l.mu.Lock()
defer l.mu.Unlock()
if _, ok := l.records[collectionID]; !ok {
l.records[collectionID] = make(map[commonpb.ErrorCode]*failInfo)
}
if _, ok := l.records[collectionID][errCode]; !ok {
l.records[collectionID][errCode] = &failInfo{}
}
l.records[collectionID][errCode].count++
l.records[collectionID][errCode].err = err
l.records[collectionID][errCode].lastTime = time.Now()
log.Warn("FailedLoadCache put failed record", zap.Int64("collectionID", collectionID),
zap.String("errCode", errCode.String()), zap.Error(err))
}
func (l *FailedLoadCache) Remove(collectionID UniqueID) {
l.mu.Lock()
defer l.mu.Unlock()
delete(l.records, collectionID)
log.Info("FailedLoadCache removes cache", zap.Int64("collectionID", collectionID))
}
func (l *FailedLoadCache) TryExpire() {
l.mu.Lock()
defer l.mu.Unlock()
for col, infos := range l.records {
for code, info := range infos {
if time.Since(info.lastTime) > expireTime {
delete(l.records[col], code)
}
}
if len(l.records[col]) == 0 {
delete(l.records, col)
log.Info("FailedLoadCache expires cache", zap.Int64("collectionID", col))
}
}
}
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package meta
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
)
func TestFailedLoadCache(t *testing.T) {
GlobalFailedLoadCache = NewFailedLoadCache()
colID := int64(0)
errCode := commonpb.ErrorCode_InsufficientMemoryToLoad
mockErr := fmt.Errorf("mock insufficient memory reason")
GlobalFailedLoadCache.Put(colID, commonpb.ErrorCode_Success, nil)
res := GlobalFailedLoadCache.Get(colID)
assert.Equal(t, commonpb.ErrorCode_Success, res.GetErrorCode())
GlobalFailedLoadCache.Put(colID, errCode, mockErr)
res = GlobalFailedLoadCache.Get(colID)
assert.Equal(t, errCode, res.GetErrorCode())
GlobalFailedLoadCache.Remove(colID)
res = GlobalFailedLoadCache.Get(colID)
assert.Equal(t, commonpb.ErrorCode_Success, res.GetErrorCode())
GlobalFailedLoadCache.Put(colID, errCode, mockErr)
GlobalFailedLoadCache.mu.Lock()
GlobalFailedLoadCache.records[colID][errCode].lastTime = time.Now().Add(-expireTime * 2)
GlobalFailedLoadCache.mu.Unlock()
GlobalFailedLoadCache.TryExpire()
res = GlobalFailedLoadCache.Get(colID)
assert.Equal(t, commonpb.ErrorCode_Success, res.GetErrorCode())
}
......@@ -234,6 +234,9 @@ func (s *Server) Init() error {
// Init observers
s.initObserver()
// Init load status cache
meta.GlobalFailedLoadCache = meta.NewFailedLoadCache()
log.Info("QueryCoord init success")
return err
}
......
......@@ -56,6 +56,7 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg, ErrNotHealthy),
}, nil
}
defer meta.GlobalFailedLoadCache.TryExpire()
isGetAll := false
collectionSet := typeutil.NewUniqueSet(req.GetCollectionIDs()...)
......@@ -86,6 +87,13 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
// ignore it
continue
}
status := meta.GlobalFailedLoadCache.Get(collectionID)
if status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("show collection failed", zap.String("errCode", status.GetErrorCode().String()), zap.String("reason", status.GetReason()))
return &querypb.ShowCollectionsResponse{
Status: status,
}, nil
}
err := fmt.Errorf("collection %d has not been loaded to memory or load failed", collectionID)
log.Warn("show collection failed", zap.Error(err))
return &querypb.ShowCollectionsResponse{
......@@ -114,6 +122,7 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions
Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg, ErrNotHealthy),
}, nil
}
defer meta.GlobalFailedLoadCache.TryExpire()
// TODO(yah01): now, for load collection, the percentage of partition is equal to the percentage of collection,
// we can calculates the real percentage of partitions
......@@ -163,6 +172,13 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions
}
if isReleased {
status := meta.GlobalFailedLoadCache.Get(req.GetCollectionID())
if status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("show collection failed", zap.String("errCode", status.GetErrorCode().String()), zap.String("reason", status.GetReason()))
return &querypb.ShowPartitionsResponse{
Status: status,
}, nil
}
msg := fmt.Sprintf("collection %v has not been loaded into QueryNode", req.GetCollectionID())
log.Warn(msg)
return &querypb.ShowPartitionsResponse{
......@@ -251,6 +267,8 @@ func (s *Server) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
log.Info("collection released")
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.SuccessLabel).Inc()
metrics.QueryCoordReleaseLatency.WithLabelValues().Observe(float64(tr.ElapseSpan().Milliseconds()))
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
return successStatus, nil
}
......@@ -333,6 +351,8 @@ func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.SuccessLabel).Inc()
metrics.QueryCoordReleaseLatency.WithLabelValues().Observe(float64(tr.ElapseSpan().Milliseconds()))
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
return successStatus, nil
}
......
......@@ -19,6 +19,7 @@ package querycoordv2
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"
......@@ -141,6 +142,7 @@ func (suite *ServiceSuite) SetupTest() {
suite.meta,
suite.targetMgr,
)
meta.GlobalFailedLoadCache = meta.NewFailedLoadCache()
suite.server = &Server{
kv: suite.kv,
......@@ -185,6 +187,18 @@ func (suite *ServiceSuite) TestShowCollections() {
suite.Len(resp.CollectionIDs, 1)
suite.Equal(collection, resp.CollectionIDs[0])
// Test insufficient memory
colBak := suite.meta.CollectionManager.GetCollection(collection)
err = suite.meta.CollectionManager.RemoveCollection(collection)
suite.NoError(err)
meta.GlobalFailedLoadCache.Put(collection, commonpb.ErrorCode_InsufficientMemoryToLoad, fmt.Errorf("mock insufficient memory reason"))
resp, err = server.ShowCollections(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode())
meta.GlobalFailedLoadCache.Remove(collection)
err = suite.meta.CollectionManager.PutCollection(colBak)
suite.NoError(err)
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
resp, err = server.ShowCollections(ctx, req)
......@@ -225,6 +239,32 @@ func (suite *ServiceSuite) TestShowPartitions() {
for _, partition := range partitions[0:1] {
suite.Contains(resp.PartitionIDs, partition)
}
// Test insufficient memory
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
colBak := suite.meta.CollectionManager.GetCollection(collection)
err = suite.meta.CollectionManager.RemoveCollection(collection)
suite.NoError(err)
meta.GlobalFailedLoadCache.Put(collection, commonpb.ErrorCode_InsufficientMemoryToLoad, fmt.Errorf("mock insufficient memory reason"))
resp, err = server.ShowPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode())
meta.GlobalFailedLoadCache.Remove(collection)
err = suite.meta.CollectionManager.PutCollection(colBak)
suite.NoError(err)
} else {
partitionID := partitions[0]
parBak := suite.meta.CollectionManager.GetPartition(partitionID)
err = suite.meta.CollectionManager.RemovePartition(partitionID)
suite.NoError(err)
meta.GlobalFailedLoadCache.Put(collection, commonpb.ErrorCode_InsufficientMemoryToLoad, fmt.Errorf("mock insufficient memory reason"))
resp, err = server.ShowPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode())
meta.GlobalFailedLoadCache.Remove(collection)
err = suite.meta.CollectionManager.PutPartition(parBak)
suite.NoError(err)
}
}
// Test when server is not healthy
......
......@@ -19,6 +19,7 @@ package task
import (
"context"
"errors"
"fmt"
"sync"
"time"
......@@ -146,8 +147,12 @@ func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) {
action := task.Actions()[mergeTask.steps[0]]
defer func() {
canceled := task.canceled.Load()
for i := range mergeTask.tasks {
mergeTask.tasks[i].SetErr(task.Err())
if canceled {
mergeTask.tasks[i].Cancel()
}
ex.removeTask(mergeTask.tasks[i], mergeTask.steps[i])
}
}()
......@@ -184,6 +189,12 @@ func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) {
log.Warn("failed to load segment, it may be a false failure", zap.Error(err))
return
}
if status.ErrorCode == commonpb.ErrorCode_InsufficientMemoryToLoad {
log.Warn("insufficient memory to load segment", zap.String("err", status.GetReason()))
task.SetErr(fmt.Errorf("%w, err:%s", ErrInsufficientMemory, status.GetReason()))
task.Cancel()
return
}
if status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("failed to load segment", zap.String("reason", status.GetReason()))
return
......
......@@ -23,6 +23,7 @@ import (
"runtime"
"sync"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/datapb"
......@@ -53,8 +54,8 @@ var (
// or the target channel is not in TargetManager
ErrTaskStale = errors.New("TaskStale")
// No enough memory to load segment
ErrResourceNotEnough = errors.New("ResourceNotEnough")
// ErrInsufficientMemory returns insufficient memory error.
ErrInsufficientMemory = errors.New("InsufficientMemoryToLoad")
ErrFailedResponse = errors.New("RpcFailed")
ErrTaskAlreadyDone = errors.New("TaskAlreadyDone")
......@@ -658,6 +659,16 @@ func (scheduler *taskScheduler) RemoveByNode(node int64) {
}
}
func (scheduler *taskScheduler) recordSegmentTaskError(task *SegmentTask) {
var errCode commonpb.ErrorCode
if errors.Is(task.Err(), ErrInsufficientMemory) {
errCode = commonpb.ErrorCode_InsufficientMemoryToLoad
} else {
errCode = commonpb.ErrorCode_UnexpectedError
}
meta.GlobalFailedLoadCache.Put(task.collectionID, errCode, task.Err())
}
func (scheduler *taskScheduler) remove(task Task) {
log := log.With(
zap.Int64("taskID", task.ID()),
......@@ -675,6 +686,10 @@ func (scheduler *taskScheduler) remove(task Task) {
index := NewReplicaSegmentIndex(task)
delete(scheduler.segmentTasks, index)
log = log.With(zap.Int64("segmentID", task.SegmentID()))
if task.Err() != nil {
log.Warn("task scheduler recordSegmentTaskError", zap.Error(task.err))
scheduler.recordSegmentTaskError(task)
}
case *ChannelTask:
index := replicaChannelIndex{task.ReplicaID(), task.Channel()}
......
......@@ -142,6 +142,7 @@ func (suite *TaskSuite) SetupTest() {
suite.scheduler.AddExecutor(1)
suite.scheduler.AddExecutor(2)
suite.scheduler.AddExecutor(3)
meta.GlobalFailedLoadCache = meta.NewFailedLoadCache()
}
func (suite *TaskSuite) BeforeTest(suiteName, testName string) {
......
......@@ -27,6 +27,8 @@ var (
ErrShardNotAvailable = errors.New("ShardNotAvailable")
// ErrTsLagTooLarge serviceable and guarantee lag too large.
ErrTsLagTooLarge = errors.New("Timestamp lag too large")
// ErrInsufficientMemory returns insufficient memory error.
ErrInsufficientMemory = errors.New("InsufficientMemoryToLoad")
)
// WrapErrShardNotAvailable wraps ErrShardNotAvailable with replica id and channel name.
......
......@@ -520,6 +520,9 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
if errors.Is(err, ErrInsufficientMemory) {
status.ErrorCode = commonpb.ErrorCode_InsufficientMemoryToLoad
}
log.Warn(err.Error())
return status, nil
}
......
......@@ -2,6 +2,8 @@ package querynode
import (
"context"
"errors"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
......@@ -38,6 +40,13 @@ func (node *QueryNode) TransferLoad(ctx context.Context, req *querypb.LoadSegmen
req.NeedTransfer = false
err := shardCluster.LoadSegments(ctx, req)
if err != nil {
if errors.Is(err, ErrInsufficientMemory) {
log.Warn("insufficient memory when shard cluster load segments", zap.Error(err))
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad,
Reason: fmt.Sprintf("insufficient memory when shard cluster load segments, err:%s", err.Error()),
}, nil
}
log.Warn("shard cluster failed to load segments", zap.Error(err))
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......
......@@ -157,6 +157,39 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
s.NoError(err)
s.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
})
s.Run("insufficient memory", func() {
cs, ok := s.querynode.ShardClusterService.getShardCluster(defaultChannelName)
s.Require().True(ok)
cs.nodes[100] = &shardNode{
nodeID: 100,
nodeAddr: "test",
client: &mockShardQueryNode{
loadSegmentsResults: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad,
Reason: "mock InsufficientMemoryToLoad",
},
},
}
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID,
},
DstNodeID: 100,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID,
InsertChannel: defaultChannelName,
CollectionID: defaultCollectionID,
PartitionID: defaultPartitionID,
},
},
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, status.GetErrorCode())
})
}
func (s *ImplUtilsSuite) TestTransferRelease() {
......
......@@ -952,7 +952,8 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad
zap.Uint64("diskUsageAfterLoad", toMB(usedLocalSizeAfterLoad)))
if memLoadingUsage > uint64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage.GetAsFloat()) {
return fmt.Errorf("load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %v MB, concurrency = %d, usedMemAfterLoad = %v MB, totalMem = %v MB, thresholdFactor = %f",
return fmt.Errorf("%w, load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %v MB, concurrency = %d, usedMemAfterLoad = %v MB, totalMem = %v MB, thresholdFactor = %f",
ErrInsufficientMemory,
collectionID,
toMB(maxSegmentSize),
concurrency,
......
......@@ -643,6 +643,10 @@ func (sc *ShardCluster) LoadSegments(ctx context.Context, req *querypb.LoadSegme
log.Warn("failed to dispatch load segment request", zap.Error(err))
return err
}
if resp.GetErrorCode() == commonpb.ErrorCode_InsufficientMemoryToLoad {
log.Warn("insufficient memory when follower load segment", zap.String("reason", resp.GetReason()))
return fmt.Errorf("%w, reason:%s", ErrInsufficientMemory, resp.GetReason())
}
if resp.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("follower load segment failed", zap.String("reason", resp.GetReason()))
return fmt.Errorf("follower %d failed to load segment, reason %s", req.DstNodeID, resp.GetReason())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册