task.go 44.0 KB
Newer Older
1
package proxynode
Z
zhenshan.cao 已提交
2 3

import (
G
godchen 已提交
4
	"context"
5
	"errors"
Z
zhenshan.cao 已提交
6
	"log"
N
neza2017 已提交
7 8 9 10
	"math"
	"strconv"

	"github.com/golang/protobuf/proto"
11
	"github.com/zilliztech/milvus-distributed/internal/allocator"
12 13
	"github.com/zilliztech/milvus-distributed/internal/msgstream"
	"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
14 15
	"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
	"github.com/zilliztech/milvus-distributed/internal/proto/indexpb"
16 17
	"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
	"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
18
	"github.com/zilliztech/milvus-distributed/internal/proto/querypb"
N
neza2017 已提交
19
	"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
C
cai.zhang 已提交
20
	"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
Z
zhenshan.cao 已提交
21 22 23
)

type task interface {
24 25
	ID() UniqueID       // return ReqID
	SetID(uid UniqueID) // set ReqID
26
	Type() commonpb.MsgType
27 28
	BeginTs() Timestamp
	EndTs() Timestamp
Z
zhenshan.cao 已提交
29
	SetTs(ts Timestamp)
30
	OnEnqueue() error
Z
zhenshan.cao 已提交
31 32 33 34
	PreExecute() error
	Execute() error
	PostExecute() error
	WaitToFinish() error
35
	Notify(err error)
Z
zhenshan.cao 已提交
36 37
}

38
type BaseInsertTask = msgstream.InsertMsg
39 40

type InsertTask struct {
41
	BaseInsertTask
D
dragondriver 已提交
42
	Condition
43 44 45
	dataServiceClient DataServiceClient
	result            *milvuspb.InsertResponse
	rowIDAllocator    *allocator.IDAllocator
46 47
}

48 49 50 51
func (it *InsertTask) OnEnqueue() error {
	return nil
}

52
func (it *InsertTask) SetID(uid UniqueID) {
53
	it.Base.MsgID = uid
54 55
}

56
func (it *InsertTask) SetTs(ts Timestamp) {
N
neza2017 已提交
57 58 59 60 61 62 63
	rowNum := len(it.RowData)
	it.Timestamps = make([]uint64, rowNum)
	for index := range it.Timestamps {
		it.Timestamps[index] = ts
	}
	it.BeginTimestamp = ts
	it.EndTimestamp = ts
64 65 66
}

func (it *InsertTask) BeginTs() Timestamp {
N
neza2017 已提交
67
	return it.BeginTimestamp
68 69 70
}

func (it *InsertTask) EndTs() Timestamp {
N
neza2017 已提交
71
	return it.EndTimestamp
72 73
}

C
cai.zhang 已提交
74
func (it *InsertTask) ID() UniqueID {
75
	return it.Base.MsgID
76 77
}

78
func (it *InsertTask) Type() commonpb.MsgType {
79
	return it.Base.MsgType
80 81 82
}

func (it *InsertTask) PreExecute() error {
83
	it.Base.MsgType = commonpb.MsgType_kInsert
84
	it.Base.SourceID = Params.ProxyID
85

N
neza2017 已提交
86 87 88 89
	collectionName := it.BaseInsertTask.CollectionName
	if err := ValidateCollectionName(collectionName); err != nil {
		return err
	}
90
	partitionTag := it.BaseInsertTask.PartitionName
N
neza2017 已提交
91 92 93 94
	if err := ValidatePartitionTag(partitionTag, true); err != nil {
		return err
	}

95 96 97 98
	return nil
}

func (it *InsertTask) Execute() error {
99
	collectionName := it.BaseInsertTask.CollectionName
G
godchen 已提交
100 101 102
	collSchema, err := globalMetaCache.GetCollectionSchema(collectionName)
	if err != nil {
		return err
103
	}
G
godchen 已提交
104 105 106
	autoID := collSchema.AutoID
	collID, err := globalMetaCache.GetCollectionID(collectionName)
	if err != nil {
107 108
		return err
	}
G
godchen 已提交
109
	it.CollectionID = collID
110 111 112 113 114 115 116 117 118 119 120
	var partitionID UniqueID
	if len(it.PartitionName) > 0 {
		partitionID, err = globalMetaCache.GetPartitionID(collectionName, it.PartitionName)
		if err != nil {
			return err
		}
	} else {
		partitionID, err = globalMetaCache.GetPartitionID(collectionName, Params.DefaultPartitionTag)
		if err != nil {
			return err
		}
G
godchen 已提交
121 122
	}
	it.PartitionID = partitionID
123 124
	var rowIDBegin UniqueID
	var rowIDEnd UniqueID
G
godchen 已提交
125 126
	rowNums := len(it.BaseInsertTask.RowData)
	rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums))
127

G
godchen 已提交
128 129 130 131 132 133 134
	it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums)
	for i := rowIDBegin; i < rowIDEnd; i++ {
		offset := i - rowIDBegin
		it.BaseInsertTask.RowIDs[offset] = i
	}

	if autoID {
N
neza2017 已提交
135 136 137
		if it.HashValues == nil || len(it.HashValues) == 0 {
			it.HashValues = make([]uint32, 0)
		}
G
godchen 已提交
138 139
		for _, rowID := range it.RowIDs {
			hashValue, _ := typeutil.Hash32Int64(rowID)
N
neza2017 已提交
140
			it.HashValues = append(it.HashValues, hashValue)
141 142 143
		}
	}

144
	var tsMsg msgstream.TsMsg = &it.BaseInsertTask
145 146 147
	msgPack := &msgstream.MsgPack{
		BeginTs: it.BeginTs(),
		EndTs:   it.EndTs(),
X
xige-16 已提交
148
		Msgs:    make([]msgstream.TsMsg, 1),
149
	}
G
godchen 已提交
150

151
	it.result = &milvuspb.InsertResponse{
152 153 154
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_SUCCESS,
		},
155 156
		RowIDBegin: rowIDBegin,
		RowIDEnd:   rowIDEnd,
157
	}
158 159 160

	msgPack.Msgs[0] = tsMsg

G
godchen 已提交
161
	stream, err := globalInsertChannelsMap.getInsertMsgStream(collID)
162
	if err != nil {
G
godchen 已提交
163
		resp, _ := it.dataServiceClient.GetInsertChannels(&datapb.InsertChannelRequest{
164 165 166 167 168 169 170
			Base: &commonpb.MsgBase{
				MsgType:   commonpb.MsgType_kInsert, // todo
				MsgID:     it.Base.MsgID,            // todo
				Timestamp: 0,                        // todo
				SourceID:  Params.ProxyID,
			},
			DbID:         0, // todo
G
godchen 已提交
171
			CollectionID: collID,
172
		})
G
godchen 已提交
173 174 175 176 177
		if resp == nil {
			return errors.New("get insert channels resp is nil")
		}
		if resp.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
			return errors.New(resp.Status.Reason)
178
		}
G
godchen 已提交
179
		err = globalInsertChannelsMap.createInsertMsgStream(collID, resp.Values)
180 181 182 183
		if err != nil {
			return err
		}
	}
G
godchen 已提交
184
	stream, err = globalInsertChannelsMap.getInsertMsgStream(collID)
185 186 187 188 189 190 191
	if err != nil {
		it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
		it.result.Status.Reason = err.Error()
		return err
	}

	err = stream.Produce(msgPack)
192 193 194
	if err != nil {
		it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
		it.result.Status.Reason = err.Error()
195
		return err
196
	}
197

198 199 200 201 202 203 204 205
	return nil
}

func (it *InsertTask) PostExecute() error {
	return nil
}

type CreateCollectionTask struct {
D
dragondriver 已提交
206
	Condition
207
	*milvuspb.CreateCollectionRequest
208 209 210 211
	masterClient      MasterClient
	dataServiceClient DataServiceClient
	result            *commonpb.Status
	schema            *schemapb.CollectionSchema
212 213
}

214 215 216 217 218
func (cct *CreateCollectionTask) OnEnqueue() error {
	cct.Base = &commonpb.MsgBase{}
	return nil
}

C
cai.zhang 已提交
219
func (cct *CreateCollectionTask) ID() UniqueID {
220
	return cct.Base.MsgID
221 222
}

223
func (cct *CreateCollectionTask) SetID(uid UniqueID) {
224
	cct.Base.MsgID = uid
225 226
}

227
func (cct *CreateCollectionTask) Type() commonpb.MsgType {
228
	return cct.Base.MsgType
229 230 231
}

func (cct *CreateCollectionTask) BeginTs() Timestamp {
232
	return cct.Base.Timestamp
233 234 235
}

func (cct *CreateCollectionTask) EndTs() Timestamp {
236
	return cct.Base.Timestamp
237 238 239
}

func (cct *CreateCollectionTask) SetTs(ts Timestamp) {
240
	cct.Base.Timestamp = ts
241 242 243
}

func (cct *CreateCollectionTask) PreExecute() error {
244
	cct.Base.MsgType = commonpb.MsgType_kCreateCollection
245
	cct.Base.SourceID = Params.ProxyID
246 247 248 249 250 251 252

	cct.schema = &schemapb.CollectionSchema{}
	err := proto.Unmarshal(cct.Schema, cct.schema)
	if err != nil {
		return err
	}

253 254
	if int64(len(cct.schema.Fields)) > Params.MaxFieldNum {
		return errors.New("maximum field's number should be limited to " + strconv.FormatInt(Params.MaxFieldNum, 10))
N
neza2017 已提交
255 256 257 258 259 260 261
	}

	// validate collection name
	if err := ValidateCollectionName(cct.schema.Name); err != nil {
		return err
	}

N
neza2017 已提交
262 263 264 265 266 267 268 269
	if err := ValidateDuplicatedFieldName(cct.schema.Fields); err != nil {
		return err
	}

	if err := ValidatePrimaryKey(cct.schema); err != nil {
		return err
	}

N
neza2017 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
	// validate field name
	for _, field := range cct.schema.Fields {
		if err := ValidateFieldName(field.Name); err != nil {
			return err
		}
		if field.DataType == schemapb.DataType_VECTOR_FLOAT || field.DataType == schemapb.DataType_VECTOR_BINARY {
			exist := false
			var dim int64 = 0
			for _, param := range field.TypeParams {
				if param.Key == "dim" {
					exist = true
					tmp, err := strconv.ParseInt(param.Value, 10, 64)
					if err != nil {
						return err
					}
					dim = tmp
					break
				}
			}
			if !exist {
				return errors.New("dimension is not defined in field type params")
			}
			if field.DataType == schemapb.DataType_VECTOR_FLOAT {
				if err := ValidateDimension(dim, false); err != nil {
					return err
				}
			} else {
				if err := ValidateDimension(dim, true); err != nil {
					return err
				}
			}
		}
N
neza2017 已提交
302 303 304
		if err := ValidateVectorFieldMetricType(field); err != nil {
			return err
		}
N
neza2017 已提交
305 306
	}

307
	return nil
Z
zhenshan.cao 已提交
308 309
}

310
func (cct *CreateCollectionTask) Execute() error {
311
	var err error
312
	cct.result, err = cct.masterClient.CreateCollection(cct.CreateCollectionRequest)
313 314 315 316
	if err != nil {
		return err
	}
	if cct.result.ErrorCode == commonpb.ErrorCode_SUCCESS {
G
godchen 已提交
317
		collID, err := globalMetaCache.GetCollectionID(cct.CollectionName)
318 319 320
		if err != nil {
			return err
		}
G
godchen 已提交
321
		resp, _ := cct.dataServiceClient.GetInsertChannels(&datapb.InsertChannelRequest{
322 323 324 325 326 327 328
			Base: &commonpb.MsgBase{
				MsgType:   commonpb.MsgType_kInsert, // todo
				MsgID:     cct.Base.MsgID,           // todo
				Timestamp: 0,                        // todo
				SourceID:  Params.ProxyID,
			},
			DbID:         0, // todo
G
godchen 已提交
329
			CollectionID: collID,
330
		})
G
godchen 已提交
331 332 333 334 335
		if resp == nil {
			return errors.New("get insert channels resp is nil")
		}
		if resp.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
			return errors.New(resp.Status.Reason)
336
		}
G
godchen 已提交
337
		err = globalInsertChannelsMap.createInsertMsgStream(collID, resp.Values)
338 339 340 341 342
		if err != nil {
			return err
		}
	}
	return nil
Z
zhenshan.cao 已提交
343 344
}

345 346
func (cct *CreateCollectionTask) PostExecute() error {
	return nil
Z
zhenshan.cao 已提交
347 348
}

349
type DropCollectionTask struct {
D
dragondriver 已提交
350
	Condition
351
	*milvuspb.DropCollectionRequest
Z
zhenshan.cao 已提交
352 353
	masterClient MasterClient
	result       *commonpb.Status
354 355
}

356 357 358 359 360
func (dct *DropCollectionTask) OnEnqueue() error {
	dct.Base = &commonpb.MsgBase{}
	return nil
}

C
cai.zhang 已提交
361
func (dct *DropCollectionTask) ID() UniqueID {
362
	return dct.Base.MsgID
363 364
}

365
func (dct *DropCollectionTask) SetID(uid UniqueID) {
366
	dct.Base.MsgID = uid
367 368
}

369
func (dct *DropCollectionTask) Type() commonpb.MsgType {
370
	return dct.Base.MsgType
371 372 373
}

func (dct *DropCollectionTask) BeginTs() Timestamp {
374
	return dct.Base.Timestamp
375 376 377
}

func (dct *DropCollectionTask) EndTs() Timestamp {
378
	return dct.Base.Timestamp
379 380 381
}

func (dct *DropCollectionTask) SetTs(ts Timestamp) {
382
	dct.Base.Timestamp = ts
383 384 385
}

func (dct *DropCollectionTask) PreExecute() error {
386
	dct.Base.MsgType = commonpb.MsgType_kDropCollection
387
	dct.Base.SourceID = Params.ProxyID
388

389
	if err := ValidateCollectionName(dct.CollectionName); err != nil {
N
neza2017 已提交
390 391
		return err
	}
392 393 394 395
	return nil
}

func (dct *DropCollectionTask) Execute() error {
396 397 398 399
	collID, err := globalMetaCache.GetCollectionID(dct.CollectionName)
	if err != nil {
		return err
	}
B
bigsheeper 已提交
400 401 402
	dct.result, _ = dct.masterClient.DropCollection(dct.DropCollectionRequest)
	if dct.result.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(dct.result.Reason)
403
	}
G
godchen 已提交
404 405 406 407 408
	err = globalInsertChannelsMap.closeInsertMsgStream(collID)
	if err != nil {
		return err
	}
	return nil
409 410 411
}

func (dct *DropCollectionTask) PostExecute() error {
G
godchen 已提交
412
	globalMetaCache.RemoveCollection(dct.CollectionName)
Z
zhenshan.cao 已提交
413
	return nil
414 415
}

416
type SearchTask struct {
D
dragondriver 已提交
417
	Condition
418
	internalpb2.SearchRequest
Z
zhenshan.cao 已提交
419
	queryMsgStream msgstream.MsgStream
420
	resultBuf      chan []*internalpb2.SearchResults
421 422 423 424 425 426
	result         *milvuspb.SearchResults
	query          *milvuspb.SearchRequest
}

func (st *SearchTask) OnEnqueue() error {
	return nil
427 428
}

429 430
func (st *SearchTask) ID() UniqueID {
	return st.Base.MsgID
431 432
}

433 434
func (st *SearchTask) SetID(uid UniqueID) {
	st.Base.MsgID = uid
435 436
}

437 438
func (st *SearchTask) Type() commonpb.MsgType {
	return st.Base.MsgType
439 440
}

441 442
func (st *SearchTask) BeginTs() Timestamp {
	return st.Base.Timestamp
443 444
}

445 446
func (st *SearchTask) EndTs() Timestamp {
	return st.Base.Timestamp
447 448
}

449 450
func (st *SearchTask) SetTs(ts Timestamp) {
	st.Base.Timestamp = ts
451 452
}

453
func (st *SearchTask) PreExecute() error {
454
	st.Base.MsgType = commonpb.MsgType_kSearch
455
	st.Base.SourceID = Params.ProxyID
456

457
	collectionName := st.query.CollectionName
G
godchen 已提交
458
	_, err := globalMetaCache.GetCollectionID(collectionName)
459 460 461 462
	if err != nil { // err is not nil if collection not exists
		return err
	}

463
	if err := ValidateCollectionName(st.query.CollectionName); err != nil {
N
neza2017 已提交
464 465 466
		return err
	}

467
	for _, tag := range st.query.PartitionNames {
N
neza2017 已提交
468 469 470 471
		if err := ValidatePartitionTag(tag, false); err != nil {
			return err
		}
	}
472 473
	st.Base.MsgType = commonpb.MsgType_kSearch
	queryBytes, err := proto.Marshal(st.query)
C
cai.zhang 已提交
474 475 476
	if err != nil {
		return err
	}
477
	st.Query = &commonpb.Blob{
C
cai.zhang 已提交
478 479
		Value: queryBytes,
	}
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498

	st.ResultChannelID = Params.SearchResultChannelNames[0]
	st.DbID = 0 // todo
	collectionID, err := globalMetaCache.GetCollectionID(collectionName)
	if err != nil { // err is not nil if collection not exists
		return err
	}
	st.CollectionID = collectionID
	st.PartitionIDs = make([]UniqueID, 0)
	for _, partitionName := range st.query.PartitionNames {
		partitionID, err := globalMetaCache.GetPartitionID(collectionName, partitionName)
		if err != nil {
			return err
		}
		st.PartitionIDs = append(st.PartitionIDs, partitionID)
	}
	st.Dsl = st.query.Dsl
	st.PlaceholderGroup = st.query.PlaceholderGroup

499 500 501
	return nil
}

502
func (st *SearchTask) Execute() error {
503
	var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{
504
		SearchRequest: st.SearchRequest,
505
		BaseMsg: msgstream.BaseMsg{
506
			HashValues:     []uint32{uint32(Params.ProxyID)},
507 508
			BeginTimestamp: st.Base.Timestamp,
			EndTimestamp:   st.Base.Timestamp,
509 510 511
		},
	}
	msgPack := &msgstream.MsgPack{
512 513
		BeginTs: st.Base.Timestamp,
		EndTs:   st.Base.Timestamp,
X
xige-16 已提交
514
		Msgs:    make([]msgstream.TsMsg, 1),
515
	}
X
xige-16 已提交
516
	msgPack.Msgs[0] = tsMsg
517
	err := st.queryMsgStream.Produce(msgPack)
518
	log.Printf("[NodeImpl] length of searchMsg: %v", len(msgPack.Msgs))
C
cai.zhang 已提交
519
	if err != nil {
520
		log.Printf("[NodeImpl] send search request failed: %v", err)
C
cai.zhang 已提交
521 522
	}
	return err
523 524
}

525
func (st *SearchTask) PostExecute() error {
526 527
	for {
		select {
528
		case <-st.Ctx().Done():
529 530
			log.Print("SearchTask: wait to finish failed, timeout!, taskID:", st.ID())
			return errors.New("SearchTask:wait to finish failed, timeout:" + strconv.FormatInt(st.ID(), 10))
531
		case searchResults := <-st.resultBuf:
532
			// fmt.Println("searchResults: ", searchResults)
533
			filterSearchResult := make([]*internalpb2.SearchResults, 0)
Z
zhenshan.cao 已提交
534
			var filterReason string
535 536 537
			for _, partialSearchResult := range searchResults {
				if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_SUCCESS {
					filterSearchResult = append(filterSearchResult, partialSearchResult)
538 539 540 541 542 543 544 545 546 547
					// For debugging, please don't delete.
					//for i := 0; i < len(partialSearchResult.Hits); i++ {
					//	testHits := milvuspb.Hits{}
					//	err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
					//	if err != nil {
					//		panic(err)
					//	}
					//	fmt.Println(testHits.IDs)
					//	fmt.Println(testHits.Scores)
					//}
Z
zhenshan.cao 已提交
548 549
				} else {
					filterReason += partialSearchResult.Status.Reason + "\n"
550 551 552
				}
			}

553 554
			availableQueryNodeNum := len(filterSearchResult)
			if availableQueryNodeNum <= 0 {
555
				st.result = &milvuspb.SearchResults{
Z
zhenshan.cao 已提交
556 557 558 559 560 561
					Status: &commonpb.Status{
						ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
						Reason:    filterReason,
					},
				}
				return errors.New(filterReason)
562
			}
C
cai.zhang 已提交
563

564
			hits := make([][]*milvuspb.Hits, 0)
565
			for _, partialSearchResult := range filterSearchResult {
B
bigsheeper 已提交
566
				if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
567 568 569
					filterReason += "nq is zero\n"
					continue
				}
570
				partialHits := make([]*milvuspb.Hits, 0)
571
				for _, bs := range partialSearchResult.Hits {
572
					partialHit := &milvuspb.Hits{}
573
					err := proto.Unmarshal(bs, partialHit)
N
neza2017 已提交
574
					if err != nil {
B
bigsheeper 已提交
575
						log.Println("unmarshal error")
N
neza2017 已提交
576 577
						return err
					}
578 579 580 581 582 583 584
					partialHits = append(partialHits, partialHit)
				}
				hits = append(hits, partialHits)
			}

			availableQueryNodeNum = len(hits)
			if availableQueryNodeNum <= 0 {
585
				st.result = &milvuspb.SearchResults{
586 587 588 589 590 591 592 593 594 595
					Status: &commonpb.Status{
						ErrorCode: commonpb.ErrorCode_SUCCESS,
						Reason:    filterReason,
					},
				}
				return nil
			}

			nq := len(hits[0])
			if nq <= 0 {
596
				st.result = &milvuspb.SearchResults{
597 598 599 600
					Status: &commonpb.Status{
						ErrorCode: commonpb.ErrorCode_SUCCESS,
						Reason:    filterReason,
					},
N
neza2017 已提交
601
				}
602
				return nil
N
neza2017 已提交
603
			}
C
cai.zhang 已提交
604

B
bigsheeper 已提交
605 606 607 608 609 610 611 612 613 614
			topk := 0
			getMax := func(a, b int) int {
				if a > b {
					return a
				}
				return b
			}
			for _, hit := range hits {
				topk = getMax(topk, len(hit[0].IDs))
			}
615
			st.result = &milvuspb.SearchResults{
616 617 618
				Status: &commonpb.Status{
					ErrorCode: 0,
				},
C
cai.zhang 已提交
619
				Hits: make([][]byte, 0),
620
			}
C
cai.zhang 已提交
621

G
GuoRentong 已提交
622
			const minFloat32 = -1 * float32(math.MaxFloat32)
623 624
			for i := 0; i < nq; i++ {
				locs := make([]int, availableQueryNodeNum)
625
				reducedHits := &milvuspb.Hits{
C
cai.zhang 已提交
626 627 628 629 630
					IDs:     make([]int64, 0),
					RowData: make([][]byte, 0),
					Scores:  make([]float32, 0),
				}

631
				for j := 0; j < topk; j++ {
B
bigsheeper 已提交
632
					valid := false
G
GuoRentong 已提交
633
					choice, maxDistance := 0, minFloat32
634
					for q, loc := range locs { // query num, the number of ways to merge
B
bigsheeper 已提交
635 636 637
						if loc >= len(hits[q][i].IDs) {
							continue
						}
N
neza2017 已提交
638
						distance := hits[q][i].Scores[loc]
C
cai.zhang 已提交
639
						if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
640
							choice = q
G
GuoRentong 已提交
641
							maxDistance = distance
B
bigsheeper 已提交
642
							valid = true
643 644
						}
					}
B
bigsheeper 已提交
645 646 647
					if !valid {
						break
					}
648
					choiceOffset := locs[choice]
649 650
					// check if distance is valid, `invalid` here means very very big,
					// in this process, distance here is the smallest, so the rest of distance are all invalid
G
GuoRentong 已提交
651
					if hits[choice][i].Scores[choiceOffset] <= minFloat32 {
652
						break
653
					}
N
neza2017 已提交
654
					reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset])
C
cai.zhang 已提交
655 656 657
					if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 {
						reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset])
					}
N
neza2017 已提交
658
					reducedHits.Scores = append(reducedHits.Scores, hits[choice][i].Scores[choiceOffset])
659 660
					locs[choice]++
				}
G
GuoRentong 已提交
661 662 663 664 665
				if searchResults[0].MetricType != "IP" {
					for k := range reducedHits.Scores {
						reducedHits.Scores[k] *= -1
					}
				}
N
neza2017 已提交
666 667
				reducedHitsBs, err := proto.Marshal(reducedHits)
				if err != nil {
B
bigsheeper 已提交
668
					log.Println("marshal error")
N
neza2017 已提交
669 670
					return err
				}
671
				st.result.Hits = append(st.result.Hits, reducedHitsBs)
672
			}
C
cai.zhang 已提交
673
			return nil
674 675
		}
	}
D
dragondriver 已提交
676 677
}

678
type HasCollectionTask struct {
D
dragondriver 已提交
679
	Condition
680
	*milvuspb.HasCollectionRequest
Z
zhenshan.cao 已提交
681
	masterClient MasterClient
682
	result       *milvuspb.BoolResponse
683 684
}

685 686 687 688 689
func (hct *HasCollectionTask) OnEnqueue() error {
	hct.Base = &commonpb.MsgBase{}
	return nil
}

C
cai.zhang 已提交
690
func (hct *HasCollectionTask) ID() UniqueID {
691
	return hct.Base.MsgID
692 693
}

694
func (hct *HasCollectionTask) SetID(uid UniqueID) {
695
	hct.Base.MsgID = uid
696 697
}

698
func (hct *HasCollectionTask) Type() commonpb.MsgType {
699
	return hct.Base.MsgType
700 701 702
}

func (hct *HasCollectionTask) BeginTs() Timestamp {
703
	return hct.Base.Timestamp
704 705 706
}

func (hct *HasCollectionTask) EndTs() Timestamp {
707
	return hct.Base.Timestamp
708 709 710
}

func (hct *HasCollectionTask) SetTs(ts Timestamp) {
711
	hct.Base.Timestamp = ts
712 713 714
}

func (hct *HasCollectionTask) PreExecute() error {
715
	hct.Base.MsgType = commonpb.MsgType_kHasCollection
716
	hct.Base.SourceID = Params.ProxyID
717

718
	if err := ValidateCollectionName(hct.CollectionName); err != nil {
N
neza2017 已提交
719 720
		return err
	}
721 722 723 724
	return nil
}

func (hct *HasCollectionTask) Execute() error {
725
	var err error
726
	hct.result, err = hct.masterClient.HasCollection(hct.HasCollectionRequest)
G
godchen 已提交
727 728 729 730 731 732
	if hct.result == nil {
		return errors.New("has collection resp is nil")
	}
	if hct.result.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(hct.result.Status.Reason)
	}
733 734 735 736 737 738 739 740
	return err
}

func (hct *HasCollectionTask) PostExecute() error {
	return nil
}

type DescribeCollectionTask struct {
D
dragondriver 已提交
741
	Condition
742
	*milvuspb.DescribeCollectionRequest
Z
zhenshan.cao 已提交
743
	masterClient MasterClient
744
	result       *milvuspb.DescribeCollectionResponse
745 746
}

747 748 749 750 751
func (dct *DescribeCollectionTask) OnEnqueue() error {
	dct.Base = &commonpb.MsgBase{}
	return nil
}

C
cai.zhang 已提交
752
func (dct *DescribeCollectionTask) ID() UniqueID {
753
	return dct.Base.MsgID
754 755
}

756
func (dct *DescribeCollectionTask) SetID(uid UniqueID) {
757
	dct.Base.MsgID = uid
758 759
}

760
func (dct *DescribeCollectionTask) Type() commonpb.MsgType {
761
	return dct.Base.MsgType
762 763 764
}

func (dct *DescribeCollectionTask) BeginTs() Timestamp {
765
	return dct.Base.Timestamp
766 767 768
}

func (dct *DescribeCollectionTask) EndTs() Timestamp {
769
	return dct.Base.Timestamp
770 771 772
}

func (dct *DescribeCollectionTask) SetTs(ts Timestamp) {
773
	dct.Base.Timestamp = ts
774 775 776
}

func (dct *DescribeCollectionTask) PreExecute() error {
777
	dct.Base.MsgType = commonpb.MsgType_kDescribeCollection
778
	dct.Base.SourceID = Params.ProxyID
779

780
	if err := ValidateCollectionName(dct.CollectionName); err != nil {
N
neza2017 已提交
781 782
		return err
	}
783 784 785 786
	return nil
}

func (dct *DescribeCollectionTask) Execute() error {
787
	var err error
788
	dct.result, err = dct.masterClient.DescribeCollection(dct.DescribeCollectionRequest)
G
godchen 已提交
789 790
	if dct.result == nil {
		return errors.New("has collection resp is nil")
791
	}
G
godchen 已提交
792 793 794 795
	if dct.result.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(dct.result.Status.Reason)
	}
	return err
796 797 798
}

func (dct *DescribeCollectionTask) PostExecute() error {
799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
	return nil
}

type GetCollectionsStatisticsTask struct {
	Condition
	*milvuspb.CollectionStatsRequest
	dataServiceClient DataServiceClient
	result            *milvuspb.CollectionStatsResponse
}

func (g *GetCollectionsStatisticsTask) ID() UniqueID {
	return g.Base.MsgID
}

func (g *GetCollectionsStatisticsTask) SetID(uid UniqueID) {
	g.Base.MsgID = uid
}

func (g *GetCollectionsStatisticsTask) Type() commonpb.MsgType {
	return g.Base.MsgType
}

func (g *GetCollectionsStatisticsTask) BeginTs() Timestamp {
	return g.Base.Timestamp
}

func (g *GetCollectionsStatisticsTask) EndTs() Timestamp {
	return g.Base.Timestamp
}

func (g *GetCollectionsStatisticsTask) SetTs(ts Timestamp) {
	g.Base.Timestamp = ts
}

func (g *GetCollectionsStatisticsTask) OnEnqueue() error {
	g.Base = &commonpb.MsgBase{}
	return nil
}

func (g *GetCollectionsStatisticsTask) PreExecute() error {
	g.Base.MsgType = commonpb.MsgType_kGetCollectionStatistics
	g.Base.SourceID = Params.ProxyID
	return nil
}

func (g *GetCollectionsStatisticsTask) Execute() error {
	collID, err := globalMetaCache.GetCollectionID(g.CollectionName)
	if err != nil {
		return err
	}
	req := &datapb.CollectionStatsRequest{
		Base: &commonpb.MsgBase{
			MsgType:   commonpb.MsgType_kGetCollectionStatistics,
			MsgID:     g.Base.MsgID,
			Timestamp: g.Base.Timestamp,
			SourceID:  g.Base.SourceID,
		},
		CollectionID: collID,
	}

G
godchen 已提交
859 860 861 862 863 864
	result, _ := g.dataServiceClient.GetCollectionStatistics(req)
	if result == nil {
		return errors.New("get collection statistics resp is nil")
	}
	if result.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(result.Status.Reason)
865 866 867 868 869 870 871 872 873 874 875 876
	}
	g.result = &milvuspb.CollectionStatsResponse{
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_SUCCESS,
			Reason:    "",
		},
		Stats: result.Stats,
	}
	return nil
}

func (g *GetCollectionsStatisticsTask) PostExecute() error {
877 878 879 880
	return nil
}

type ShowCollectionsTask struct {
D
dragondriver 已提交
881
	Condition
882
	*milvuspb.ShowCollectionRequest
Z
zhenshan.cao 已提交
883
	masterClient MasterClient
884
	result       *milvuspb.ShowCollectionResponse
G
godchen 已提交
885
	ctx          context.Context
886 887
}

888 889 890 891 892
func (sct *ShowCollectionsTask) OnEnqueue() error {
	sct.Base = &commonpb.MsgBase{}
	return nil
}

C
cai.zhang 已提交
893
func (sct *ShowCollectionsTask) ID() UniqueID {
894
	return sct.Base.MsgID
895 896
}

897
func (sct *ShowCollectionsTask) SetID(uid UniqueID) {
898
	sct.Base.MsgID = uid
899 900
}

901
func (sct *ShowCollectionsTask) Type() commonpb.MsgType {
902
	return sct.Base.MsgType
903 904 905
}

func (sct *ShowCollectionsTask) BeginTs() Timestamp {
906
	return sct.Base.Timestamp
907 908 909
}

func (sct *ShowCollectionsTask) EndTs() Timestamp {
910
	return sct.Base.Timestamp
911 912 913
}

func (sct *ShowCollectionsTask) SetTs(ts Timestamp) {
914
	sct.Base.Timestamp = ts
915 916 917
}

func (sct *ShowCollectionsTask) PreExecute() error {
918
	sct.Base.MsgType = commonpb.MsgType_kShowCollections
919
	sct.Base.SourceID = Params.ProxyID
920

921 922 923 924
	return nil
}

func (sct *ShowCollectionsTask) Execute() error {
925
	var err error
926
	sct.result, err = sct.masterClient.ShowCollections(sct.ShowCollectionRequest)
G
godchen 已提交
927 928 929 930 931 932
	if sct.result == nil {
		return errors.New("get collection statistics resp is nil")
	}
	if sct.result.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(sct.result.Status.Reason)
	}
933 934 935 936 937 938
	return err
}

func (sct *ShowCollectionsTask) PostExecute() error {
	return nil
}
N
neza2017 已提交
939 940 941

type CreatePartitionTask struct {
	Condition
942
	*milvuspb.CreatePartitionRequest
Z
zhenshan.cao 已提交
943
	masterClient MasterClient
N
neza2017 已提交
944 945 946
	result       *commonpb.Status
}

947 948 949 950 951
func (cpt *CreatePartitionTask) OnEnqueue() error {
	cpt.Base = &commonpb.MsgBase{}
	return nil
}

N
neza2017 已提交
952
func (cpt *CreatePartitionTask) ID() UniqueID {
953
	return cpt.Base.MsgID
N
neza2017 已提交
954 955
}

956
func (cpt *CreatePartitionTask) SetID(uid UniqueID) {
957
	cpt.Base.MsgID = uid
958 959
}

960
func (cpt *CreatePartitionTask) Type() commonpb.MsgType {
961
	return cpt.Base.MsgType
N
neza2017 已提交
962 963 964
}

func (cpt *CreatePartitionTask) BeginTs() Timestamp {
965
	return cpt.Base.Timestamp
N
neza2017 已提交
966 967 968
}

func (cpt *CreatePartitionTask) EndTs() Timestamp {
969
	return cpt.Base.Timestamp
N
neza2017 已提交
970 971 972
}

func (cpt *CreatePartitionTask) SetTs(ts Timestamp) {
973
	cpt.Base.Timestamp = ts
N
neza2017 已提交
974 975 976
}

func (cpt *CreatePartitionTask) PreExecute() error {
977
	cpt.Base.MsgType = commonpb.MsgType_kCreatePartition
978
	cpt.Base.SourceID = Params.ProxyID
979

980
	collName, partitionTag := cpt.CollectionName, cpt.PartitionName
N
neza2017 已提交
981 982 983 984 985 986 987 988 989

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidatePartitionTag(partitionTag, true); err != nil {
		return err
	}

N
neza2017 已提交
990 991 992 993
	return nil
}

func (cpt *CreatePartitionTask) Execute() (err error) {
994
	cpt.result, err = cpt.masterClient.CreatePartition(cpt.CreatePartitionRequest)
G
godchen 已提交
995 996 997 998 999 1000
	if cpt.result == nil {
		return errors.New("get collection statistics resp is nil")
	}
	if cpt.result.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(cpt.result.Reason)
	}
N
neza2017 已提交
1001 1002 1003 1004 1005 1006 1007 1008 1009
	return err
}

func (cpt *CreatePartitionTask) PostExecute() error {
	return nil
}

type DropPartitionTask struct {
	Condition
1010
	*milvuspb.DropPartitionRequest
Z
zhenshan.cao 已提交
1011
	masterClient MasterClient
N
neza2017 已提交
1012 1013 1014
	result       *commonpb.Status
}

1015 1016 1017 1018 1019
func (dpt *DropPartitionTask) OnEnqueue() error {
	dpt.Base = &commonpb.MsgBase{}
	return nil
}

N
neza2017 已提交
1020
func (dpt *DropPartitionTask) ID() UniqueID {
1021
	return dpt.Base.MsgID
N
neza2017 已提交
1022 1023
}

1024
func (dpt *DropPartitionTask) SetID(uid UniqueID) {
1025
	dpt.Base.MsgID = uid
1026 1027
}

1028
func (dpt *DropPartitionTask) Type() commonpb.MsgType {
1029
	return dpt.Base.MsgType
N
neza2017 已提交
1030 1031 1032
}

func (dpt *DropPartitionTask) BeginTs() Timestamp {
1033
	return dpt.Base.Timestamp
N
neza2017 已提交
1034 1035 1036
}

func (dpt *DropPartitionTask) EndTs() Timestamp {
1037
	return dpt.Base.Timestamp
N
neza2017 已提交
1038 1039 1040
}

func (dpt *DropPartitionTask) SetTs(ts Timestamp) {
1041
	dpt.Base.Timestamp = ts
N
neza2017 已提交
1042 1043 1044
}

func (dpt *DropPartitionTask) PreExecute() error {
1045
	dpt.Base.MsgType = commonpb.MsgType_kDropPartition
1046
	dpt.Base.SourceID = Params.ProxyID
1047

1048
	collName, partitionTag := dpt.CollectionName, dpt.PartitionName
N
neza2017 已提交
1049 1050 1051 1052 1053 1054 1055 1056 1057

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidatePartitionTag(partitionTag, true); err != nil {
		return err
	}

N
neza2017 已提交
1058 1059 1060 1061
	return nil
}

func (dpt *DropPartitionTask) Execute() (err error) {
1062
	dpt.result, err = dpt.masterClient.DropPartition(dpt.DropPartitionRequest)
G
godchen 已提交
1063 1064 1065 1066 1067 1068
	if dpt.result == nil {
		return errors.New("get collection statistics resp is nil")
	}
	if dpt.result.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(dpt.result.Reason)
	}
N
neza2017 已提交
1069 1070 1071 1072 1073 1074 1075 1076 1077
	return err
}

func (dpt *DropPartitionTask) PostExecute() error {
	return nil
}

type HasPartitionTask struct {
	Condition
1078
	*milvuspb.HasPartitionRequest
Z
zhenshan.cao 已提交
1079
	masterClient MasterClient
1080
	result       *milvuspb.BoolResponse
N
neza2017 已提交
1081 1082
}

1083 1084 1085 1086 1087
func (hpt *HasPartitionTask) OnEnqueue() error {
	hpt.Base = &commonpb.MsgBase{}
	return nil
}

N
neza2017 已提交
1088
func (hpt *HasPartitionTask) ID() UniqueID {
1089
	return hpt.Base.MsgID
N
neza2017 已提交
1090 1091
}

1092
func (hpt *HasPartitionTask) SetID(uid UniqueID) {
1093
	hpt.Base.MsgID = uid
1094 1095
}

1096
func (hpt *HasPartitionTask) Type() commonpb.MsgType {
1097
	return hpt.Base.MsgType
N
neza2017 已提交
1098 1099 1100
}

func (hpt *HasPartitionTask) BeginTs() Timestamp {
1101
	return hpt.Base.Timestamp
N
neza2017 已提交
1102 1103 1104
}

func (hpt *HasPartitionTask) EndTs() Timestamp {
1105
	return hpt.Base.Timestamp
N
neza2017 已提交
1106 1107 1108
}

func (hpt *HasPartitionTask) SetTs(ts Timestamp) {
1109
	hpt.Base.Timestamp = ts
N
neza2017 已提交
1110 1111 1112
}

func (hpt *HasPartitionTask) PreExecute() error {
1113
	hpt.Base.MsgType = commonpb.MsgType_kHasPartition
1114
	hpt.Base.SourceID = Params.ProxyID
1115

1116
	collName, partitionTag := hpt.CollectionName, hpt.PartitionName
N
neza2017 已提交
1117 1118 1119 1120 1121 1122 1123 1124

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidatePartitionTag(partitionTag, true); err != nil {
		return err
	}
N
neza2017 已提交
1125 1126 1127 1128
	return nil
}

func (hpt *HasPartitionTask) Execute() (err error) {
1129
	hpt.result, err = hpt.masterClient.HasPartition(hpt.HasPartitionRequest)
G
godchen 已提交
1130 1131 1132 1133 1134 1135
	if hpt.result == nil {
		return errors.New("get collection statistics resp is nil")
	}
	if hpt.result.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(hpt.result.Status.Reason)
	}
N
neza2017 已提交
1136 1137 1138 1139 1140 1141 1142 1143 1144
	return err
}

func (hpt *HasPartitionTask) PostExecute() error {
	return nil
}

type ShowPartitionsTask struct {
	Condition
1145
	*milvuspb.ShowPartitionRequest
Z
zhenshan.cao 已提交
1146
	masterClient MasterClient
1147
	result       *milvuspb.ShowPartitionResponse
N
neza2017 已提交
1148 1149
}

1150 1151 1152 1153 1154
func (spt *ShowPartitionsTask) OnEnqueue() error {
	spt.Base = &commonpb.MsgBase{}
	return nil
}

N
neza2017 已提交
1155
func (spt *ShowPartitionsTask) ID() UniqueID {
1156
	return spt.Base.MsgID
N
neza2017 已提交
1157 1158
}

1159
func (spt *ShowPartitionsTask) SetID(uid UniqueID) {
1160
	spt.Base.MsgID = uid
1161 1162
}

1163
func (spt *ShowPartitionsTask) Type() commonpb.MsgType {
1164
	return spt.Base.MsgType
N
neza2017 已提交
1165 1166 1167
}

func (spt *ShowPartitionsTask) BeginTs() Timestamp {
1168
	return spt.Base.Timestamp
N
neza2017 已提交
1169 1170 1171
}

func (spt *ShowPartitionsTask) EndTs() Timestamp {
1172
	return spt.Base.Timestamp
N
neza2017 已提交
1173 1174 1175
}

func (spt *ShowPartitionsTask) SetTs(ts Timestamp) {
1176
	spt.Base.Timestamp = ts
N
neza2017 已提交
1177 1178 1179
}

func (spt *ShowPartitionsTask) PreExecute() error {
1180
	spt.Base.MsgType = commonpb.MsgType_kShowPartitions
1181
	spt.Base.SourceID = Params.ProxyID
1182

1183
	if err := ValidateCollectionName(spt.CollectionName); err != nil {
N
neza2017 已提交
1184 1185
		return err
	}
N
neza2017 已提交
1186 1187 1188
	return nil
}

1189
func (spt *ShowPartitionsTask) Execute() error {
1190
	var err error
1191
	spt.result, err = spt.masterClient.ShowPartitions(spt.ShowPartitionRequest)
G
godchen 已提交
1192 1193
	if spt.result == nil {
		return errors.New("get collection statistics resp is nil")
G
godchen 已提交
1194
	}
G
godchen 已提交
1195 1196 1197 1198
	if spt.result.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(spt.result.Status.Reason)
	}
	return err
N
neza2017 已提交
1199 1200 1201 1202 1203
}

func (spt *ShowPartitionsTask) PostExecute() error {
	return nil
}
1204 1205 1206

type CreateIndexTask struct {
	Condition
1207
	*milvuspb.CreateIndexRequest
Z
zhenshan.cao 已提交
1208
	masterClient MasterClient
1209 1210 1211
	result       *commonpb.Status
}

1212 1213 1214 1215 1216
func (cit *CreateIndexTask) OnEnqueue() error {
	cit.Base = &commonpb.MsgBase{}
	return nil
}

1217
func (cit *CreateIndexTask) ID() UniqueID {
1218
	return cit.Base.MsgID
1219 1220 1221
}

func (cit *CreateIndexTask) SetID(uid UniqueID) {
1222
	cit.Base.MsgID = uid
1223 1224
}

1225
func (cit *CreateIndexTask) Type() commonpb.MsgType {
1226
	return cit.Base.MsgType
1227 1228 1229
}

func (cit *CreateIndexTask) BeginTs() Timestamp {
1230
	return cit.Base.Timestamp
1231 1232 1233
}

func (cit *CreateIndexTask) EndTs() Timestamp {
1234
	return cit.Base.Timestamp
1235 1236 1237
}

func (cit *CreateIndexTask) SetTs(ts Timestamp) {
1238
	cit.Base.Timestamp = ts
1239 1240 1241
}

func (cit *CreateIndexTask) PreExecute() error {
1242
	cit.Base.MsgType = commonpb.MsgType_kCreateIndex
1243
	cit.Base.SourceID = Params.ProxyID
1244

1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257
	collName, fieldName := cit.CollectionName, cit.FieldName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidateFieldName(fieldName); err != nil {
		return err
	}

	return nil
}

G
godchen 已提交
1258 1259
func (cit *CreateIndexTask) Execute() error {
	var err error
1260
	cit.result, err = cit.masterClient.CreateIndex(cit.CreateIndexRequest)
G
godchen 已提交
1261 1262 1263 1264 1265 1266
	if cit.result == nil {
		return errors.New("get collection statistics resp is nil")
	}
	if cit.result.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(cit.result.Reason)
	}
1267 1268 1269 1270 1271 1272 1273 1274 1275
	return err
}

func (cit *CreateIndexTask) PostExecute() error {
	return nil
}

type DescribeIndexTask struct {
	Condition
1276
	*milvuspb.DescribeIndexRequest
Z
zhenshan.cao 已提交
1277
	masterClient MasterClient
1278
	result       *milvuspb.DescribeIndexResponse
1279 1280
}

1281 1282 1283 1284 1285
func (dit *DescribeIndexTask) OnEnqueue() error {
	dit.Base = &commonpb.MsgBase{}
	return nil
}

1286
func (dit *DescribeIndexTask) ID() UniqueID {
1287
	return dit.Base.MsgID
1288 1289 1290
}

func (dit *DescribeIndexTask) SetID(uid UniqueID) {
1291
	dit.Base.MsgID = uid
1292 1293
}

1294
func (dit *DescribeIndexTask) Type() commonpb.MsgType {
1295
	return dit.Base.MsgType
1296 1297 1298
}

func (dit *DescribeIndexTask) BeginTs() Timestamp {
1299
	return dit.Base.Timestamp
1300 1301 1302
}

func (dit *DescribeIndexTask) EndTs() Timestamp {
1303
	return dit.Base.Timestamp
1304 1305 1306
}

func (dit *DescribeIndexTask) SetTs(ts Timestamp) {
1307
	dit.Base.Timestamp = ts
1308 1309 1310
}

func (dit *DescribeIndexTask) PreExecute() error {
1311
	dit.Base.MsgType = commonpb.MsgType_kDescribeIndex
1312
	dit.Base.SourceID = Params.ProxyID
1313

1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326
	collName, fieldName := dit.CollectionName, dit.FieldName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidateFieldName(fieldName); err != nil {
		return err
	}

	return nil
}

1327
func (dit *DescribeIndexTask) Execute() error {
1328
	var err error
1329
	dit.result, err = dit.masterClient.DescribeIndex(dit.DescribeIndexRequest)
G
godchen 已提交
1330 1331 1332 1333 1334 1335
	if dit.result == nil {
		return errors.New("get collection statistics resp is nil")
	}
	if dit.result.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
		return errors.New(dit.result.Status.Reason)
	}
1336 1337 1338 1339 1340 1341 1342
	return err
}

func (dit *DescribeIndexTask) PostExecute() error {
	return nil
}

1343
type GetIndexStateTask struct {
1344
	Condition
1345
	*milvuspb.IndexStateRequest
Z
zhenshan.cao 已提交
1346 1347 1348
	indexServiceClient    IndexServiceClient
	masterClientInterface MasterClient
	result                *milvuspb.IndexStateResponse
1349 1350
}

1351 1352 1353 1354 1355 1356
func (dipt *GetIndexStateTask) OnEnqueue() error {
	dipt.Base = &commonpb.MsgBase{}
	return nil
}

func (dipt *GetIndexStateTask) ID() UniqueID {
1357
	return dipt.Base.MsgID
1358 1359
}

1360
func (dipt *GetIndexStateTask) SetID(uid UniqueID) {
1361
	dipt.Base.MsgID = uid
1362 1363
}

1364
func (dipt *GetIndexStateTask) Type() commonpb.MsgType {
1365
	return dipt.Base.MsgType
1366 1367
}

1368
func (dipt *GetIndexStateTask) BeginTs() Timestamp {
1369
	return dipt.Base.Timestamp
1370 1371
}

1372
func (dipt *GetIndexStateTask) EndTs() Timestamp {
1373
	return dipt.Base.Timestamp
1374 1375
}

1376
func (dipt *GetIndexStateTask) SetTs(ts Timestamp) {
1377
	dipt.Base.Timestamp = ts
1378 1379
}

1380 1381
func (dipt *GetIndexStateTask) PreExecute() error {
	dipt.Base.MsgType = commonpb.MsgType_kGetIndexState
1382
	dipt.Base.SourceID = Params.ProxyID
1383

1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396
	collName, fieldName := dipt.CollectionName, dipt.FieldName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidateFieldName(fieldName); err != nil {
		return err
	}

	return nil
}

1397
func (dipt *GetIndexStateTask) Execute() error {
Z
zhenshan.cao 已提交
1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481
	collectionName := dipt.CollectionName
	collectionID, err := globalMetaCache.GetCollectionID(collectionName)
	if err != nil { // err is not nil if collection not exists
		return err
	}

	showPartitionRequest := &milvuspb.ShowPartitionRequest{
		Base: &commonpb.MsgBase{
			MsgType:   commonpb.MsgType_kShowPartitions,
			MsgID:     dipt.Base.MsgID,
			Timestamp: dipt.Base.Timestamp,
			SourceID:  Params.ProxyID,
		},
		DbName:         dipt.DbName,
		CollectionName: collectionName,
		CollectionID:   collectionID,
	}
	partitions, err := dipt.masterClientInterface.ShowPartitions(showPartitionRequest)
	if err != nil {
		return err
	}

	for _, partitionID := range partitions.PartitionIDs {
		showSegmentsRequest := &milvuspb.ShowSegmentRequest{
			Base: &commonpb.MsgBase{
				MsgType:   commonpb.MsgType_kShowSegment,
				MsgID:     dipt.Base.MsgID,
				Timestamp: dipt.Base.Timestamp,
				SourceID:  Params.ProxyID,
			},
			CollectionID: collectionID,
			PartitionID:  partitionID,
		}
		segments, err := dipt.masterClientInterface.ShowSegments(showSegmentsRequest)
		if err != nil {
			return err
		}

		getIndexStatesRequest := &indexpb.IndexStatesRequest{
			IndexBuildIDs: make([]UniqueID, 0),
		}
		for _, segmentID := range segments.SegmentIDs {
			describeSegmentRequest := &milvuspb.DescribeSegmentRequest{
				Base: &commonpb.MsgBase{
					MsgType:   commonpb.MsgType_kDescribeSegment,
					MsgID:     dipt.Base.MsgID,
					Timestamp: dipt.Base.Timestamp,
					SourceID:  Params.ProxyID,
				},
				CollectionID: collectionID,
				SegmentID:    segmentID,
			}
			segmentDesc, err := dipt.masterClientInterface.DescribeSegment(describeSegmentRequest)
			if err != nil {
				return err
			}

			getIndexStatesRequest.IndexBuildIDs = append(getIndexStatesRequest.IndexBuildIDs, segmentDesc.BuildID)
		}

		states, err := dipt.indexServiceClient.GetIndexStates(getIndexStatesRequest)
		if err != nil {
			return err
		}

		if states.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
			dipt.result = &milvuspb.IndexStateResponse{
				Status: states.Status,
				State:  commonpb.IndexState_FAILED,
			}
			return nil
		}

		for _, state := range states.States {
			if state.State != commonpb.IndexState_FINISHED {
				dipt.result = &milvuspb.IndexStateResponse{
					Status: states.Status,
					State:  state.State,
				}
				return nil
			}
		}
	}

1482 1483
	dipt.result = &milvuspb.IndexStateResponse{
		Status: &commonpb.Status{
Z
zhenshan.cao 已提交
1484
			ErrorCode: commonpb.ErrorCode_SUCCESS,
1485 1486 1487 1488
			Reason:    "",
		},
		State: commonpb.IndexState_FINISHED,
	}
Z
zhenshan.cao 已提交
1489

1490
	return nil
1491 1492
}

1493
func (dipt *GetIndexStateTask) PostExecute() error {
1494 1495
	return nil
}
Z
zhenshan.cao 已提交
1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539

type FlushTask struct {
	Condition
	*milvuspb.FlushRequest
	dataServiceClient DataServiceClient
	result            *commonpb.Status
}

func (ft *FlushTask) OnEnqueue() error {
	ft.Base = &commonpb.MsgBase{}
	return nil
}

func (ft *FlushTask) ID() UniqueID {
	return ft.Base.MsgID
}

func (ft *FlushTask) SetID(uid UniqueID) {
	ft.Base.MsgID = uid
}

func (ft *FlushTask) Type() commonpb.MsgType {
	return ft.Base.MsgType
}

func (ft *FlushTask) BeginTs() Timestamp {
	return ft.Base.Timestamp
}

func (ft *FlushTask) EndTs() Timestamp {
	return ft.Base.Timestamp
}

func (ft *FlushTask) SetTs(ts Timestamp) {
	ft.Base.Timestamp = ts
}

func (ft *FlushTask) PreExecute() error {
	ft.Base.MsgType = commonpb.MsgType_kFlush
	ft.Base.SourceID = Params.ProxyID
	return nil
}

func (ft *FlushTask) Execute() error {
1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555
	for _, collName := range ft.CollectionNames {
		collID, err := globalMetaCache.GetCollectionID(collName)
		if err != nil {
			return err
		}
		flushReq := &datapb.FlushRequest{
			Base: &commonpb.MsgBase{
				MsgType:   commonpb.MsgType_kFlush,
				MsgID:     ft.Base.MsgID,
				Timestamp: ft.Base.Timestamp,
				SourceID:  ft.Base.SourceID,
			},
			DbID:         0,
			CollectionID: collID,
		}
		var status *commonpb.Status
G
godchen 已提交
1556 1557 1558
		status, _ = ft.dataServiceClient.Flush(flushReq)
		if status == nil {
			return errors.New("flush resp is nil")
1559 1560 1561 1562
		}
		if status.ErrorCode != commonpb.ErrorCode_SUCCESS {
			return errors.New(status.Reason)
		}
Z
zhenshan.cao 已提交
1563
	}
1564 1565
	ft.result = &commonpb.Status{
		ErrorCode: commonpb.ErrorCode_SUCCESS,
Z
zhenshan.cao 已提交
1566
	}
1567
	return nil
Z
zhenshan.cao 已提交
1568 1569 1570 1571 1572
}

func (ft *FlushTask) PostExecute() error {
	return nil
}
1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627

type LoadCollectionTask struct {
	Condition
	*milvuspb.LoadCollectionRequest
	queryserviceClient QueryServiceClient
	result             *commonpb.Status
}

func (lct *LoadCollectionTask) OnEnqueue() error {
	lct.Base = &commonpb.MsgBase{}
	return nil
}

func (lct *LoadCollectionTask) ID() UniqueID {
	return lct.Base.MsgID
}

func (lct *LoadCollectionTask) SetID(uid UniqueID) {
	lct.Base.MsgID = uid
}

func (lct *LoadCollectionTask) Type() commonpb.MsgType {
	return lct.Base.MsgType
}

func (lct *LoadCollectionTask) BeginTs() Timestamp {
	return lct.Base.Timestamp
}

func (lct *LoadCollectionTask) EndTs() Timestamp {
	return lct.Base.Timestamp
}

func (lct *LoadCollectionTask) SetTs(ts Timestamp) {
	lct.Base.Timestamp = ts
}

func (lct *LoadCollectionTask) PreExecute() error {
	lct.Base.MsgType = commonpb.MsgType_kLoadCollection
	lct.Base.SourceID = Params.ProxyID

	collName := lct.CollectionName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	return nil
}

func (lct *LoadCollectionTask) Execute() (err error) {
	collID, err := globalMetaCache.GetCollectionID(lct.CollectionName)
	if err != nil {
		return err
	}
1628 1629 1630 1631 1632
	collSchema, err := globalMetaCache.GetCollectionSchema(lct.CollectionName)
	if err != nil {
		return err
	}

1633 1634 1635 1636 1637 1638 1639 1640 1641
	request := &querypb.LoadCollectionRequest{
		Base: &commonpb.MsgBase{
			MsgType:   commonpb.MsgType_kLoadCollection,
			MsgID:     lct.Base.MsgID,
			Timestamp: lct.Base.Timestamp,
			SourceID:  lct.Base.SourceID,
		},
		DbID:         0,
		CollectionID: collID,
1642
		Schema:       collSchema,
1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778
	}
	lct.result, err = lct.queryserviceClient.LoadCollection(request)
	return err
}

func (lct *LoadCollectionTask) PostExecute() error {
	return nil
}

type ReleaseCollectionTask struct {
	Condition
	*milvuspb.ReleaseCollectionRequest
	queryserviceClient QueryServiceClient
	result             *commonpb.Status
}

func (rct *ReleaseCollectionTask) OnEnqueue() error {
	rct.Base = &commonpb.MsgBase{}
	return nil
}

func (rct *ReleaseCollectionTask) ID() UniqueID {
	return rct.Base.MsgID
}

func (rct *ReleaseCollectionTask) SetID(uid UniqueID) {
	rct.Base.MsgID = uid
}

func (rct *ReleaseCollectionTask) Type() commonpb.MsgType {
	return rct.Base.MsgType
}

func (rct *ReleaseCollectionTask) BeginTs() Timestamp {
	return rct.Base.Timestamp
}

func (rct *ReleaseCollectionTask) EndTs() Timestamp {
	return rct.Base.Timestamp
}

func (rct *ReleaseCollectionTask) SetTs(ts Timestamp) {
	rct.Base.Timestamp = ts
}

func (rct *ReleaseCollectionTask) PreExecute() error {
	rct.Base.MsgType = commonpb.MsgType_kReleaseCollection
	rct.Base.SourceID = Params.ProxyID

	collName := rct.CollectionName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	return nil
}

func (rct *ReleaseCollectionTask) Execute() (err error) {
	collID, err := globalMetaCache.GetCollectionID(rct.CollectionName)
	if err != nil {
		return err
	}
	request := &querypb.ReleaseCollectionRequest{
		Base: &commonpb.MsgBase{
			MsgType:   commonpb.MsgType_kReleaseCollection,
			MsgID:     rct.Base.MsgID,
			Timestamp: rct.Base.Timestamp,
			SourceID:  rct.Base.SourceID,
		},
		DbID:         0,
		CollectionID: collID,
	}
	rct.result, err = rct.queryserviceClient.ReleaseCollection(request)
	return err
}

func (rct *ReleaseCollectionTask) PostExecute() error {
	return nil
}

type LoadPartitionTask struct {
	Condition
	*milvuspb.LoadPartitonRequest
	queryserviceClient QueryServiceClient
	result             *commonpb.Status
}

func (lpt *LoadPartitionTask) OnEnqueue() error {
	lpt.Base = &commonpb.MsgBase{}
	return nil
}

func (lpt *LoadPartitionTask) ID() UniqueID {
	return lpt.Base.MsgID
}

func (lpt *LoadPartitionTask) SetID(uid UniqueID) {
	lpt.Base.MsgID = uid
}

func (lpt *LoadPartitionTask) Type() commonpb.MsgType {
	return lpt.Base.MsgType
}

func (lpt *LoadPartitionTask) BeginTs() Timestamp {
	return lpt.Base.Timestamp
}

func (lpt *LoadPartitionTask) EndTs() Timestamp {
	return lpt.Base.Timestamp
}

func (lpt *LoadPartitionTask) SetTs(ts Timestamp) {
	lpt.Base.Timestamp = ts
}

func (lpt *LoadPartitionTask) PreExecute() error {
	lpt.Base.MsgType = commonpb.MsgType_kLoadPartition
	lpt.Base.SourceID = Params.ProxyID

	collName := lpt.CollectionName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	return nil
}

func (lpt *LoadPartitionTask) Execute() (err error) {
	var partitionIDs []int64
	collID, err := globalMetaCache.GetCollectionID(lpt.CollectionName)
	if err != nil {
		return err
	}
1779 1780 1781 1782
	collSchema, err := globalMetaCache.GetCollectionSchema(lpt.CollectionName)
	if err != nil {
		return err
	}
1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799
	for _, partitionName := range lpt.PartitionNames {
		partitionID, err := globalMetaCache.GetPartitionID(lpt.CollectionName, partitionName)
		if err != nil {
			return err
		}
		partitionIDs = append(partitionIDs, partitionID)
	}
	request := &querypb.LoadPartitionRequest{
		Base: &commonpb.MsgBase{
			MsgType:   commonpb.MsgType_kLoadPartition,
			MsgID:     lpt.Base.MsgID,
			Timestamp: lpt.Base.Timestamp,
			SourceID:  lpt.Base.SourceID,
		},
		DbID:         0,
		CollectionID: collID,
		PartitionIDs: partitionIDs,
1800
		Schema:       collSchema,
1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889
	}
	lpt.result, err = lpt.queryserviceClient.LoadPartitions(request)
	return err
}

func (lpt *LoadPartitionTask) PostExecute() error {
	return nil
}

type ReleasePartitionTask struct {
	Condition
	*milvuspb.ReleasePartitionRequest
	queryserviceClient QueryServiceClient
	result             *commonpb.Status
}

func (rpt *ReleasePartitionTask) OnEnqueue() error {
	rpt.Base = &commonpb.MsgBase{}
	return nil
}

func (rpt *ReleasePartitionTask) ID() UniqueID {
	return rpt.Base.MsgID
}

func (rpt *ReleasePartitionTask) SetID(uid UniqueID) {
	rpt.Base.MsgID = uid
}

func (rpt *ReleasePartitionTask) Type() commonpb.MsgType {
	return rpt.Base.MsgType
}

func (rpt *ReleasePartitionTask) BeginTs() Timestamp {
	return rpt.Base.Timestamp
}

func (rpt *ReleasePartitionTask) EndTs() Timestamp {
	return rpt.Base.Timestamp
}

func (rpt *ReleasePartitionTask) SetTs(ts Timestamp) {
	rpt.Base.Timestamp = ts
}

func (rpt *ReleasePartitionTask) PreExecute() error {
	rpt.Base.MsgType = commonpb.MsgType_kReleasePartition
	rpt.Base.SourceID = Params.ProxyID

	collName := rpt.CollectionName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	return nil
}

func (rpt *ReleasePartitionTask) Execute() (err error) {
	var partitionIDs []int64
	collID, err := globalMetaCache.GetCollectionID(rpt.CollectionName)
	if err != nil {
		return err
	}
	for _, partitionName := range rpt.PartitionNames {
		partitionID, err := globalMetaCache.GetPartitionID(rpt.CollectionName, partitionName)
		if err != nil {
			return err
		}
		partitionIDs = append(partitionIDs, partitionID)
	}
	request := &querypb.ReleasePartitionRequest{
		Base: &commonpb.MsgBase{
			MsgType:   commonpb.MsgType_kReleasePartition,
			MsgID:     rpt.Base.MsgID,
			Timestamp: rpt.Base.Timestamp,
			SourceID:  rpt.Base.SourceID,
		},
		DbID:         0,
		CollectionID: collID,
		PartitionIDs: partitionIDs,
	}
	rpt.result, err = rpt.queryserviceClient.ReleasePartitions(request)
	return err
}

func (rpt *ReleasePartitionTask) PostExecute() error {
	return nil
}