未验证 提交 4e83f378 编写于 作者: Y yah01 提交者: GitHub

Improve load performance (#17273)

Load segment's fields concurrently
Signed-off-by: Nyah01 <yang.cen@zilliz.com>
上级 7c52a8c5
......@@ -339,7 +339,7 @@ func loadIndexForSegment(ctx context.Context, node *QueryNode, segmentID UniqueI
},
}
err = loader.loadSegment(req, segmentTypeSealed)
err = loader.LoadSegment(req, segmentTypeSealed)
if err != nil {
return err
}
......
......@@ -71,28 +71,27 @@ func (loader *segmentLoader) getFieldType(segment *Segment, fieldID FieldID) (sc
return coll.getFieldType(fieldID)
}
func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segmentType segmentType) error {
func (loader *segmentLoader) LoadSegment(req *querypb.LoadSegmentsRequest, segmentType segmentType) error {
if req.Base == nil {
return fmt.Errorf("nil base message when load segment, collectionID = %d", req.CollectionID)
}
// no segment needs to load, return
if len(req.Infos) == 0 {
segmentNum := len(req.Infos)
if segmentNum == 0 {
return nil
}
log.Info("segmentLoader start loading...",
zap.Any("collectionID", req.CollectionID),
zap.Any("numOfSegments", len(req.Infos)),
zap.Any("loadType", segmentType),
)
zap.Any("segmentNum", segmentNum),
zap.Any("loadType", segmentType))
// check memory limit
concurrencyLevel := loader.cpuPool.Cap()
if len(req.Infos) > 0 && len(req.Infos[0].BinlogPaths) > 0 {
concurrencyLevel /= len(req.Infos[0].BinlogPaths)
if concurrencyLevel <= 0 {
concurrencyLevel = 1
}
if concurrencyLevel > segmentNum {
concurrencyLevel = segmentNum
}
for ; concurrencyLevel > 1; concurrencyLevel /= 2 {
err := loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel)
......@@ -166,9 +165,13 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme
return nil
}
// start to load
// Make sure we can always benefit from concurrency, and not spawn too many idle goroutines
err = funcutil.ProcessFuncParallel(len(req.Infos),
log.Debug("start to load segments in parallel",
zap.Int("segmentNum", segmentNum),
zap.Int("concurrencyLevel", concurrencyLevel))
err = funcutil.ProcessFuncParallel(segmentNum,
concurrencyLevel,
loadSegmentFunc, "loadSegmentFunc")
if err != nil {
......@@ -217,8 +220,10 @@ func (loader *segmentLoader) loadSegmentInternal(segment *Segment,
if segment.getType() == segmentTypeSealed {
fieldID2IndexInfo := make(map[int64]*querypb.FieldIndexInfo)
for _, indexInfo := range loadInfo.IndexInfos {
fieldID := indexInfo.FieldID
fieldID2IndexInfo[fieldID] = indexInfo
if indexInfo != nil && indexInfo.EnableIndex {
fieldID := indexInfo.FieldID
fieldID2IndexInfo[fieldID] = indexInfo
}
}
indexedFieldInfos := make(map[int64]*IndexedFieldInfo)
......@@ -239,12 +244,14 @@ func (loader *segmentLoader) loadSegmentInternal(segment *Segment,
if err := loader.loadIndexedFieldData(segment, indexedFieldInfos); err != nil {
return err
}
if err := loader.loadSealedSegmentFields(segment, fieldBinlogs); err != nil {
return err
}
} else {
fieldBinlogs = loadInfo.BinlogPaths
}
if err := loader.loadFiledBinlogData(segment, fieldBinlogs); err != nil {
return err
if err := loader.loadGrowingSegmentFields(segment, fieldBinlogs); err != nil {
return err
}
}
if pkFieldID == common.InvalidFieldID {
......@@ -275,7 +282,7 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi
return result
}
func (loader *segmentLoader) loadFiledBinlogData(segment *Segment, fieldBinlogs []*datapb.FieldBinlog) error {
func (loader *segmentLoader) loadGrowingSegmentFields(segment *Segment, fieldBinlogs []*datapb.FieldBinlog) error {
if len(fieldBinlogs) <= 0 {
return nil
}
......@@ -290,7 +297,7 @@ func (loader *segmentLoader) loadFiledBinlogData(segment *Segment, fieldBinlogs
loadFutures = append(loadFutures, futures...)
}
// wait for async load result
// wait for async load results
blobs := make([]*storage.Blob, len(loadFutures))
for index, future := range loadFutures {
if !future.OK() {
......@@ -328,14 +335,64 @@ func (loader *segmentLoader) loadFiledBinlogData(segment *Segment, fieldBinlogs
}
return loader.loadGrowingSegments(segment, rowIDData.(*storage.Int64FieldData).Data, utss, insertData)
case segmentTypeSealed:
return loader.loadSealedSegments(segment, insertData)
default:
err := errors.New(fmt.Sprintln("illegal segment type when load segment, collectionID = ", segment.collectionID))
err := fmt.Errorf("illegal segmentType=%v when load segment, collectionID=%v", segmentType, segment.collectionID)
return err
}
}
func (loader *segmentLoader) loadSealedSegmentFields(segment *Segment, fields []*datapb.FieldBinlog) error {
// Load fields concurrently
futures := make([]*concurrency.Future, 0, len(fields))
for _, field := range fields {
future := loader.loadSealedFieldAsync(segment, field)
futures = append(futures, future)
}
err := concurrency.AwaitAll(futures...)
if err != nil {
return err
}
log.Info("log field binlogs done",
zap.Int64("collection", segment.collectionID),
zap.Int64("segment", segment.segmentID),
zap.Any("fields", fields))
return nil
}
// async load field of sealed segment
func (loader *segmentLoader) loadSealedFieldAsync(segment *Segment, field *datapb.FieldBinlog) *concurrency.Future {
iCodec := storage.InsertCodec{}
// Avoid consuming too much memory if no CPU worker ready,
// acquire a CPU worker before load field binlogs
return loader.cpuPool.Submit(func() (interface{}, error) {
futures := loader.loadFieldBinlogsAsync(field)
blobs := make([]*storage.Blob, len(futures))
for index, future := range futures {
if !future.OK() {
return nil, future.Err()
}
blob := future.Value().(*storage.Blob)
blobs[index] = blob
}
_, _, insertData, err := iCodec.Deserialize(blobs)
if err != nil {
log.Warn(err.Error())
return nil, err
}
return nil, loader.loadSealedSegments(segment, insertData)
})
}
// Load binlogs concurrently into memory from KV storage asyncly
func (loader *segmentLoader) loadFieldBinlogsAsync(field *datapb.FieldBinlog) []*concurrency.Future {
futures := make([]*concurrency.Future, 0, len(field.Binlogs))
......@@ -361,21 +418,13 @@ func (loader *segmentLoader) loadFieldBinlogsAsync(field *datapb.FieldBinlog) []
func (loader *segmentLoader) loadIndexedFieldData(segment *Segment, vecFieldInfos map[int64]*IndexedFieldInfo) error {
for fieldID, fieldInfo := range vecFieldInfos {
if fieldInfo.indexInfo == nil || !fieldInfo.indexInfo.EnableIndex {
fieldBinlog := fieldInfo.fieldBinlog
err := loader.loadFiledBinlogData(segment, []*datapb.FieldBinlog{fieldBinlog})
if err != nil {
return err
}
log.Debug("load vector field's binlog data done", zap.Int64("segmentID", segment.ID()), zap.Int64("fieldID", fieldID))
} else {
indexInfo := fieldInfo.indexInfo
err := loader.loadFieldIndexData(segment, indexInfo)
if err != nil {
return err
}
log.Debug("load field's index data done", zap.Int64("segmentID", segment.ID()), zap.Int64("fieldID", fieldID))
indexInfo := fieldInfo.indexInfo
err := loader.loadFieldIndexData(segment, indexInfo)
if err != nil {
return err
}
log.Debug("load field's index data done", zap.Int64("segmentID", segment.ID()), zap.Int64("fieldID", fieldID))
segment.setIndexedFieldInfo(fieldID, fieldInfo)
}
......@@ -767,7 +816,11 @@ func newSegmentLoader(
panic(err)
}
ioPool, err := concurrency.NewPool(cpuNum*2, ants.WithPreAlloc(true))
ioPoolSize := cpuNum * 2
if ioPoolSize < 32 {
ioPoolSize = 32
}
ioPool, err := concurrency.NewPool(ioPoolSize, ants.WithPreAlloc(true))
if err != nil {
log.Error("failed to create goroutine pool for segment loader",
zap.Error(err))
......
......@@ -70,7 +70,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) {
},
}
err = loader.loadSegment(req, segmentTypeSealed)
err = loader.LoadSegment(req, segmentTypeSealed)
assert.NoError(t, err)
})
......@@ -101,7 +101,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) {
},
}
err = loader.loadSegment(req, segmentTypeSealed)
err = loader.LoadSegment(req, segmentTypeSealed)
assert.Error(t, err)
})
......@@ -114,7 +114,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) {
req := &querypb.LoadSegmentsRequest{}
err = loader.loadSegment(req, segmentTypeSealed)
err = loader.LoadSegment(req, segmentTypeSealed)
assert.Error(t, err)
})
}
......@@ -182,7 +182,7 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) {
binlog, _, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema)
assert.NoError(t, err)
err = loader.loadFiledBinlogData(segment, binlog)
err = loader.loadSealedSegmentFields(segment, binlog)
assert.NoError(t, err)
}
......@@ -235,7 +235,7 @@ func TestSegmentLoader_invalid(t *testing.T) {
},
}
err = loader.loadSegment(req, segmentTypeSealed)
err = loader.LoadSegment(req, segmentTypeSealed)
assert.Error(t, err)
})
......@@ -273,7 +273,7 @@ func TestSegmentLoader_invalid(t *testing.T) {
},
},
}
err = loader.loadSegment(req, segmentTypeSealed)
err = loader.LoadSegment(req, segmentTypeSealed)
assert.Error(t, err)
})
......@@ -298,7 +298,7 @@ func TestSegmentLoader_invalid(t *testing.T) {
},
}
err = loader.loadSegment(req, commonpb.SegmentState_Dropped)
err = loader.LoadSegment(req, commonpb.SegmentState_Dropped)
assert.Error(t, err)
})
}
......@@ -416,7 +416,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) {
},
}
err = loader.loadSegment(req1, segmentTypeSealed)
err = loader.LoadSegment(req1, segmentTypeSealed)
assert.NoError(t, err)
segment1, err := loader.metaReplica.getSegmentByID(segmentID1, segmentTypeSealed)
......@@ -442,7 +442,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) {
},
}
err = loader.loadSegment(req2, segmentTypeSealed)
err = loader.LoadSegment(req2, segmentTypeSealed)
assert.NoError(t, err)
segment2, err := loader.metaReplica.getSegmentByID(segmentID2, segmentTypeSealed)
......@@ -476,7 +476,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) {
},
}
err = loader.loadSegment(req1, segmentTypeGrowing)
err = loader.LoadSegment(req1, segmentTypeGrowing)
assert.NoError(t, err)
segment1, err := loader.metaReplica.getSegmentByID(segmentID1, segmentTypeGrowing)
......@@ -502,7 +502,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) {
},
}
err = loader.loadSegment(req2, segmentTypeGrowing)
err = loader.LoadSegment(req2, segmentTypeGrowing)
assert.NoError(t, err)
segment2, err := loader.metaReplica.getSegmentByID(segmentID2, segmentTypeGrowing)
......@@ -562,7 +562,7 @@ func TestSegmentLoader_testLoadSealedSegmentWithIndex(t *testing.T) {
},
}
err = loader.loadSegment(req, segmentTypeSealed)
err = loader.LoadSegment(req, segmentTypeSealed)
assert.NoError(t, err)
segment, err := node.metaReplica.getSegmentByID(segmentID, segmentTypeSealed)
......
......@@ -218,7 +218,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
zap.Int64("collectionID", collectionID),
zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs),
)
err = w.node.loader.loadSegment(req, segmentTypeGrowing)
err = w.node.loader.LoadSegment(req, segmentTypeGrowing)
if err != nil {
log.Warn(err.Error())
return err
......@@ -524,7 +524,7 @@ func (l *loadSegmentsTask) PreExecute(ctx context.Context) error {
func (l *loadSegmentsTask) Execute(ctx context.Context) error {
// TODO: support db
log.Info("LoadSegmentTask Execute start", zap.Int64("msgID", l.req.Base.MsgID))
err := l.node.loader.loadSegment(l.req, segmentTypeSealed)
err := l.node.loader.LoadSegment(l.req, segmentTypeSealed)
if err != nil {
log.Warn(err.Error())
return err
......
......@@ -46,8 +46,8 @@ func (future *Future) Value() interface{} {
return future.value
}
// True if error occurred,
// false otherwise.
// False if error occurred,
// true otherwise.
func (future *Future) OK() bool {
<-future.ch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册