import_wrapper.go 21.8 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

G
groot 已提交
17 18 19 20 21
package importutil

import (
	"bufio"
	"context"
G
groot 已提交
22
	"fmt"
G
groot 已提交
23
	"strconv"
G
groot 已提交
24

G
godchen 已提交
25 26
	"go.uber.org/zap"

S
SimFG 已提交
27 28
	"github.com/milvus-io/milvus-proto/go-api/commonpb"
	"github.com/milvus-io/milvus-proto/go-api/schemapb"
G
groot 已提交
29 30 31
	"github.com/milvus-io/milvus/internal/allocator"
	"github.com/milvus-io/milvus/internal/common"
	"github.com/milvus-io/milvus/internal/log"
G
groot 已提交
32
	"github.com/milvus-io/milvus/internal/proto/datapb"
G
groot 已提交
33
	"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
34
	"github.com/milvus-io/milvus/internal/querycoordv2/params"
G
groot 已提交
35
	"github.com/milvus-io/milvus/internal/storage"
36
	"github.com/milvus-io/milvus/internal/util/retry"
G
groot 已提交
37
	"github.com/milvus-io/milvus/internal/util/timerecord"
G
groot 已提交
38 39 40 41 42
)

const (
	JSONFileExt  = ".json"
	NumpyFileExt = ".npy"
G
groot 已提交
43

G
groot 已提交
44 45 46
	// supposed size of a single block, to control a binlog file size, the max biglog file size is no more than 2*SingleBlockSize
	SingleBlockSize = 16 * 1024 * 1024 // 16MB

G
groot 已提交
47 48 49 50 51 52 53 54 55
	// this limitation is to avoid this OOM risk:
	// simetimes system segment max size is a large number, a single segment fields data might cause OOM.
	// flush the segment when its data reach this limitation, let the compaction to compact it later.
	MaxSegmentSizeInMemory = 512 * 1024 * 1024 // 512MB

	// this limitation is to avoid this OOM risk:
	// if the shard number is a large number, although single segment size is small, but there are lot of in-memory segments,
	// the total memory size might cause OOM.
	MaxTotalSizeInMemory = 2 * 1024 * 1024 * 1024 // 2GB
G
groot 已提交
56

G
groot 已提交
57 58 59
	// progress percent value of persist state
	ProgressValueForPersist = 90

G
groot 已提交
60 61 62 63 64 65
	// keywords of import task informations
	FailedReason    = "failed_reason"
	Files           = "files"
	CollectionName  = "collection"
	PartitionName   = "partition"
	PersistTimeCost = "persist_cost"
G
groot 已提交
66
	ProgressPercent = "progress_percent"
G
groot 已提交
67 68
)

69 70 71
// ReportImportAttempts is the maximum # of attempts to retry when import fails.
var ReportImportAttempts uint = 10

G
groot 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardID int) error
type AssignSegmentFunc func(shardID int) (int64, string, error)
type CreateBinlogsFunc func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error)
type SaveSegmentFunc func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error

type WorkingSegment struct {
	segmentID    int64                 // segment ID
	shardID      int                   // shard id
	targetChName string                // target dml channel
	rowCount     int64                 // accumulate row count
	memSize      int                   // total memory size of all binlogs
	fieldsInsert []*datapb.FieldBinlog // persisted binlogs
	fieldsStats  []*datapb.FieldBinlog // stats of persisted binlogs
}

G
groot 已提交
87 88 89 90 91
type ImportWrapper struct {
	ctx              context.Context            // for canceling parse process
	cancel           context.CancelFunc         // for canceling parse process
	collectionSchema *schemapb.CollectionSchema // collection schema
	shardNum         int32                      // sharding number of the collection
G
groot 已提交
92
	segmentSize      int64                      // maximum size of a segment(unit:byte) defined by dataCoord.segment.maxSize (milvus.yml)
G
groot 已提交
93
	rowIDAllocator   *allocator.IDAllocator     // autoid allocator
G
godchen 已提交
94
	chunkManager     storage.ChunkManager
G
groot 已提交
95

G
groot 已提交
96 97 98
	assignSegmentFunc AssignSegmentFunc // function to prepare a new segment
	createBinlogsFunc CreateBinlogsFunc // function to create binlog for a segment
	saveSegmentFunc   SaveSegmentFunc   // function to persist a segment
G
groot 已提交
99

G
groot 已提交
100 101 102
	importResult         *rootcoordpb.ImportResult                 // import result
	reportFunc           func(res *rootcoordpb.ImportResult) error // report import state to rootcoord
	reportImportAttempts uint                                      // attempts count if report function get error
G
groot 已提交
103 104

	workingSegments map[int]*WorkingSegment // a map shard id to working segments
G
groot 已提交
105
	progressPercent int64                   // working progress percent
G
groot 已提交
106 107
}

G
godchen 已提交
108
func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int64,
G
groot 已提交
109 110
	idAlloc *allocator.IDAllocator, cm storage.ChunkManager, importResult *rootcoordpb.ImportResult,
	reportFunc func(res *rootcoordpb.ImportResult) error) *ImportWrapper {
G
groot 已提交
111
	if collectionSchema == nil {
G
groot 已提交
112
		log.Error("import wrapper: collection schema is nil")
G
groot 已提交
113 114 115
		return nil
	}

116 117
	params.Params.InitOnce()

G
groot 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
	// ignore the RowID field and Timestamp field
	realSchema := &schemapb.CollectionSchema{
		Name:        collectionSchema.GetName(),
		Description: collectionSchema.GetDescription(),
		AutoID:      collectionSchema.GetAutoID(),
		Fields:      make([]*schemapb.FieldSchema, 0),
	}
	for i := 0; i < len(collectionSchema.Fields); i++ {
		schema := collectionSchema.Fields[i]
		if schema.GetName() == common.RowIDFieldName || schema.GetName() == common.TimeStampFieldName {
			continue
		}
		realSchema.Fields = append(realSchema.Fields, schema)
	}

	ctx, cancel := context.WithCancel(ctx)

	wrapper := &ImportWrapper{
G
groot 已提交
136 137 138 139 140 141 142 143 144 145 146
		ctx:                  ctx,
		cancel:               cancel,
		collectionSchema:     realSchema,
		shardNum:             shardNum,
		segmentSize:          segmentSize,
		rowIDAllocator:       idAlloc,
		chunkManager:         cm,
		importResult:         importResult,
		reportFunc:           reportFunc,
		reportImportAttempts: ReportImportAttempts,
		workingSegments:      make(map[int]*WorkingSegment),
G
groot 已提交
147 148 149 150 151
	}

	return wrapper
}

G
groot 已提交
152 153 154
func (p *ImportWrapper) SetCallbackFunctions(assignSegmentFunc AssignSegmentFunc, createBinlogsFunc CreateBinlogsFunc, saveSegmentFunc SaveSegmentFunc) error {
	if assignSegmentFunc == nil {
		log.Error("import wrapper: callback function AssignSegmentFunc is nil")
G
groot 已提交
155
		return fmt.Errorf("callback function AssignSegmentFunc is nil")
G
groot 已提交
156 157 158 159
	}

	if createBinlogsFunc == nil {
		log.Error("import wrapper: callback function CreateBinlogsFunc is nil")
G
groot 已提交
160
		return fmt.Errorf("callback function CreateBinlogsFunc is nil")
G
groot 已提交
161 162 163 164
	}

	if saveSegmentFunc == nil {
		log.Error("import wrapper: callback function SaveSegmentFunc is nil")
G
groot 已提交
165
		return fmt.Errorf("callback function SaveSegmentFunc is nil")
G
groot 已提交
166 167 168 169 170 171 172 173 174
	}

	p.assignSegmentFunc = assignSegmentFunc
	p.createBinlogsFunc = createBinlogsFunc
	p.saveSegmentFunc = saveSegmentFunc
	return nil
}

// Cancel method can be used to cancel parse process
G
groot 已提交
175 176 177 178 179
func (p *ImportWrapper) Cancel() error {
	p.cancel()
	return nil
}

G
groot 已提交
180 181 182 183
// fileValidation verify the input paths
// if all the files are json type, return true
// if all the files are numpy type, return false, and not allow duplicate file name
func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
184 185
	// use this map to check duplicate file name(only for numpy file)
	fileNames := make(map[string]struct{})
G
groot 已提交
186

G
groot 已提交
187
	totalSize := int64(0)
G
groot 已提交
188
	rowBased := false
G
groot 已提交
189 190
	for i := 0; i < len(filePaths); i++ {
		filePath := filePaths[i]
G
groot 已提交
191 192 193 194
		name, fileType := GetFileNameAndExt(filePath)

		// only allow json file or numpy file
		if fileType != JSONFileExt && fileType != NumpyFileExt {
G
groot 已提交
195 196
			log.Error("import wrapper: unsupported file type", zap.String("filePath", filePath))
			return false, fmt.Errorf("unsupported file type: '%s'", filePath)
G
groot 已提交
197 198 199 200 201
		}

		// we use the first file to determine row-based or column-based
		if i == 0 && fileType == JSONFileExt {
			rowBased = true
G
groot 已提交
202
		}
G
groot 已提交
203 204

		// check file type
G
groot 已提交
205
		// row-based only support json type, column-based only support numpy type
G
groot 已提交
206 207
		if rowBased {
			if fileType != JSONFileExt {
G
groot 已提交
208
				log.Error("import wrapper: unsupported file type for row-based mode", zap.String("filePath", filePath))
G
groot 已提交
209
				return rowBased, fmt.Errorf("unsupported file type for row-based mode: '%s'", filePath)
G
groot 已提交
210 211
			}
		} else {
G
groot 已提交
212
			if fileType != NumpyFileExt {
G
groot 已提交
213
				log.Error("import wrapper: unsupported file type for column-based mode", zap.String("filePath", filePath))
G
groot 已提交
214
				return rowBased, fmt.Errorf("unsupported file type for column-based mode: '%s'", filePath)
G
groot 已提交
215 216 217
			}
		}

G
groot 已提交
218 219 220 221
		// check dupliate file
		_, ok := fileNames[name]
		if ok {
			log.Error("import wrapper: duplicate file name", zap.String("filePath", filePath))
G
groot 已提交
222
			return rowBased, fmt.Errorf("duplicate file: '%s'", filePath)
G
groot 已提交
223 224 225
		}
		fileNames[name] = struct{}{}

G
groot 已提交
226
		// check file size, single file size cannot exceed MaxFileSize
227
		size, err := p.chunkManager.Size(p.ctx, filePath)
G
groot 已提交
228
		if err != nil {
G
groot 已提交
229
			log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Error(err))
G
groot 已提交
230
			return rowBased, fmt.Errorf("failed to get file size of '%s', error:%w", filePath, err)
G
groot 已提交
231 232
		}

G
groot 已提交
233
		// empty file
G
groot 已提交
234
		if size == 0 {
G
groot 已提交
235
			log.Error("import wrapper: file size is zero", zap.String("filePath", filePath))
G
groot 已提交
236
			return rowBased, fmt.Errorf("the file '%s' size is zero", filePath)
G
groot 已提交
237
		}
G
groot 已提交
238

239
		if size > params.Params.CommonCfg.ImportMaxFileSize {
G
groot 已提交
240
			log.Error("import wrapper: file size exceeds the maximum size", zap.String("filePath", filePath),
241 242
				zap.Int64("fileSize", size), zap.Int64("MaxFileSize", params.Params.CommonCfg.ImportMaxFileSize))
			return rowBased, fmt.Errorf("the file '%s' size exceeds the maximum size: %d bytes", filePath, params.Params.CommonCfg.ImportMaxFileSize)
G
groot 已提交
243
		}
G
groot 已提交
244 245 246
		totalSize += size
	}

G
groot 已提交
247
	return rowBased, nil
G
groot 已提交
248 249
}

G
groot 已提交
250
// Import is the entry of import operation
G
groot 已提交
251
// filePath and rowBased are from ImportTask
G
groot 已提交
252 253 254
// if onlyValidate is true, this process only do validation, no data generated, flushFunc will not be called
func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error {
	log.Info("import wrapper: begin import", zap.Any("filePaths", filePaths), zap.Any("options", options))
G
groot 已提交
255

G
groot 已提交
256 257
	// data restore function to import milvus native binlog files(for backup/restore tools)
	// the backup/restore tool provide two paths for a partition, the first path is binlog path, the second is deltalog path
258
	if options.IsBackup && p.isBinlogImport(filePaths) {
G
groot 已提交
259
		return p.doBinlogImport(filePaths, options.TsStartPoint, options.TsEndPoint)
G
groot 已提交
260 261 262
	}

	// normal logic for import general data files
G
groot 已提交
263
	rowBased, err := p.fileValidation(filePaths)
G
groot 已提交
264 265 266 267
	if err != nil {
		return err
	}

G
groot 已提交
268
	tr := timerecord.NewTimeRecorder("Import task")
G
groot 已提交
269 270 271
	if rowBased {
		// parse and consume row-based files
		// for row-based files, the JSONRowConsumer will generate autoid for primary key, and split rows into segments
G
groot 已提交
272
		// according to shard number, so the flushFunc will be called in the JSONRowConsumer
G
groot 已提交
273 274
		for i := 0; i < len(filePaths); i++ {
			filePath := filePaths[i]
G
groot 已提交
275
			_, fileType := GetFileNameAndExt(filePath)
G
groot 已提交
276
			log.Info("import wrapper:  row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
G
groot 已提交
277 278

			if fileType == JSONFileExt {
G
groot 已提交
279
				err = p.parseRowBasedJSON(filePath, options.OnlyValidate)
G
groot 已提交
280
				if err != nil {
G
groot 已提交
281
					log.Error("import wrapper: failed to parse row-based json file", zap.Error(err), zap.String("filePath", filePath))
G
groot 已提交
282 283
					return err
				}
G
groot 已提交
284 285 286 287
			} // no need to check else, since the fileValidation() already do this

			// trigger gc after each file finished
			triggerGC()
G
groot 已提交
288 289
		}
	} else {
G
groot 已提交
290 291 292 293 294 295
		// parse and consume column-based files(currently support numpy)
		// for column-based files, the NumpyParser will generate autoid for primary key, and split rows into segments
		// according to shard number, so the flushFunc will be called in the NumpyParser
		flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
			printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
			return p.flushFunc(fields, shardID)
G
groot 已提交
296
		}
G
groot 已提交
297 298
		parser, err := NewNumpyParser(p.ctx, p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize,
			p.chunkManager, flushFunc, p.updateProgressPercent)
G
groot 已提交
299 300
		if err != nil {
			return err
G
groot 已提交
301 302
		}

G
groot 已提交
303
		err = parser.Parse(filePaths)
G
groot 已提交
304 305 306
		if err != nil {
			return err
		}
G
groot 已提交
307

G
groot 已提交
308 309 310
		p.importResult.AutoIds = append(p.importResult.AutoIds, parser.IDRange()...)

		// trigger after parse finished
G
groot 已提交
311
		triggerGC()
G
groot 已提交
312 313
	}

G
groot 已提交
314
	return p.reportPersisted(p.reportImportAttempts, tr)
G
groot 已提交
315 316
}

G
groot 已提交
317
// reportPersisted notify the rootcoord to mark the task state to be ImportPersisted
G
groot 已提交
318
func (p *ImportWrapper) reportPersisted(reportAttempts uint, tr *timerecord.TimeRecorder) error {
G
groot 已提交
319 320 321 322 323 324
	// force close all segments
	err := p.closeAllWorkingSegments()
	if err != nil {
		return err
	}

G
groot 已提交
325 326 327 328 329 330
	if tr != nil {
		ts := tr.Elapse("persist finished").Seconds()
		p.importResult.Infos = append(p.importResult.Infos,
			&commonpb.KeyValuePair{Key: PersistTimeCost, Value: strconv.FormatFloat(ts, 'f', 2, 64)})
	}

G
groot 已提交
331 332
	// report file process state
	p.importResult.State = commonpb.ImportState_ImportPersisted
G
groot 已提交
333 334 335
	progressValue := strconv.Itoa(ProgressValueForPersist)
	UpdateKVInfo(&p.importResult.Infos, ProgressPercent, progressValue)

G
groot 已提交
336
	log.Info("import wrapper: report import result", zap.Any("importResult", p.importResult))
337 338 339
	// persist state task is valuable, retry more times in case fail this task only because of network error
	reportErr := retry.Do(p.ctx, func() error {
		return p.reportFunc(p.importResult)
G
groot 已提交
340
	}, retry.Attempts(reportAttempts))
341
	if reportErr != nil {
342
		log.Warn("import wrapper: fail to report import state to RootCoord", zap.Error(reportErr))
343 344 345 346 347
		return reportErr
	}
	return nil
}

G
groot 已提交
348
// isBinlogImport is to judge whether it is binlog import operation
G
groot 已提交
349 350
// For internal usage by the restore tool: https://github.com/zilliztech/milvus-backup
// This tool exports data from a milvus service, and call bulkload interface to import native data into another milvus service.
351
// This tool provides two paths: one is insert log path of a partition,the other is delta log path of this partition.
G
groot 已提交
352 353
// This method checks the filePaths, if the file paths is exist and not a file, we say it is native import.
func (p *ImportWrapper) isBinlogImport(filePaths []string) bool {
354 355 356
	// must contains the insert log path, and the delta log path is optional to be empty string
	if len(filePaths) != 2 {
		log.Info("import wrapper: paths count is not 2, not binlog import", zap.Int("len", len(filePaths)))
G
groot 已提交
357 358 359
		return false
	}

360
	checkFunc := func(filePath string) bool {
G
groot 已提交
361
		// contains file extension, is not a path
362
		_, fileType := GetFileNameAndExt(filePath)
G
groot 已提交
363
		if len(fileType) != 0 {
G
groot 已提交
364
			log.Info("import wrapper: not a path, not binlog import", zap.String("filePath", filePath), zap.String("fileType", fileType))
G
groot 已提交
365 366
			return false
		}
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
		return true
	}

	// the first path is insert log path
	filePath := filePaths[0]
	if len(filePath) == 0 {
		log.Info("import wrapper: the first path is empty string, not binlog import")
		return false
	}

	if !checkFunc(filePath) {
		return false
	}

	// the second path is delta log path
	filePath = filePaths[1]
	if len(filePath) > 0 && !checkFunc(filePath) {
		return false
G
groot 已提交
385 386 387 388 389 390
	}

	log.Info("import wrapper: do binlog import")
	return true
}

G
groot 已提交
391 392
// doBinlogImport is the entry of binlog import operation
func (p *ImportWrapper) doBinlogImport(filePaths []string, tsStartPoint uint64, tsEndPoint uint64) error {
G
groot 已提交
393 394
	tr := timerecord.NewTimeRecorder("Import task")

G
groot 已提交
395
	flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
G
groot 已提交
396 397
		printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
		return p.flushFunc(fields, shardID)
G
groot 已提交
398
	}
G
groot 已提交
399
	parser, err := NewBinlogParser(p.ctx, p.collectionSchema, p.shardNum, SingleBlockSize, p.chunkManager, flushFunc,
G
groot 已提交
400
		p.updateProgressPercent, tsStartPoint, tsEndPoint)
G
groot 已提交
401 402 403 404 405 406 407 408 409
	if err != nil {
		return err
	}

	err = parser.Parse(filePaths)
	if err != nil {
		return err
	}

G
groot 已提交
410
	return p.reportPersisted(p.reportImportAttempts, tr)
G
groot 已提交
411 412
}

G
groot 已提交
413
// parseRowBasedJSON is the entry of row-based json import operation
414 415 416 417 418
func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) error {
	tr := timerecord.NewTimeRecorder("json row-based parser: " + filePath)

	// for minio storage, chunkManager will download file into local memory
	// for local storage, chunkManager open the file directly
419
	file, err := p.chunkManager.Reader(p.ctx, filePath)
420 421 422 423 424
	if err != nil {
		return err
	}
	defer file.Close()

G
groot 已提交
425 426 427 428 429
	size, err := p.chunkManager.Size(p.ctx, filePath)
	if err != nil {
		return err
	}

430 431
	// parse file
	reader := bufio.NewReader(file)
G
groot 已提交
432
	parser := NewJSONParser(p.ctx, p.collectionSchema, p.updateProgressPercent)
433 434 435 436 437 438 439 440 441

	// if only validate, we input a empty flushFunc so that the consumer do nothing but only validation.
	var flushFunc ImportFlushFunc
	if onlyValidate {
		flushFunc = func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
			return nil
		}
	} else {
		flushFunc = func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
442
			var filePaths = []string{filePath}
G
groot 已提交
443 444
			printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths)
			return p.flushFunc(fields, shardID)
445 446
		}
	}
G
groot 已提交
447

448
	consumer, err := NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize, flushFunc)
449 450 451 452
	if err != nil {
		return err
	}

G
groot 已提交
453
	err = parser.ParseRows(&IOReader{r: reader, fileSize: size}, consumer)
454 455 456 457 458
	if err != nil {
		return err
	}

	// for row-based files, auto-id is generated within JSONRowConsumer
G
groot 已提交
459
	p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...)
460 461 462 463 464

	tr.Elapse("parsed")
	return nil
}

G
groot 已提交
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500
// flushFunc is the callback function for parsers generate segment and save binlog files
func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, shardID int) error {
	// if fields data is empty, do nothing
	var rowNum int
	memSize := 0
	for _, field := range fields {
		rowNum = field.RowNum()
		memSize += field.GetMemorySize()
		break
	}
	if rowNum <= 0 {
		log.Warn("import wrapper: fields data is empty", zap.Int("shardID", shardID))
		return nil
	}

	// if there is no segment for this shard, create a new one
	// if the segment exists and its size almost exceed segmentSize, close it and create a new one
	var segment *WorkingSegment
	segment, ok := p.workingSegments[shardID]
	if ok {
		// the segment already exists, check its size, if the size exceeds(or almost) segmentSize, close the segment
		if int64(segment.memSize)+int64(memSize) >= p.segmentSize {
			err := p.closeWorkingSegment(segment)
			if err != nil {
				return err
			}
			segment = nil
			p.workingSegments[shardID] = nil
		}

	}

	if segment == nil {
		// create a new segment
		segID, channelName, err := p.assignSegmentFunc(shardID)
		if err != nil {
G
groot 已提交
501
			log.Error("import wrapper: failed to assign a new segment", zap.Error(err), zap.Int("shardID", shardID))
G
groot 已提交
502
			return fmt.Errorf("failed to assign a new segment for shard id %d, error: %w", shardID, err)
G
groot 已提交
503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
		}

		segment = &WorkingSegment{
			segmentID:    segID,
			shardID:      shardID,
			targetChName: channelName,
			rowCount:     int64(0),
			memSize:      0,
			fieldsInsert: make([]*datapb.FieldBinlog, 0),
			fieldsStats:  make([]*datapb.FieldBinlog, 0),
		}
		p.workingSegments[shardID] = segment
	}

	// save binlogs
	fieldsInsert, fieldsStats, err := p.createBinlogsFunc(fields, segment.segmentID)
	if err != nil {
G
groot 已提交
520
		log.Error("import wrapper: failed to save binlogs", zap.Error(err), zap.Int("shardID", shardID),
G
groot 已提交
521
			zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName))
G
groot 已提交
522 523
		return fmt.Errorf("failed to save binlogs, shard id %d, segment id %d, channel '%s', error: %w",
			shardID, segment.segmentID, segment.targetChName, err)
G
groot 已提交
524 525 526 527 528 529 530
	}

	segment.fieldsInsert = append(segment.fieldsInsert, fieldsInsert...)
	segment.fieldsStats = append(segment.fieldsStats, fieldsStats...)
	segment.rowCount += int64(rowNum)
	segment.memSize += memSize

G
groot 已提交
531 532 533 534 535 536 537 538 539 540 541
	// report working progress percent value to rootcoord
	// if failed to report, ignore the error, the percent value might be improper but the task can be succeed
	progressValue := strconv.Itoa(int(p.progressPercent))
	UpdateKVInfo(&p.importResult.Infos, ProgressPercent, progressValue)
	reportErr := retry.Do(p.ctx, func() error {
		return p.reportFunc(p.importResult)
	}, retry.Attempts(p.reportImportAttempts))
	if reportErr != nil {
		log.Warn("import wrapper: fail to report working progress percent value to RootCoord", zap.Error(reportErr))
	}

G
groot 已提交
542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
	return nil
}

// closeWorkingSegment marks a segment to be sealed
func (p *ImportWrapper) closeWorkingSegment(segment *WorkingSegment) error {
	log.Info("import wrapper: adding segment to the correct DataNode flow graph and saving binlog paths",
		zap.Int("shardID", segment.shardID),
		zap.Int64("segmentID", segment.segmentID),
		zap.String("targetChannel", segment.targetChName),
		zap.Int64("rowCount", segment.rowCount),
		zap.Int("insertLogCount", len(segment.fieldsInsert)),
		zap.Int("statsLogCount", len(segment.fieldsStats)))

	err := p.saveSegmentFunc(segment.fieldsInsert, segment.fieldsStats, segment.segmentID, segment.targetChName, segment.rowCount)
	if err != nil {
G
groot 已提交
557
		log.Error("import wrapper: failed to seal segment",
G
groot 已提交
558
			zap.Error(err),
G
groot 已提交
559 560 561
			zap.Int("shardID", segment.shardID),
			zap.Int64("segmentID", segment.segmentID),
			zap.String("targetChannel", segment.targetChName))
G
groot 已提交
562 563
		return fmt.Errorf("failed to seal segment, shard id %d, segment id %d, channel '%s', error: %w",
			segment.shardID, segment.segmentID, segment.targetChName, err)
G
groot 已提交
564 565 566 567 568 569 570 571 572
	}

	return nil
}

// closeAllWorkingSegments mark all segments to be sealed at the end of import operation
func (p *ImportWrapper) closeAllWorkingSegments() error {
	for _, segment := range p.workingSegments {
		err := p.closeWorkingSegment(segment)
G
groot 已提交
573 574 575 576
		if err != nil {
			return err
		}
	}
G
groot 已提交
577
	p.workingSegments = make(map[int]*WorkingSegment)
G
groot 已提交
578 579 580

	return nil
}
G
groot 已提交
581 582 583 584 585 586 587 588

func (p *ImportWrapper) updateProgressPercent(percent int64) {
	// ignore illegal percent value
	if percent < 0 || percent > 100 {
		return
	}
	p.progressPercent = percent
}