import_wrapper.go 21.8 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

G
groot 已提交
17 18 19 20 21
package importutil

import (
	"bufio"
	"context"
G
groot 已提交
22
	"fmt"
G
groot 已提交
23
	"strconv"
G
groot 已提交
24

G
godchen 已提交
25 26
	"go.uber.org/zap"

S
SimFG 已提交
27 28
	"github.com/milvus-io/milvus-proto/go-api/commonpb"
	"github.com/milvus-io/milvus-proto/go-api/schemapb"
G
groot 已提交
29 30 31
	"github.com/milvus-io/milvus/internal/allocator"
	"github.com/milvus-io/milvus/internal/common"
	"github.com/milvus-io/milvus/internal/log"
G
groot 已提交
32
	"github.com/milvus-io/milvus/internal/proto/datapb"
G
groot 已提交
33
	"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
34
	"github.com/milvus-io/milvus/internal/querycoordv2/params"
G
groot 已提交
35
	"github.com/milvus-io/milvus/internal/storage"
36
	"github.com/milvus-io/milvus/internal/util/retry"
G
groot 已提交
37
	"github.com/milvus-io/milvus/internal/util/timerecord"
G
groot 已提交
38 39 40 41 42
)

const (
	JSONFileExt  = ".json"
	NumpyFileExt = ".npy"
G
groot 已提交
43

G
groot 已提交
44 45 46
	// supposed size of a single block, to control a binlog file size, the max biglog file size is no more than 2*SingleBlockSize
	SingleBlockSize = 16 * 1024 * 1024 // 16MB

G
groot 已提交
47 48 49 50 51 52 53 54 55
	// this limitation is to avoid this OOM risk:
	// simetimes system segment max size is a large number, a single segment fields data might cause OOM.
	// flush the segment when its data reach this limitation, let the compaction to compact it later.
	MaxSegmentSizeInMemory = 512 * 1024 * 1024 // 512MB

	// this limitation is to avoid this OOM risk:
	// if the shard number is a large number, although single segment size is small, but there are lot of in-memory segments,
	// the total memory size might cause OOM.
	MaxTotalSizeInMemory = 2 * 1024 * 1024 * 1024 // 2GB
G
groot 已提交
56

G
groot 已提交
57 58 59
	// progress percent value of persist state
	ProgressValueForPersist = 90

G
groot 已提交
60 61 62 63 64 65
	// keywords of import task informations
	FailedReason    = "failed_reason"
	Files           = "files"
	CollectionName  = "collection"
	PartitionName   = "partition"
	PersistTimeCost = "persist_cost"
G
groot 已提交
66
	ProgressPercent = "progress_percent"
G
groot 已提交
67 68
)

69 70 71
// ReportImportAttempts is the maximum # of attempts to retry when import fails.
var ReportImportAttempts uint = 10

G
groot 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardID int) error
type AssignSegmentFunc func(shardID int) (int64, string, error)
type CreateBinlogsFunc func(fields map[storage.FieldID]storage.FieldData, segmentID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error)
type SaveSegmentFunc func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64) error

type WorkingSegment struct {
	segmentID    int64                 // segment ID
	shardID      int                   // shard id
	targetChName string                // target dml channel
	rowCount     int64                 // accumulate row count
	memSize      int                   // total memory size of all binlogs
	fieldsInsert []*datapb.FieldBinlog // persisted binlogs
	fieldsStats  []*datapb.FieldBinlog // stats of persisted binlogs
}

G
groot 已提交
87 88 89 90 91
type ImportWrapper struct {
	ctx              context.Context            // for canceling parse process
	cancel           context.CancelFunc         // for canceling parse process
	collectionSchema *schemapb.CollectionSchema // collection schema
	shardNum         int32                      // sharding number of the collection
G
groot 已提交
92
	segmentSize      int64                      // maximum size of a segment(unit:byte) defined by dataCoord.segment.maxSize (milvus.yml)
G
groot 已提交
93
	rowIDAllocator   *allocator.IDAllocator     // autoid allocator
G
godchen 已提交
94
	chunkManager     storage.ChunkManager
G
groot 已提交
95

G
groot 已提交
96 97 98
	assignSegmentFunc AssignSegmentFunc // function to prepare a new segment
	createBinlogsFunc CreateBinlogsFunc // function to create binlog for a segment
	saveSegmentFunc   SaveSegmentFunc   // function to persist a segment
G
groot 已提交
99

G
groot 已提交
100 101 102
	importResult         *rootcoordpb.ImportResult                 // import result
	reportFunc           func(res *rootcoordpb.ImportResult) error // report import state to rootcoord
	reportImportAttempts uint                                      // attempts count if report function get error
G
groot 已提交
103 104

	workingSegments map[int]*WorkingSegment // a map shard id to working segments
G
groot 已提交
105
	progressPercent int64                   // working progress percent
G
groot 已提交
106 107
}

G
godchen 已提交
108
func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int64,
G
groot 已提交
109 110
	idAlloc *allocator.IDAllocator, cm storage.ChunkManager, importResult *rootcoordpb.ImportResult,
	reportFunc func(res *rootcoordpb.ImportResult) error) *ImportWrapper {
G
groot 已提交
111
	if collectionSchema == nil {
G
groot 已提交
112
		log.Error("import wrapper: collection schema is nil")
G
groot 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
		return nil
	}

	// ignore the RowID field and Timestamp field
	realSchema := &schemapb.CollectionSchema{
		Name:        collectionSchema.GetName(),
		Description: collectionSchema.GetDescription(),
		AutoID:      collectionSchema.GetAutoID(),
		Fields:      make([]*schemapb.FieldSchema, 0),
	}
	for i := 0; i < len(collectionSchema.Fields); i++ {
		schema := collectionSchema.Fields[i]
		if schema.GetName() == common.RowIDFieldName || schema.GetName() == common.TimeStampFieldName {
			continue
		}
		realSchema.Fields = append(realSchema.Fields, schema)
	}

	ctx, cancel := context.WithCancel(ctx)

	wrapper := &ImportWrapper{
G
groot 已提交
134 135 136 137 138 139 140 141 142 143 144
		ctx:                  ctx,
		cancel:               cancel,
		collectionSchema:     realSchema,
		shardNum:             shardNum,
		segmentSize:          segmentSize,
		rowIDAllocator:       idAlloc,
		chunkManager:         cm,
		importResult:         importResult,
		reportFunc:           reportFunc,
		reportImportAttempts: ReportImportAttempts,
		workingSegments:      make(map[int]*WorkingSegment),
G
groot 已提交
145 146 147 148 149
	}

	return wrapper
}

G
groot 已提交
150 151 152
func (p *ImportWrapper) SetCallbackFunctions(assignSegmentFunc AssignSegmentFunc, createBinlogsFunc CreateBinlogsFunc, saveSegmentFunc SaveSegmentFunc) error {
	if assignSegmentFunc == nil {
		log.Error("import wrapper: callback function AssignSegmentFunc is nil")
G
groot 已提交
153
		return fmt.Errorf("callback function AssignSegmentFunc is nil")
G
groot 已提交
154 155 156 157
	}

	if createBinlogsFunc == nil {
		log.Error("import wrapper: callback function CreateBinlogsFunc is nil")
G
groot 已提交
158
		return fmt.Errorf("callback function CreateBinlogsFunc is nil")
G
groot 已提交
159 160 161 162
	}

	if saveSegmentFunc == nil {
		log.Error("import wrapper: callback function SaveSegmentFunc is nil")
G
groot 已提交
163
		return fmt.Errorf("callback function SaveSegmentFunc is nil")
G
groot 已提交
164 165 166 167 168 169 170 171 172
	}

	p.assignSegmentFunc = assignSegmentFunc
	p.createBinlogsFunc = createBinlogsFunc
	p.saveSegmentFunc = saveSegmentFunc
	return nil
}

// Cancel method can be used to cancel parse process
G
groot 已提交
173 174 175 176 177
func (p *ImportWrapper) Cancel() error {
	p.cancel()
	return nil
}

G
groot 已提交
178 179 180 181
// fileValidation verify the input paths
// if all the files are json type, return true
// if all the files are numpy type, return false, and not allow duplicate file name
func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) {
182 183
	// use this map to check duplicate file name(only for numpy file)
	fileNames := make(map[string]struct{})
G
groot 已提交
184

G
groot 已提交
185
	totalSize := int64(0)
G
groot 已提交
186
	rowBased := false
G
groot 已提交
187 188
	for i := 0; i < len(filePaths); i++ {
		filePath := filePaths[i]
G
groot 已提交
189 190 191 192
		name, fileType := GetFileNameAndExt(filePath)

		// only allow json file or numpy file
		if fileType != JSONFileExt && fileType != NumpyFileExt {
G
groot 已提交
193 194
			log.Error("import wrapper: unsupported file type", zap.String("filePath", filePath))
			return false, fmt.Errorf("unsupported file type: '%s'", filePath)
G
groot 已提交
195 196 197 198 199
		}

		// we use the first file to determine row-based or column-based
		if i == 0 && fileType == JSONFileExt {
			rowBased = true
G
groot 已提交
200
		}
G
groot 已提交
201 202

		// check file type
G
groot 已提交
203
		// row-based only support json type, column-based only support numpy type
G
groot 已提交
204 205
		if rowBased {
			if fileType != JSONFileExt {
G
groot 已提交
206
				log.Error("import wrapper: unsupported file type for row-based mode", zap.String("filePath", filePath))
G
groot 已提交
207
				return rowBased, fmt.Errorf("unsupported file type for row-based mode: '%s'", filePath)
G
groot 已提交
208 209
			}
		} else {
G
groot 已提交
210
			if fileType != NumpyFileExt {
G
groot 已提交
211
				log.Error("import wrapper: unsupported file type for column-based mode", zap.String("filePath", filePath))
G
groot 已提交
212
				return rowBased, fmt.Errorf("unsupported file type for column-based mode: '%s'", filePath)
G
groot 已提交
213 214 215
			}
		}

G
groot 已提交
216 217 218 219
		// check dupliate file
		_, ok := fileNames[name]
		if ok {
			log.Error("import wrapper: duplicate file name", zap.String("filePath", filePath))
G
groot 已提交
220
			return rowBased, fmt.Errorf("duplicate file: '%s'", filePath)
G
groot 已提交
221 222 223
		}
		fileNames[name] = struct{}{}

G
groot 已提交
224
		// check file size, single file size cannot exceed MaxFileSize
225
		size, err := p.chunkManager.Size(p.ctx, filePath)
G
groot 已提交
226
		if err != nil {
G
groot 已提交
227
			log.Error("import wrapper: failed to get file size", zap.String("filePath", filePath), zap.Error(err))
G
groot 已提交
228
			return rowBased, fmt.Errorf("failed to get file size of '%s', error:%w", filePath, err)
G
groot 已提交
229 230
		}

G
groot 已提交
231
		// empty file
G
groot 已提交
232
		if size == 0 {
G
groot 已提交
233
			log.Error("import wrapper: file size is zero", zap.String("filePath", filePath))
G
groot 已提交
234
			return rowBased, fmt.Errorf("the file '%s' size is zero", filePath)
G
groot 已提交
235
		}
G
groot 已提交
236

237
		if size > params.Params.CommonCfg.ImportMaxFileSize {
G
groot 已提交
238
			log.Error("import wrapper: file size exceeds the maximum size", zap.String("filePath", filePath),
239 240
				zap.Int64("fileSize", size), zap.Int64("MaxFileSize", params.Params.CommonCfg.ImportMaxFileSize))
			return rowBased, fmt.Errorf("the file '%s' size exceeds the maximum size: %d bytes", filePath, params.Params.CommonCfg.ImportMaxFileSize)
G
groot 已提交
241
		}
G
groot 已提交
242 243 244
		totalSize += size
	}

G
groot 已提交
245
	return rowBased, nil
G
groot 已提交
246 247
}

G
groot 已提交
248
// Import is the entry of import operation
G
groot 已提交
249
// filePath and rowBased are from ImportTask
G
groot 已提交
250 251 252
// if onlyValidate is true, this process only do validation, no data generated, flushFunc will not be called
func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error {
	log.Info("import wrapper: begin import", zap.Any("filePaths", filePaths), zap.Any("options", options))
G
groot 已提交
253

G
groot 已提交
254 255
	// data restore function to import milvus native binlog files(for backup/restore tools)
	// the backup/restore tool provide two paths for a partition, the first path is binlog path, the second is deltalog path
256
	if options.IsBackup && p.isBinlogImport(filePaths) {
G
groot 已提交
257
		return p.doBinlogImport(filePaths, options.TsStartPoint, options.TsEndPoint)
G
groot 已提交
258 259 260
	}

	// normal logic for import general data files
G
groot 已提交
261
	rowBased, err := p.fileValidation(filePaths)
G
groot 已提交
262 263 264 265
	if err != nil {
		return err
	}

G
groot 已提交
266
	tr := timerecord.NewTimeRecorder("Import task")
G
groot 已提交
267 268 269
	if rowBased {
		// parse and consume row-based files
		// for row-based files, the JSONRowConsumer will generate autoid for primary key, and split rows into segments
G
groot 已提交
270
		// according to shard number, so the flushFunc will be called in the JSONRowConsumer
G
groot 已提交
271 272
		for i := 0; i < len(filePaths); i++ {
			filePath := filePaths[i]
G
groot 已提交
273
			_, fileType := GetFileNameAndExt(filePath)
G
groot 已提交
274
			log.Info("import wrapper:  row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
G
groot 已提交
275 276

			if fileType == JSONFileExt {
G
groot 已提交
277
				err = p.parseRowBasedJSON(filePath, options.OnlyValidate)
G
groot 已提交
278
				if err != nil {
G
groot 已提交
279
					log.Error("import wrapper: failed to parse row-based json file", zap.Error(err), zap.String("filePath", filePath))
G
groot 已提交
280 281
					return err
				}
G
groot 已提交
282 283 284 285
			} // no need to check else, since the fileValidation() already do this

			// trigger gc after each file finished
			triggerGC()
G
groot 已提交
286 287
		}
	} else {
G
groot 已提交
288 289 290 291 292 293
		// parse and consume column-based files(currently support numpy)
		// for column-based files, the NumpyParser will generate autoid for primary key, and split rows into segments
		// according to shard number, so the flushFunc will be called in the NumpyParser
		flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
			printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
			return p.flushFunc(fields, shardID)
G
groot 已提交
294
		}
G
groot 已提交
295 296
		parser, err := NewNumpyParser(p.ctx, p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize,
			p.chunkManager, flushFunc, p.updateProgressPercent)
G
groot 已提交
297 298
		if err != nil {
			return err
G
groot 已提交
299 300
		}

G
groot 已提交
301
		err = parser.Parse(filePaths)
G
groot 已提交
302 303 304
		if err != nil {
			return err
		}
G
groot 已提交
305

G
groot 已提交
306 307 308
		p.importResult.AutoIds = append(p.importResult.AutoIds, parser.IDRange()...)

		// trigger after parse finished
G
groot 已提交
309
		triggerGC()
G
groot 已提交
310 311
	}

G
groot 已提交
312
	return p.reportPersisted(p.reportImportAttempts, tr)
G
groot 已提交
313 314
}

G
groot 已提交
315
// reportPersisted notify the rootcoord to mark the task state to be ImportPersisted
G
groot 已提交
316
func (p *ImportWrapper) reportPersisted(reportAttempts uint, tr *timerecord.TimeRecorder) error {
G
groot 已提交
317 318 319 320 321 322
	// force close all segments
	err := p.closeAllWorkingSegments()
	if err != nil {
		return err
	}

G
groot 已提交
323 324 325 326 327 328
	if tr != nil {
		ts := tr.Elapse("persist finished").Seconds()
		p.importResult.Infos = append(p.importResult.Infos,
			&commonpb.KeyValuePair{Key: PersistTimeCost, Value: strconv.FormatFloat(ts, 'f', 2, 64)})
	}

G
groot 已提交
329 330
	// report file process state
	p.importResult.State = commonpb.ImportState_ImportPersisted
G
groot 已提交
331 332 333
	progressValue := strconv.Itoa(ProgressValueForPersist)
	UpdateKVInfo(&p.importResult.Infos, ProgressPercent, progressValue)

G
groot 已提交
334
	log.Info("import wrapper: report import result", zap.Any("importResult", p.importResult))
335 336 337
	// persist state task is valuable, retry more times in case fail this task only because of network error
	reportErr := retry.Do(p.ctx, func() error {
		return p.reportFunc(p.importResult)
G
groot 已提交
338
	}, retry.Attempts(reportAttempts))
339
	if reportErr != nil {
340
		log.Warn("import wrapper: fail to report import state to RootCoord", zap.Error(reportErr))
341 342 343 344 345
		return reportErr
	}
	return nil
}

G
groot 已提交
346
// isBinlogImport is to judge whether it is binlog import operation
G
groot 已提交
347 348
// For internal usage by the restore tool: https://github.com/zilliztech/milvus-backup
// This tool exports data from a milvus service, and call bulkload interface to import native data into another milvus service.
349
// This tool provides two paths: one is insert log path of a partition,the other is delta log path of this partition.
G
groot 已提交
350 351
// This method checks the filePaths, if the file paths is exist and not a file, we say it is native import.
func (p *ImportWrapper) isBinlogImport(filePaths []string) bool {
352 353 354
	// must contains the insert log path, and the delta log path is optional to be empty string
	if len(filePaths) != 2 {
		log.Info("import wrapper: paths count is not 2, not binlog import", zap.Int("len", len(filePaths)))
G
groot 已提交
355 356 357
		return false
	}

358
	checkFunc := func(filePath string) bool {
G
groot 已提交
359
		// contains file extension, is not a path
360
		_, fileType := GetFileNameAndExt(filePath)
G
groot 已提交
361
		if len(fileType) != 0 {
G
groot 已提交
362
			log.Info("import wrapper: not a path, not binlog import", zap.String("filePath", filePath), zap.String("fileType", fileType))
G
groot 已提交
363 364
			return false
		}
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
		return true
	}

	// the first path is insert log path
	filePath := filePaths[0]
	if len(filePath) == 0 {
		log.Info("import wrapper: the first path is empty string, not binlog import")
		return false
	}

	if !checkFunc(filePath) {
		return false
	}

	// the second path is delta log path
	filePath = filePaths[1]
	if len(filePath) > 0 && !checkFunc(filePath) {
		return false
G
groot 已提交
383 384 385 386 387 388
	}

	log.Info("import wrapper: do binlog import")
	return true
}

G
groot 已提交
389 390
// doBinlogImport is the entry of binlog import operation
func (p *ImportWrapper) doBinlogImport(filePaths []string, tsStartPoint uint64, tsEndPoint uint64) error {
G
groot 已提交
391 392
	tr := timerecord.NewTimeRecorder("Import task")

G
groot 已提交
393
	flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
G
groot 已提交
394 395
		printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths)
		return p.flushFunc(fields, shardID)
G
groot 已提交
396
	}
G
groot 已提交
397
	parser, err := NewBinlogParser(p.ctx, p.collectionSchema, p.shardNum, SingleBlockSize, p.chunkManager, flushFunc,
G
groot 已提交
398
		p.updateProgressPercent, tsStartPoint, tsEndPoint)
G
groot 已提交
399 400 401 402 403 404 405 406 407
	if err != nil {
		return err
	}

	err = parser.Parse(filePaths)
	if err != nil {
		return err
	}

G
groot 已提交
408
	return p.reportPersisted(p.reportImportAttempts, tr)
G
groot 已提交
409 410
}

G
groot 已提交
411
// parseRowBasedJSON is the entry of row-based json import operation
412 413 414 415 416
func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) error {
	tr := timerecord.NewTimeRecorder("json row-based parser: " + filePath)

	// for minio storage, chunkManager will download file into local memory
	// for local storage, chunkManager open the file directly
417
	file, err := p.chunkManager.Reader(p.ctx, filePath)
418 419 420 421 422
	if err != nil {
		return err
	}
	defer file.Close()

G
groot 已提交
423 424 425 426 427
	size, err := p.chunkManager.Size(p.ctx, filePath)
	if err != nil {
		return err
	}

428 429
	// parse file
	reader := bufio.NewReader(file)
G
groot 已提交
430
	parser := NewJSONParser(p.ctx, p.collectionSchema, p.updateProgressPercent)
431 432 433 434 435 436 437 438 439

	// if only validate, we input a empty flushFunc so that the consumer do nothing but only validation.
	var flushFunc ImportFlushFunc
	if onlyValidate {
		flushFunc = func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
			return nil
		}
	} else {
		flushFunc = func(fields map[storage.FieldID]storage.FieldData, shardID int) error {
440
			var filePaths = []string{filePath}
G
groot 已提交
441 442
			printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths)
			return p.flushFunc(fields, shardID)
443 444
		}
	}
G
groot 已提交
445

446
	consumer, err := NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, SingleBlockSize, flushFunc)
447 448 449 450
	if err != nil {
		return err
	}

G
groot 已提交
451
	err = parser.ParseRows(&IOReader{r: reader, fileSize: size}, consumer)
452 453 454 455 456
	if err != nil {
		return err
	}

	// for row-based files, auto-id is generated within JSONRowConsumer
G
groot 已提交
457
	p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...)
458 459 460 461 462

	tr.Elapse("parsed")
	return nil
}

G
groot 已提交
463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
// flushFunc is the callback function for parsers generate segment and save binlog files
func (p *ImportWrapper) flushFunc(fields map[storage.FieldID]storage.FieldData, shardID int) error {
	// if fields data is empty, do nothing
	var rowNum int
	memSize := 0
	for _, field := range fields {
		rowNum = field.RowNum()
		memSize += field.GetMemorySize()
		break
	}
	if rowNum <= 0 {
		log.Warn("import wrapper: fields data is empty", zap.Int("shardID", shardID))
		return nil
	}

	// if there is no segment for this shard, create a new one
	// if the segment exists and its size almost exceed segmentSize, close it and create a new one
	var segment *WorkingSegment
	segment, ok := p.workingSegments[shardID]
	if ok {
		// the segment already exists, check its size, if the size exceeds(or almost) segmentSize, close the segment
		if int64(segment.memSize)+int64(memSize) >= p.segmentSize {
			err := p.closeWorkingSegment(segment)
			if err != nil {
				return err
			}
			segment = nil
			p.workingSegments[shardID] = nil
		}

	}

	if segment == nil {
		// create a new segment
		segID, channelName, err := p.assignSegmentFunc(shardID)
		if err != nil {
G
groot 已提交
499
			log.Error("import wrapper: failed to assign a new segment", zap.Error(err), zap.Int("shardID", shardID))
G
groot 已提交
500
			return fmt.Errorf("failed to assign a new segment for shard id %d, error: %w", shardID, err)
G
groot 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
		}

		segment = &WorkingSegment{
			segmentID:    segID,
			shardID:      shardID,
			targetChName: channelName,
			rowCount:     int64(0),
			memSize:      0,
			fieldsInsert: make([]*datapb.FieldBinlog, 0),
			fieldsStats:  make([]*datapb.FieldBinlog, 0),
		}
		p.workingSegments[shardID] = segment
	}

	// save binlogs
	fieldsInsert, fieldsStats, err := p.createBinlogsFunc(fields, segment.segmentID)
	if err != nil {
G
groot 已提交
518
		log.Error("import wrapper: failed to save binlogs", zap.Error(err), zap.Int("shardID", shardID),
G
groot 已提交
519
			zap.Int64("segmentID", segment.segmentID), zap.String("targetChannel", segment.targetChName))
G
groot 已提交
520 521
		return fmt.Errorf("failed to save binlogs, shard id %d, segment id %d, channel '%s', error: %w",
			shardID, segment.segmentID, segment.targetChName, err)
G
groot 已提交
522 523 524 525 526 527 528
	}

	segment.fieldsInsert = append(segment.fieldsInsert, fieldsInsert...)
	segment.fieldsStats = append(segment.fieldsStats, fieldsStats...)
	segment.rowCount += int64(rowNum)
	segment.memSize += memSize

G
groot 已提交
529 530 531 532 533 534 535 536 537 538 539
	// report working progress percent value to rootcoord
	// if failed to report, ignore the error, the percent value might be improper but the task can be succeed
	progressValue := strconv.Itoa(int(p.progressPercent))
	UpdateKVInfo(&p.importResult.Infos, ProgressPercent, progressValue)
	reportErr := retry.Do(p.ctx, func() error {
		return p.reportFunc(p.importResult)
	}, retry.Attempts(p.reportImportAttempts))
	if reportErr != nil {
		log.Warn("import wrapper: fail to report working progress percent value to RootCoord", zap.Error(reportErr))
	}

G
groot 已提交
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
	return nil
}

// closeWorkingSegment marks a segment to be sealed
func (p *ImportWrapper) closeWorkingSegment(segment *WorkingSegment) error {
	log.Info("import wrapper: adding segment to the correct DataNode flow graph and saving binlog paths",
		zap.Int("shardID", segment.shardID),
		zap.Int64("segmentID", segment.segmentID),
		zap.String("targetChannel", segment.targetChName),
		zap.Int64("rowCount", segment.rowCount),
		zap.Int("insertLogCount", len(segment.fieldsInsert)),
		zap.Int("statsLogCount", len(segment.fieldsStats)))

	err := p.saveSegmentFunc(segment.fieldsInsert, segment.fieldsStats, segment.segmentID, segment.targetChName, segment.rowCount)
	if err != nil {
G
groot 已提交
555
		log.Error("import wrapper: failed to seal segment",
G
groot 已提交
556
			zap.Error(err),
G
groot 已提交
557 558 559
			zap.Int("shardID", segment.shardID),
			zap.Int64("segmentID", segment.segmentID),
			zap.String("targetChannel", segment.targetChName))
G
groot 已提交
560 561
		return fmt.Errorf("failed to seal segment, shard id %d, segment id %d, channel '%s', error: %w",
			segment.shardID, segment.segmentID, segment.targetChName, err)
G
groot 已提交
562 563 564 565 566 567 568 569 570
	}

	return nil
}

// closeAllWorkingSegments mark all segments to be sealed at the end of import operation
func (p *ImportWrapper) closeAllWorkingSegments() error {
	for _, segment := range p.workingSegments {
		err := p.closeWorkingSegment(segment)
G
groot 已提交
571 572 573 574
		if err != nil {
			return err
		}
	}
G
groot 已提交
575
	p.workingSegments = make(map[int]*WorkingSegment)
G
groot 已提交
576 577 578

	return nil
}
G
groot 已提交
579 580 581 582 583 584 585 586

func (p *ImportWrapper) updateProgressPercent(percent int64) {
	// ignore illegal percent value
	if percent < 0 || percent > 100 {
		return
	}
	p.progressPercent = percent
}