task.go 28.1 KB
Newer Older
Z
zhenshan.cao 已提交
1 2 3
package proxy

import (
4 5
	"context"
	"errors"
Z
zhenshan.cao 已提交
6
	"log"
N
neza2017 已提交
7 8 9
	"math"
	"strconv"

G
godchen 已提交
10 11 12
	"github.com/opentracing/opentracing-go"
	oplog "github.com/opentracing/opentracing-go/log"

N
neza2017 已提交
13
	"github.com/golang/protobuf/proto"
14
	"github.com/zilliztech/milvus-distributed/internal/allocator"
15 16
	"github.com/zilliztech/milvus-distributed/internal/msgstream"
	"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
Z
zhenshan.cao 已提交
17
	"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
18
	"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
N
neza2017 已提交
19
	"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
20
	"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
C
cai.zhang 已提交
21
	"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
Z
zhenshan.cao 已提交
22 23 24
)

type task interface {
25 26
	ID() UniqueID       // return ReqID
	SetID(uid UniqueID) // set ReqID
N
neza2017 已提交
27
	Type() internalpb.MsgType
28 29
	BeginTs() Timestamp
	EndTs() Timestamp
Z
zhenshan.cao 已提交
30
	SetTs(ts Timestamp)
Z
zhenshan.cao 已提交
31 32 33 34
	PreExecute() error
	Execute() error
	PostExecute() error
	WaitToFinish() error
35
	Notify(err error)
Z
zhenshan.cao 已提交
36 37
}

38
type BaseInsertTask = msgstream.InsertMsg
39 40

type InsertTask struct {
41
	BaseInsertTask
D
dragondriver 已提交
42
	Condition
43
	result                *servicepb.IntegerRangeResponse
44 45
	manipulationMsgStream *msgstream.PulsarMsgStream
	ctx                   context.Context
46
	rowIDAllocator        *allocator.IDAllocator
47 48
}

49 50 51 52
func (it *InsertTask) SetID(uid UniqueID) {
	it.ReqID = uid
}

53
func (it *InsertTask) SetTs(ts Timestamp) {
N
neza2017 已提交
54 55 56 57 58 59 60
	rowNum := len(it.RowData)
	it.Timestamps = make([]uint64, rowNum)
	for index := range it.Timestamps {
		it.Timestamps[index] = ts
	}
	it.BeginTimestamp = ts
	it.EndTimestamp = ts
61 62 63
}

func (it *InsertTask) BeginTs() Timestamp {
N
neza2017 已提交
64
	return it.BeginTimestamp
65 66 67
}

func (it *InsertTask) EndTs() Timestamp {
N
neza2017 已提交
68
	return it.EndTimestamp
69 70
}

C
cai.zhang 已提交
71
func (it *InsertTask) ID() UniqueID {
72
	return it.ReqID
73 74 75 76 77 78 79
}

func (it *InsertTask) Type() internalpb.MsgType {
	return it.MsgType
}

func (it *InsertTask) PreExecute() error {
G
godchen 已提交
80 81 82 83 84
	span, ctx := opentracing.StartSpanFromContext(it.ctx, "InsertTask preExecute")
	defer span.Finish()
	it.ctx = ctx
	span.SetTag("hash keys", it.ReqID)
	span.SetTag("start time", it.BeginTs())
N
neza2017 已提交
85 86
	collectionName := it.BaseInsertTask.CollectionName
	if err := ValidateCollectionName(collectionName); err != nil {
G
godchen 已提交
87 88
		span.LogFields(oplog.Error(err))
		span.Finish()
N
neza2017 已提交
89 90 91 92
		return err
	}
	partitionTag := it.BaseInsertTask.PartitionTag
	if err := ValidatePartitionTag(partitionTag, true); err != nil {
G
godchen 已提交
93 94
		span.LogFields(oplog.Error(err))
		span.Finish()
N
neza2017 已提交
95 96 97
		return err
	}

98 99 100 101
	return nil
}

func (it *InsertTask) Execute() error {
G
godchen 已提交
102 103 104 105 106
	span, ctx := opentracing.StartSpanFromContext(it.ctx, "InsertTask Execute")
	defer span.Finish()
	it.ctx = ctx
	span.SetTag("hash keys", it.ReqID)
	span.SetTag("start time", it.BeginTs())
107
	collectionName := it.BaseInsertTask.CollectionName
G
godchen 已提交
108
	span.LogFields(oplog.String("collection_name", collectionName))
109
	if !globalMetaCache.Hit(collectionName) {
G
godchen 已提交
110
		err := globalMetaCache.Sync(collectionName)
111
		if err != nil {
G
godchen 已提交
112 113
			span.LogFields(oplog.Error(err))
			span.Finish()
114 115 116 117 118
			return err
		}
	}
	description, err := globalMetaCache.Get(collectionName)
	if err != nil || description == nil {
G
godchen 已提交
119 120
		span.LogFields(oplog.Error(err))
		span.Finish()
121 122 123
		return err
	}
	autoID := description.Schema.AutoID
G
godchen 已提交
124
	span.LogFields(oplog.Bool("auto_id", autoID))
125 126
	var rowIDBegin UniqueID
	var rowIDEnd UniqueID
G
godchen 已提交
127 128
	rowNums := len(it.BaseInsertTask.RowData)
	rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums))
G
godchen 已提交
129 130 131
	span.LogFields(oplog.Int("rowNums", rowNums),
		oplog.Int("rowIDBegin", int(rowIDBegin)),
		oplog.Int("rowIDEnd", int(rowIDEnd)))
G
godchen 已提交
132 133 134 135 136 137 138
	it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums)
	for i := rowIDBegin; i < rowIDEnd; i++ {
		offset := i - rowIDBegin
		it.BaseInsertTask.RowIDs[offset] = i
	}

	if autoID {
N
neza2017 已提交
139 140 141
		if it.HashValues == nil || len(it.HashValues) == 0 {
			it.HashValues = make([]uint32, 0)
		}
G
godchen 已提交
142 143
		for _, rowID := range it.RowIDs {
			hashValue, _ := typeutil.Hash32Int64(rowID)
N
neza2017 已提交
144
			it.HashValues = append(it.HashValues, hashValue)
145 146 147
		}
	}

148
	var tsMsg msgstream.TsMsg = &it.BaseInsertTask
149 150 151
	msgPack := &msgstream.MsgPack{
		BeginTs: it.BeginTs(),
		EndTs:   it.EndTs(),
X
xige-16 已提交
152
		Msgs:    make([]msgstream.TsMsg, 1),
153
	}
G
godchen 已提交
154 155
	tsMsg.SetMsgContext(ctx)
	span.LogFields(oplog.String("send msg", "send msg"))
X
xige-16 已提交
156
	msgPack.Msgs[0] = tsMsg
157
	err = it.manipulationMsgStream.Produce(msgPack)
G
godchen 已提交
158

159 160 161 162
	it.result = &servicepb.IntegerRangeResponse{
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_SUCCESS,
		},
163 164
		Begin: rowIDBegin,
		End:   rowIDEnd,
165 166 167 168
	}
	if err != nil {
		it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
		it.result.Status.Reason = err.Error()
G
godchen 已提交
169
		span.LogFields(oplog.Error(err))
170
	}
171 172 173 174
	return nil
}

func (it *InsertTask) PostExecute() error {
G
godchen 已提交
175 176
	span, _ := opentracing.StartSpanFromContext(it.ctx, "InsertTask postExecute")
	defer span.Finish()
177 178 179 180
	return nil
}

type CreateCollectionTask struct {
D
dragondriver 已提交
181
	Condition
182 183
	internalpb.CreateCollectionRequest
	masterClient masterpb.MasterClient
184
	result       *commonpb.Status
Z
zhenshan.cao 已提交
185
	ctx          context.Context
N
neza2017 已提交
186
	schema       *schemapb.CollectionSchema
187 188
}

C
cai.zhang 已提交
189
func (cct *CreateCollectionTask) ID() UniqueID {
190
	return cct.ReqID
191 192
}

193 194 195 196
func (cct *CreateCollectionTask) SetID(uid UniqueID) {
	cct.ReqID = uid
}

197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
func (cct *CreateCollectionTask) Type() internalpb.MsgType {
	return cct.MsgType
}

func (cct *CreateCollectionTask) BeginTs() Timestamp {
	return cct.Timestamp
}

func (cct *CreateCollectionTask) EndTs() Timestamp {
	return cct.Timestamp
}

func (cct *CreateCollectionTask) SetTs(ts Timestamp) {
	cct.Timestamp = ts
}

func (cct *CreateCollectionTask) PreExecute() error {
N
neza2017 已提交
214 215 216 217 218 219 220 221 222
	if int64(len(cct.schema.Fields)) > Params.MaxFieldNum() {
		return errors.New("maximum field's number should be limited to " + strconv.FormatInt(Params.MaxFieldNum(), 10))
	}

	// validate collection name
	if err := ValidateCollectionName(cct.schema.Name); err != nil {
		return err
	}

N
neza2017 已提交
223 224 225 226 227 228 229 230
	if err := ValidateDuplicatedFieldName(cct.schema.Fields); err != nil {
		return err
	}

	if err := ValidatePrimaryKey(cct.schema); err != nil {
		return err
	}

N
neza2017 已提交
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
	// validate field name
	for _, field := range cct.schema.Fields {
		if err := ValidateFieldName(field.Name); err != nil {
			return err
		}
		if field.DataType == schemapb.DataType_VECTOR_FLOAT || field.DataType == schemapb.DataType_VECTOR_BINARY {
			exist := false
			var dim int64 = 0
			for _, param := range field.TypeParams {
				if param.Key == "dim" {
					exist = true
					tmp, err := strconv.ParseInt(param.Value, 10, 64)
					if err != nil {
						return err
					}
					dim = tmp
					break
				}
			}
			if !exist {
				return errors.New("dimension is not defined in field type params")
			}
			if field.DataType == schemapb.DataType_VECTOR_FLOAT {
				if err := ValidateDimension(dim, false); err != nil {
					return err
				}
			} else {
				if err := ValidateDimension(dim, true); err != nil {
					return err
				}
			}
		}
N
neza2017 已提交
263 264 265
		if err := ValidateVectorFieldMetricType(field); err != nil {
			return err
		}
N
neza2017 已提交
266 267
	}

268
	return nil
Z
zhenshan.cao 已提交
269 270
}

271
func (cct *CreateCollectionTask) Execute() error {
N
neza2017 已提交
272 273
	schemaBytes, _ := proto.Marshal(cct.schema)
	cct.CreateCollectionRequest.Schema.Value = schemaBytes
274 275 276
	resp, err := cct.masterClient.CreateCollection(cct.ctx, &cct.CreateCollectionRequest)
	if err != nil {
		log.Printf("create collection failed, error= %v", err)
277
		cct.result = &commonpb.Status{
278
			ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Z
zhenshan.cao 已提交
279
			Reason:    err.Error(),
280 281
		}
	} else {
282
		cct.result = resp
283 284
	}
	return err
Z
zhenshan.cao 已提交
285 286
}

287 288
func (cct *CreateCollectionTask) PostExecute() error {
	return nil
Z
zhenshan.cao 已提交
289 290
}

291
type DropCollectionTask struct {
D
dragondriver 已提交
292
	Condition
293 294
	internalpb.DropCollectionRequest
	masterClient masterpb.MasterClient
295
	result       *commonpb.Status
296 297 298
	ctx          context.Context
}

C
cai.zhang 已提交
299
func (dct *DropCollectionTask) ID() UniqueID {
300
	return dct.ReqID
301 302
}

303 304 305 306
func (dct *DropCollectionTask) SetID(uid UniqueID) {
	dct.ReqID = uid
}

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
func (dct *DropCollectionTask) Type() internalpb.MsgType {
	return dct.MsgType
}

func (dct *DropCollectionTask) BeginTs() Timestamp {
	return dct.Timestamp
}

func (dct *DropCollectionTask) EndTs() Timestamp {
	return dct.Timestamp
}

func (dct *DropCollectionTask) SetTs(ts Timestamp) {
	dct.Timestamp = ts
}

func (dct *DropCollectionTask) PreExecute() error {
N
neza2017 已提交
324 325 326
	if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil {
		return err
	}
327 328 329 330 331 332 333
	return nil
}

func (dct *DropCollectionTask) Execute() error {
	resp, err := dct.masterClient.DropCollection(dct.ctx, &dct.DropCollectionRequest)
	if err != nil {
		log.Printf("drop collection failed, error= %v", err)
334
		dct.result = &commonpb.Status{
335 336 337 338
			ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
			Reason:    err.Error(),
		}
	} else {
339
		dct.result = resp
340 341 342 343 344
	}
	return err
}

func (dct *DropCollectionTask) PostExecute() error {
345 346 347
	if globalMetaCache.Hit(dct.CollectionName.CollectionName) {
		return globalMetaCache.Remove(dct.CollectionName.CollectionName)
	}
Z
zhenshan.cao 已提交
348
	return nil
349 350
}

351
type QueryTask struct {
D
dragondriver 已提交
352
	Condition
353 354 355
	internalpb.SearchRequest
	queryMsgStream *msgstream.PulsarMsgStream
	resultBuf      chan []*internalpb.SearchResult
356
	result         *servicepb.QueryResult
357
	ctx            context.Context
N
neza2017 已提交
358
	query          *servicepb.Query
359 360
}

C
cai.zhang 已提交
361
func (qt *QueryTask) ID() UniqueID {
362
	return qt.ReqID
363 364
}

365 366 367 368
func (qt *QueryTask) SetID(uid UniqueID) {
	qt.ReqID = uid
}

369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
func (qt *QueryTask) Type() internalpb.MsgType {
	return qt.MsgType
}

func (qt *QueryTask) BeginTs() Timestamp {
	return qt.Timestamp
}

func (qt *QueryTask) EndTs() Timestamp {
	return qt.Timestamp
}

func (qt *QueryTask) SetTs(ts Timestamp) {
	qt.Timestamp = ts
}

func (qt *QueryTask) PreExecute() error {
G
godchen 已提交
386 387 388 389 390 391
	span, ctx := opentracing.StartSpanFromContext(qt.ctx, "QueryTask preExecute")
	defer span.Finish()
	qt.ctx = ctx
	span.SetTag("hash keys", qt.ReqID)
	span.SetTag("start time", qt.BeginTs())

392 393
	collectionName := qt.query.CollectionName
	if !globalMetaCache.Hit(collectionName) {
G
godchen 已提交
394
		err := globalMetaCache.Sync(collectionName)
395
		if err != nil {
G
godchen 已提交
396 397
			span.LogFields(oplog.Error(err))
			span.Finish()
398 399 400 401 402
			return err
		}
	}
	_, err := globalMetaCache.Get(collectionName)
	if err != nil { // err is not nil if collection not exists
G
godchen 已提交
403 404
		span.LogFields(oplog.Error(err))
		span.Finish()
405 406 407
		return err
	}

N
neza2017 已提交
408
	if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
G
godchen 已提交
409 410
		span.LogFields(oplog.Error(err))
		span.Finish()
N
neza2017 已提交
411 412 413 414 415
		return err
	}

	for _, tag := range qt.query.PartitionTags {
		if err := ValidatePartitionTag(tag, false); err != nil {
G
godchen 已提交
416 417
			span.LogFields(oplog.Error(err))
			span.Finish()
N
neza2017 已提交
418 419 420
			return err
		}
	}
C
cai.zhang 已提交
421 422 423 424 425 426
	qt.MsgType = internalpb.MsgType_kSearch
	if qt.query.PartitionTags == nil || len(qt.query.PartitionTags) <= 0 {
		qt.query.PartitionTags = []string{Params.defaultPartitionTag()}
	}
	queryBytes, err := proto.Marshal(qt.query)
	if err != nil {
G
godchen 已提交
427 428
		span.LogFields(oplog.Error(err))
		span.Finish()
C
cai.zhang 已提交
429 430 431 432 433
		return err
	}
	qt.Query = &commonpb.Blob{
		Value: queryBytes,
	}
434 435 436 437
	return nil
}

func (qt *QueryTask) Execute() error {
G
godchen 已提交
438 439 440 441 442
	span, ctx := opentracing.StartSpanFromContext(qt.ctx, "QueryTask Execute")
	defer span.Finish()
	qt.ctx = ctx
	span.SetTag("hash keys", qt.ReqID)
	span.SetTag("start time", qt.BeginTs())
443 444 445
	var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{
		SearchRequest: qt.SearchRequest,
		BaseMsg: msgstream.BaseMsg{
N
neza2017 已提交
446
			HashValues:     []uint32{uint32(Params.ProxyID())},
447 448 449 450 451 452 453
			BeginTimestamp: qt.Timestamp,
			EndTimestamp:   qt.Timestamp,
		},
	}
	msgPack := &msgstream.MsgPack{
		BeginTs: qt.Timestamp,
		EndTs:   qt.Timestamp,
X
xige-16 已提交
454
		Msgs:    make([]msgstream.TsMsg, 1),
455
	}
G
godchen 已提交
456
	tsMsg.SetMsgContext(ctx)
X
xige-16 已提交
457
	msgPack.Msgs[0] = tsMsg
C
cai.zhang 已提交
458 459 460
	err := qt.queryMsgStream.Produce(msgPack)
	log.Printf("[Proxy] length of searchMsg: %v", len(msgPack.Msgs))
	if err != nil {
G
godchen 已提交
461 462
		span.LogFields(oplog.Error(err))
		span.Finish()
C
cai.zhang 已提交
463 464 465
		log.Printf("[Proxy] send search request failed: %v", err)
	}
	return err
466 467 468
}

func (qt *QueryTask) PostExecute() error {
G
godchen 已提交
469 470 471 472
	span, _ := opentracing.StartSpanFromContext(qt.ctx, "QueryTask postExecute")
	defer span.Finish()
	span.SetTag("hash keys", qt.ReqID)
	span.SetTag("start time", qt.BeginTs())
473 474 475 476
	for {
		select {
		case <-qt.ctx.Done():
			log.Print("wait to finish failed, timeout!")
G
godchen 已提交
477
			span.LogFields(oplog.String("wait to finish failed, timeout", "wait to finish failed, timeout"))
C
cai.zhang 已提交
478
			return errors.New("wait to finish failed, timeout")
479
		case searchResults := <-qt.resultBuf:
G
godchen 已提交
480
			span.LogFields(oplog.String("receive result", "receive result"))
481
			filterSearchResult := make([]*internalpb.SearchResult, 0)
Z
zhenshan.cao 已提交
482
			var filterReason string
483 484 485
			for _, partialSearchResult := range searchResults {
				if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_SUCCESS {
					filterSearchResult = append(filterSearchResult, partialSearchResult)
Z
zhenshan.cao 已提交
486 487
				} else {
					filterReason += partialSearchResult.Status.Reason + "\n"
488 489 490
				}
			}

491 492
			availableQueryNodeNum := len(filterSearchResult)
			if availableQueryNodeNum <= 0 {
Z
zhenshan.cao 已提交
493 494 495 496 497 498
				qt.result = &servicepb.QueryResult{
					Status: &commonpb.Status{
						ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
						Reason:    filterReason,
					},
				}
G
godchen 已提交
499
				span.LogFields(oplog.Error(errors.New(filterReason)))
Z
zhenshan.cao 已提交
500
				return errors.New(filterReason)
501
			}
C
cai.zhang 已提交
502

503 504 505 506 507 508 509 510 511 512
			hits := make([][]*servicepb.Hits, 0)
			for _, partialSearchResult := range filterSearchResult {
				if len(partialSearchResult.Hits) <= 0 {
					filterReason += "nq is zero\n"
					continue
				}
				partialHits := make([]*servicepb.Hits, 0)
				for _, bs := range partialSearchResult.Hits {
					partialHit := &servicepb.Hits{}
					err := proto.Unmarshal(bs, partialHit)
N
neza2017 已提交
513
					if err != nil {
B
bigsheeper 已提交
514
						log.Println("unmarshal error")
N
neza2017 已提交
515 516
						return err
					}
517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
					partialHits = append(partialHits, partialHit)
				}
				hits = append(hits, partialHits)
			}

			availableQueryNodeNum = len(hits)
			if availableQueryNodeNum <= 0 {
				qt.result = &servicepb.QueryResult{
					Status: &commonpb.Status{
						ErrorCode: commonpb.ErrorCode_SUCCESS,
						Reason:    filterReason,
					},
				}
				return nil
			}

			nq := len(hits[0])
			if nq <= 0 {
				qt.result = &servicepb.QueryResult{
					Status: &commonpb.Status{
						ErrorCode: commonpb.ErrorCode_SUCCESS,
						Reason:    filterReason,
					},
N
neza2017 已提交
540
				}
541
				return nil
N
neza2017 已提交
542
			}
C
cai.zhang 已提交
543

544
			topk := len(hits[0][0].IDs)
C
cai.zhang 已提交
545
			qt.result = &servicepb.QueryResult{
546 547 548
				Status: &commonpb.Status{
					ErrorCode: 0,
				},
C
cai.zhang 已提交
549
				Hits: make([][]byte, 0),
550
			}
C
cai.zhang 已提交
551

G
GuoRentong 已提交
552
			const minFloat32 = -1 * float32(math.MaxFloat32)
553 554
			for i := 0; i < nq; i++ {
				locs := make([]int, availableQueryNodeNum)
C
cai.zhang 已提交
555 556 557 558 559 560
				reducedHits := &servicepb.Hits{
					IDs:     make([]int64, 0),
					RowData: make([][]byte, 0),
					Scores:  make([]float32, 0),
				}

561
				for j := 0; j < topk; j++ {
G
GuoRentong 已提交
562
					choice, maxDistance := 0, minFloat32
563
					for q, loc := range locs { // query num, the number of ways to merge
N
neza2017 已提交
564
						distance := hits[q][i].Scores[loc]
G
GuoRentong 已提交
565
						if distance > maxDistance {
566
							choice = q
G
GuoRentong 已提交
567
							maxDistance = distance
568 569 570
						}
					}
					choiceOffset := locs[choice]
571 572
					// check if distance is valid, `invalid` here means very very big,
					// in this process, distance here is the smallest, so the rest of distance are all invalid
G
GuoRentong 已提交
573
					if hits[choice][i].Scores[choiceOffset] <= minFloat32 {
574
						break
575
					}
N
neza2017 已提交
576
					reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset])
C
cai.zhang 已提交
577 578 579
					if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 {
						reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset])
					}
N
neza2017 已提交
580
					reducedHits.Scores = append(reducedHits.Scores, hits[choice][i].Scores[choiceOffset])
581 582
					locs[choice]++
				}
G
GuoRentong 已提交
583 584 585 586 587
				if searchResults[0].MetricType != "IP" {
					for k := range reducedHits.Scores {
						reducedHits.Scores[k] *= -1
					}
				}
N
neza2017 已提交
588 589
				reducedHitsBs, err := proto.Marshal(reducedHits)
				if err != nil {
B
bigsheeper 已提交
590
					log.Println("marshal error")
G
godchen 已提交
591
					span.LogFields(oplog.Error(err))
N
neza2017 已提交
592 593
					return err
				}
C
cai.zhang 已提交
594
				qt.result.Hits = append(qt.result.Hits, reducedHitsBs)
595
			}
C
cai.zhang 已提交
596
			return nil
597 598
		}
	}
D
dragondriver 已提交
599 600
}

601
type HasCollectionTask struct {
D
dragondriver 已提交
602
	Condition
603 604
	internalpb.HasCollectionRequest
	masterClient masterpb.MasterClient
605
	result       *servicepb.BoolResponse
606 607 608
	ctx          context.Context
}

C
cai.zhang 已提交
609
func (hct *HasCollectionTask) ID() UniqueID {
610
	return hct.ReqID
611 612
}

613 614 615 616
func (hct *HasCollectionTask) SetID(uid UniqueID) {
	hct.ReqID = uid
}

617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633
func (hct *HasCollectionTask) Type() internalpb.MsgType {
	return hct.MsgType
}

func (hct *HasCollectionTask) BeginTs() Timestamp {
	return hct.Timestamp
}

func (hct *HasCollectionTask) EndTs() Timestamp {
	return hct.Timestamp
}

func (hct *HasCollectionTask) SetTs(ts Timestamp) {
	hct.Timestamp = ts
}

func (hct *HasCollectionTask) PreExecute() error {
N
neza2017 已提交
634 635 636
	if err := ValidateCollectionName(hct.CollectionName.CollectionName); err != nil {
		return err
	}
637 638 639 640 641 642 643
	return nil
}

func (hct *HasCollectionTask) Execute() error {
	resp, err := hct.masterClient.HasCollection(hct.ctx, &hct.HasCollectionRequest)
	if err != nil {
		log.Printf("has collection failed, error= %v", err)
644
		hct.result = &servicepb.BoolResponse{
645 646 647 648 649 650 651
			Status: &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
				Reason:    "internal error",
			},
			Value: false,
		}
	} else {
652
		hct.result = resp
653 654 655 656 657 658 659 660 661
	}
	return err
}

func (hct *HasCollectionTask) PostExecute() error {
	return nil
}

type DescribeCollectionTask struct {
D
dragondriver 已提交
662
	Condition
663 664
	internalpb.DescribeCollectionRequest
	masterClient masterpb.MasterClient
665
	result       *servicepb.CollectionDescription
666 667 668
	ctx          context.Context
}

C
cai.zhang 已提交
669
func (dct *DescribeCollectionTask) ID() UniqueID {
670
	return dct.ReqID
671 672
}

673 674 675 676
func (dct *DescribeCollectionTask) SetID(uid UniqueID) {
	dct.ReqID = uid
}

677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
func (dct *DescribeCollectionTask) Type() internalpb.MsgType {
	return dct.MsgType
}

func (dct *DescribeCollectionTask) BeginTs() Timestamp {
	return dct.Timestamp
}

func (dct *DescribeCollectionTask) EndTs() Timestamp {
	return dct.Timestamp
}

func (dct *DescribeCollectionTask) SetTs(ts Timestamp) {
	dct.Timestamp = ts
}

func (dct *DescribeCollectionTask) PreExecute() error {
N
neza2017 已提交
694 695 696
	if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil {
		return err
	}
697 698 699 700
	return nil
}

func (dct *DescribeCollectionTask) Execute() error {
701
	var err error
G
godchen 已提交
702
	dct.result, err = dct.masterClient.DescribeCollection(dct.ctx, &dct.DescribeCollectionRequest)
G
godchen 已提交
703 704 705 706
	if err != nil {
		return err
	}
	err = globalMetaCache.Update(dct.CollectionName.CollectionName, dct.result)
707 708 709 710 711 712 713 714
	return err
}

func (dct *DescribeCollectionTask) PostExecute() error {
	return nil
}

type ShowCollectionsTask struct {
D
dragondriver 已提交
715
	Condition
716 717
	internalpb.ShowCollectionRequest
	masterClient masterpb.MasterClient
718
	result       *servicepb.StringListResponse
719 720 721
	ctx          context.Context
}

C
cai.zhang 已提交
722
func (sct *ShowCollectionsTask) ID() UniqueID {
723
	return sct.ReqID
724 725
}

726 727 728 729
func (sct *ShowCollectionsTask) SetID(uid UniqueID) {
	sct.ReqID = uid
}

730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753
func (sct *ShowCollectionsTask) Type() internalpb.MsgType {
	return sct.MsgType
}

func (sct *ShowCollectionsTask) BeginTs() Timestamp {
	return sct.Timestamp
}

func (sct *ShowCollectionsTask) EndTs() Timestamp {
	return sct.Timestamp
}

func (sct *ShowCollectionsTask) SetTs(ts Timestamp) {
	sct.Timestamp = ts
}

func (sct *ShowCollectionsTask) PreExecute() error {
	return nil
}

func (sct *ShowCollectionsTask) Execute() error {
	resp, err := sct.masterClient.ShowCollections(sct.ctx, &sct.ShowCollectionRequest)
	if err != nil {
		log.Printf("show collections failed, error= %v", err)
754
		sct.result = &servicepb.StringListResponse{
755 756 757 758 759 760
			Status: &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
				Reason:    "internal error",
			},
		}
	} else {
761
		sct.result = resp
762 763 764 765 766 767 768
	}
	return err
}

func (sct *ShowCollectionsTask) PostExecute() error {
	return nil
}
N
neza2017 已提交
769 770 771 772 773 774 775 776 777 778 779 780 781

type CreatePartitionTask struct {
	Condition
	internalpb.CreatePartitionRequest
	masterClient masterpb.MasterClient
	result       *commonpb.Status
	ctx          context.Context
}

func (cpt *CreatePartitionTask) ID() UniqueID {
	return cpt.ReqID
}

782 783 784 785
func (cpt *CreatePartitionTask) SetID(uid UniqueID) {
	cpt.ReqID = uid
}

N
neza2017 已提交
786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802
func (cpt *CreatePartitionTask) Type() internalpb.MsgType {
	return cpt.MsgType
}

func (cpt *CreatePartitionTask) BeginTs() Timestamp {
	return cpt.Timestamp
}

func (cpt *CreatePartitionTask) EndTs() Timestamp {
	return cpt.Timestamp
}

func (cpt *CreatePartitionTask) SetTs(ts Timestamp) {
	cpt.Timestamp = ts
}

func (cpt *CreatePartitionTask) PreExecute() error {
N
neza2017 已提交
803 804 805 806 807 808 809 810 811 812
	collName, partitionTag := cpt.PartitionName.CollectionName, cpt.PartitionName.Tag

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidatePartitionTag(partitionTag, true); err != nil {
		return err
	}

N
neza2017 已提交
813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836
	return nil
}

func (cpt *CreatePartitionTask) Execute() (err error) {
	cpt.result, err = cpt.masterClient.CreatePartition(cpt.ctx, &cpt.CreatePartitionRequest)
	return err
}

func (cpt *CreatePartitionTask) PostExecute() error {
	return nil
}

type DropPartitionTask struct {
	Condition
	internalpb.DropPartitionRequest
	masterClient masterpb.MasterClient
	result       *commonpb.Status
	ctx          context.Context
}

func (dpt *DropPartitionTask) ID() UniqueID {
	return dpt.ReqID
}

837 838 839 840
func (dpt *DropPartitionTask) SetID(uid UniqueID) {
	dpt.ReqID = uid
}

N
neza2017 已提交
841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857
func (dpt *DropPartitionTask) Type() internalpb.MsgType {
	return dpt.MsgType
}

func (dpt *DropPartitionTask) BeginTs() Timestamp {
	return dpt.Timestamp
}

func (dpt *DropPartitionTask) EndTs() Timestamp {
	return dpt.Timestamp
}

func (dpt *DropPartitionTask) SetTs(ts Timestamp) {
	dpt.Timestamp = ts
}

func (dpt *DropPartitionTask) PreExecute() error {
N
neza2017 已提交
858 859 860 861 862 863 864 865 866 867
	collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidatePartitionTag(partitionTag, true); err != nil {
		return err
	}

N
neza2017 已提交
868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891
	return nil
}

func (dpt *DropPartitionTask) Execute() (err error) {
	dpt.result, err = dpt.masterClient.DropPartition(dpt.ctx, &dpt.DropPartitionRequest)
	return err
}

func (dpt *DropPartitionTask) PostExecute() error {
	return nil
}

type HasPartitionTask struct {
	Condition
	internalpb.HasPartitionRequest
	masterClient masterpb.MasterClient
	result       *servicepb.BoolResponse
	ctx          context.Context
}

func (hpt *HasPartitionTask) ID() UniqueID {
	return hpt.ReqID
}

892 893 894 895
func (hpt *HasPartitionTask) SetID(uid UniqueID) {
	hpt.ReqID = uid
}

N
neza2017 已提交
896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912
func (hpt *HasPartitionTask) Type() internalpb.MsgType {
	return hpt.MsgType
}

func (hpt *HasPartitionTask) BeginTs() Timestamp {
	return hpt.Timestamp
}

func (hpt *HasPartitionTask) EndTs() Timestamp {
	return hpt.Timestamp
}

func (hpt *HasPartitionTask) SetTs(ts Timestamp) {
	hpt.Timestamp = ts
}

func (hpt *HasPartitionTask) PreExecute() error {
N
neza2017 已提交
913 914 915 916 917 918 919 920 921
	collName, partitionTag := hpt.PartitionName.CollectionName, hpt.PartitionName.Tag

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidatePartitionTag(partitionTag, true); err != nil {
		return err
	}
N
neza2017 已提交
922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
	return nil
}

func (hpt *HasPartitionTask) Execute() (err error) {
	hpt.result, err = hpt.masterClient.HasPartition(hpt.ctx, &hpt.HasPartitionRequest)
	return err
}

func (hpt *HasPartitionTask) PostExecute() error {
	return nil
}

type DescribePartitionTask struct {
	Condition
	internalpb.DescribePartitionRequest
	masterClient masterpb.MasterClient
	result       *servicepb.PartitionDescription
	ctx          context.Context
}

func (dpt *DescribePartitionTask) ID() UniqueID {
	return dpt.ReqID
}

946 947 948 949
func (dpt *DescribePartitionTask) SetID(uid UniqueID) {
	dpt.ReqID = uid
}

N
neza2017 已提交
950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966
func (dpt *DescribePartitionTask) Type() internalpb.MsgType {
	return dpt.MsgType
}

func (dpt *DescribePartitionTask) BeginTs() Timestamp {
	return dpt.Timestamp
}

func (dpt *DescribePartitionTask) EndTs() Timestamp {
	return dpt.Timestamp
}

func (dpt *DescribePartitionTask) SetTs(ts Timestamp) {
	dpt.Timestamp = ts
}

func (dpt *DescribePartitionTask) PreExecute() error {
N
neza2017 已提交
967 968 969 970 971 972 973 974 975
	collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidatePartitionTag(partitionTag, true); err != nil {
		return err
	}
N
neza2017 已提交
976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999
	return nil
}

func (dpt *DescribePartitionTask) Execute() (err error) {
	dpt.result, err = dpt.masterClient.DescribePartition(dpt.ctx, &dpt.DescribePartitionRequest)
	return err
}

func (dpt *DescribePartitionTask) PostExecute() error {
	return nil
}

type ShowPartitionsTask struct {
	Condition
	internalpb.ShowPartitionRequest
	masterClient masterpb.MasterClient
	result       *servicepb.StringListResponse
	ctx          context.Context
}

func (spt *ShowPartitionsTask) ID() UniqueID {
	return spt.ReqID
}

1000 1001 1002 1003
func (spt *ShowPartitionsTask) SetID(uid UniqueID) {
	spt.ReqID = uid
}

N
neza2017 已提交
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020
func (spt *ShowPartitionsTask) Type() internalpb.MsgType {
	return spt.MsgType
}

func (spt *ShowPartitionsTask) BeginTs() Timestamp {
	return spt.Timestamp
}

func (spt *ShowPartitionsTask) EndTs() Timestamp {
	return spt.Timestamp
}

func (spt *ShowPartitionsTask) SetTs(ts Timestamp) {
	spt.Timestamp = ts
}

func (spt *ShowPartitionsTask) PreExecute() error {
N
neza2017 已提交
1021 1022 1023
	if err := ValidateCollectionName(spt.CollectionName.CollectionName); err != nil {
		return err
	}
N
neza2017 已提交
1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034
	return nil
}

func (spt *ShowPartitionsTask) Execute() (err error) {
	spt.result, err = spt.masterClient.ShowPartitions(spt.ctx, &spt.ShowPartitionRequest)
	return err
}

func (spt *ShowPartitionsTask) PostExecute() error {
	return nil
}
1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199

type CreateIndexTask struct {
	Condition
	internalpb.CreateIndexRequest
	masterClient masterpb.MasterClient
	result       *commonpb.Status
	ctx          context.Context
}

func (cit *CreateIndexTask) ID() UniqueID {
	return cit.ReqID
}

func (cit *CreateIndexTask) SetID(uid UniqueID) {
	cit.ReqID = uid
}

func (cit *CreateIndexTask) Type() internalpb.MsgType {
	return cit.MsgType
}

func (cit *CreateIndexTask) BeginTs() Timestamp {
	return cit.Timestamp
}

func (cit *CreateIndexTask) EndTs() Timestamp {
	return cit.Timestamp
}

func (cit *CreateIndexTask) SetTs(ts Timestamp) {
	cit.Timestamp = ts
}

func (cit *CreateIndexTask) PreExecute() error {
	collName, fieldName := cit.CollectionName, cit.FieldName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidateFieldName(fieldName); err != nil {
		return err
	}

	return nil
}

func (cit *CreateIndexTask) Execute() (err error) {
	cit.result, err = cit.masterClient.CreateIndex(cit.ctx, &cit.CreateIndexRequest)
	return err
}

func (cit *CreateIndexTask) PostExecute() error {
	return nil
}

type DescribeIndexTask struct {
	Condition
	internalpb.DescribeIndexRequest
	masterClient masterpb.MasterClient
	result       *servicepb.DescribeIndexResponse
	ctx          context.Context
}

func (dit *DescribeIndexTask) ID() UniqueID {
	return dit.ReqID
}

func (dit *DescribeIndexTask) SetID(uid UniqueID) {
	dit.ReqID = uid
}

func (dit *DescribeIndexTask) Type() internalpb.MsgType {
	return dit.MsgType
}

func (dit *DescribeIndexTask) BeginTs() Timestamp {
	return dit.Timestamp
}

func (dit *DescribeIndexTask) EndTs() Timestamp {
	return dit.Timestamp
}

func (dit *DescribeIndexTask) SetTs(ts Timestamp) {
	dit.Timestamp = ts
}

func (dit *DescribeIndexTask) PreExecute() error {
	collName, fieldName := dit.CollectionName, dit.FieldName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidateFieldName(fieldName); err != nil {
		return err
	}

	return nil
}

func (dit *DescribeIndexTask) Execute() (err error) {
	dit.result, err = dit.masterClient.DescribeIndex(dit.ctx, &dit.DescribeIndexRequest)
	return err
}

func (dit *DescribeIndexTask) PostExecute() error {
	return nil
}

type DescribeIndexProgressTask struct {
	Condition
	internalpb.DescribeIndexProgressRequest
	masterClient masterpb.MasterClient
	result       *servicepb.BoolResponse
	ctx          context.Context
}

func (dipt *DescribeIndexProgressTask) ID() UniqueID {
	return dipt.ReqID
}

func (dipt *DescribeIndexProgressTask) SetID(uid UniqueID) {
	dipt.ReqID = uid
}

func (dipt *DescribeIndexProgressTask) Type() internalpb.MsgType {
	return dipt.MsgType
}

func (dipt *DescribeIndexProgressTask) BeginTs() Timestamp {
	return dipt.Timestamp
}

func (dipt *DescribeIndexProgressTask) EndTs() Timestamp {
	return dipt.Timestamp
}

func (dipt *DescribeIndexProgressTask) SetTs(ts Timestamp) {
	dipt.Timestamp = ts
}

func (dipt *DescribeIndexProgressTask) PreExecute() error {
	collName, fieldName := dipt.CollectionName, dipt.FieldName

	if err := ValidateCollectionName(collName); err != nil {
		return err
	}

	if err := ValidateFieldName(fieldName); err != nil {
		return err
	}

	return nil
}

func (dipt *DescribeIndexProgressTask) Execute() (err error) {
	dipt.result, err = dipt.masterClient.DescribeIndexProgress(dipt.ctx, &dipt.DescribeIndexProgressRequest)
	return err
}

func (dipt *DescribeIndexProgressTask) PostExecute() error {
	return nil
}