impl.go 30.9 KB
Newer Older
1 2 3 4 5 6
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
7 8
// with the License. You may obtain a copy of the License at
//
9
//     http://www.apache.org/licenses/LICENSE-2.0
10
//
11 12 13 14 15
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
16

17 18 19 20
package querynode

import (
	"context"
21
	"errors"
22
	"fmt"
23
	"sync"
24 25 26

	"go.uber.org/zap"

27
	"github.com/milvus-io/milvus/internal/common"
X
Xiangyu Wang 已提交
28
	"github.com/milvus-io/milvus/internal/log"
29
	"github.com/milvus-io/milvus/internal/metrics"
X
Xiangyu Wang 已提交
30 31 32
	"github.com/milvus-io/milvus/internal/proto/commonpb"
	"github.com/milvus-io/milvus/internal/proto/internalpb"
	"github.com/milvus-io/milvus/internal/proto/milvuspb"
33
	"github.com/milvus-io/milvus/internal/proto/querypb"
X
Xiangyu Wang 已提交
34
	queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
35
	"github.com/milvus-io/milvus/internal/util/metricsinfo"
36
	"github.com/milvus-io/milvus/internal/util/timerecord"
X
Xiangyu Wang 已提交
37
	"github.com/milvus-io/milvus/internal/util/typeutil"
38 39
)

40
// GetComponentStates returns information about whether the node is healthy
41 42 43 44 45 46
func (node *QueryNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
	stats := &internalpb.ComponentStates{
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_Success,
		},
	}
47 48 49
	code, ok := node.stateCode.Load().(internalpb.StateCode)
	if !ok {
		errMsg := "unexpected error in type assertion"
50 51
		stats.Status = &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
52
			Reason:    errMsg,
53
		}
G
godchen 已提交
54
		return stats, nil
55 56 57 58
	}
	nodeID := common.NotRegisteredID
	if node.session != nil && node.session.Registered() {
		nodeID = node.session.ServerID
59 60
	}
	info := &internalpb.ComponentInfo{
61
		NodeID:    nodeID,
62 63 64 65
		Role:      typeutil.QueryNodeRole,
		StateCode: code,
	}
	stats.State = info
66
	log.Debug("Get QueryNode component state done", zap.Any("stateCode", info.StateCode))
67 68 69
	return stats, nil
}

70 71
// GetTimeTickChannel returns the time tick channel
// TimeTickChannel contains many time tick messages, which will be sent by query nodes
72 73 74 75 76 77
func (node *QueryNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
	return &milvuspb.StringResponse{
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_Success,
			Reason:    "",
		},
78
		Value: Params.CommonCfg.QueryCoordTimeTick,
79 80 81
	}, nil
}

82
// GetStatisticsChannel returns the statistics channel
83
// Statistics channel contains statistics infos of query nodes, such as segment infos, memory infos
84 85 86 87 88 89
func (node *QueryNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
	return &milvuspb.StringResponse{
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_Success,
			Reason:    "",
		},
90
		Value: Params.CommonCfg.QueryNodeStats,
91 92 93
	}, nil
}

94
// AddQueryChannel watch queryChannel of the collection to receive query message
95
func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQueryChannelRequest) (*commonpb.Status, error) {
96 97
	code := node.stateCode.Load().(internalpb.StateCode)
	if code != internalpb.StateCode_Healthy {
X
Xiaofan 已提交
98
		err := fmt.Errorf("query node %d is not ready", Params.QueryNodeCfg.GetNodeID())
99 100 101 102
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
G
godchen 已提交
103
		return status, nil
104
	}
105 106 107 108 109 110 111
	dct := &addQueryChannelTask{
		baseTask: baseTask{
			ctx:  ctx,
			done: make(chan error),
		},
		req:  in,
		node: node,
112
	}
113

114
	err := node.scheduler.queue.Enqueue(dct)
115 116 117 118
	if err != nil {
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
119
		}
X
Xiaofan 已提交
120
		log.Warn(err.Error())
G
godchen 已提交
121
		return status, nil
122
	}
X
Xiaofan 已提交
123
	log.Info("addQueryChannelTask Enqueue done",
124 125 126 127
		zap.Int64("collectionID", in.CollectionID),
		zap.String("queryChannel", in.QueryChannel),
		zap.String("queryResultChannel", in.QueryResultChannel),
	)
128

129 130 131 132 133 134
	waitFunc := func() (*commonpb.Status, error) {
		err = dct.WaitToFinish()
		if err != nil {
			status := &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UnexpectedError,
				Reason:    err.Error(),
135
			}
X
Xiaofan 已提交
136
			log.Warn(err.Error())
G
godchen 已提交
137
			return status, nil
138
		}
X
Xiaofan 已提交
139
		log.Info("addQueryChannelTask WaitToFinish done",
140 141 142 143 144
			zap.Int64("collectionID", in.CollectionID),
			zap.String("queryChannel", in.QueryChannel),
			zap.String("queryResultChannel", in.QueryResultChannel),
		)

145 146 147
		return &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_Success,
		}, nil
148
	}
149

150
	return waitFunc()
151 152
}

153
// RemoveQueryChannel remove queryChannel of the collection to stop receiving query message
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
func (node *QueryNode) RemoveQueryChannel(ctx context.Context, in *queryPb.RemoveQueryChannelRequest) (*commonpb.Status, error) {
	// if node.searchService == nil || node.searchService.searchMsgStream == nil {
	// 	errMsg := "null search service or null search result message stream"
	// 	status := &commonpb.Status{
	// 		ErrorCode: commonpb.ErrorCode_UnexpectedError,
	// 		Reason:    errMsg,
	// 	}

	// 	return status, errors.New(errMsg)
	// }

	// searchStream, ok := node.searchService.searchMsgStream.(*pulsarms.PulsarMsgStream)
	// if !ok {
	// 	errMsg := "type assertion failed for search message stream"
	// 	status := &commonpb.Status{
	// 		ErrorCode: commonpb.ErrorCode_UnexpectedError,
	// 		Reason:    errMsg,
	// 	}

	// 	return status, errors.New(errMsg)
	// }

	// resultStream, ok := node.searchService.searchResultMsgStream.(*pulsarms.PulsarMsgStream)
	// if !ok {
	// 	errMsg := "type assertion failed for search result message stream"
	// 	status := &commonpb.Status{
	// 		ErrorCode: commonpb.ErrorCode_UnexpectedError,
	// 		Reason:    errMsg,
	// 	}

	// 	return status, errors.New(errMsg)
	// }

	// // remove request channel
	// consumeChannels := []string{in.RequestChannelID}
	// consumeSubName := Params.MsgChannelSubName
	// // TODO: searchStream.RemovePulsarConsumers(producerChannels)
	// searchStream.AsConsumer(consumeChannels, consumeSubName)

	// // remove result channel
	// producerChannels := []string{in.ResultChannelID}
	// // TODO: resultStream.RemovePulsarProducer(producerChannels)
	// resultStream.AsProducer(producerChannels)

	status := &commonpb.Status{
		ErrorCode: commonpb.ErrorCode_Success,
	}
	return status, nil
}

G
godchen 已提交
204
// WatchDmChannels create consumers on dmChannels to receive Incremental data,which is the important part of real-time query
205
func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmChannelsRequest) (*commonpb.Status, error) {
206 207
	code := node.stateCode.Load().(internalpb.StateCode)
	if code != internalpb.StateCode_Healthy {
X
Xiaofan 已提交
208
		err := fmt.Errorf("query node %d is not ready", Params.QueryNodeCfg.GetNodeID())
209 210 211 212
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
G
godchen 已提交
213
		return status, nil
214
	}
215 216 217 218 219 220 221
	dct := &watchDmChannelsTask{
		baseTask: baseTask{
			ctx:  ctx,
			done: make(chan error),
		},
		req:  in,
		node: node,
222 223
	}

224 225
	err := node.scheduler.queue.Enqueue(dct)
	if err != nil {
226 227
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
228
			Reason:    err.Error(),
229
		}
X
Xiaofan 已提交
230
		log.Warn(err.Error())
G
godchen 已提交
231
		return status, nil
232
	}
X
Xiaofan 已提交
233
	log.Info("watchDmChannelsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()), zap.Int64("replicaID", in.GetReplicaID()))
234
	waitFunc := func() (*commonpb.Status, error) {
235
		err = dct.WaitToFinish()
236
		if err != nil {
237 238 239 240
			status := &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UnexpectedError,
				Reason:    err.Error(),
			}
X
Xiaofan 已提交
241
			log.Warn(err.Error())
G
godchen 已提交
242
			return status, nil
243
		}
X
Xiaofan 已提交
244
		log.Info("watchDmChannelsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
245 246 247
		return &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_Success,
		}, nil
248
	}
249 250

	return waitFunc()
251 252
}

G
godchen 已提交
253
// WatchDeltaChannels create consumers on dmChannels to receive Incremental data,which is the important part of real-time query
254
func (node *QueryNode) WatchDeltaChannels(ctx context.Context, in *queryPb.WatchDeltaChannelsRequest) (*commonpb.Status, error) {
255 256
	code := node.stateCode.Load().(internalpb.StateCode)
	if code != internalpb.StateCode_Healthy {
X
Xiaofan 已提交
257
		err := fmt.Errorf("query node %d is not ready", Params.QueryNodeCfg.GetNodeID())
258 259 260 261
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
G
godchen 已提交
262
		return status, nil
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
	}
	dct := &watchDeltaChannelsTask{
		baseTask: baseTask{
			ctx:  ctx,
			done: make(chan error),
		},
		req:  in,
		node: node,
	}

	err := node.scheduler.queue.Enqueue(dct)
	if err != nil {
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
X
Xiaofan 已提交
279
		log.Warn(err.Error())
G
godchen 已提交
280
		return status, nil
281
	}
X
Xiaofan 已提交
282 283

	log.Info("watchDeltaChannelsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
284 285 286 287 288 289 290 291

	waitFunc := func() (*commonpb.Status, error) {
		err = dct.WaitToFinish()
		if err != nil {
			status := &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UnexpectedError,
				Reason:    err.Error(),
			}
X
Xiaofan 已提交
292
			log.Warn(err.Error())
G
godchen 已提交
293
			return status, nil
294
		}
X
Xiaofan 已提交
295 296

		log.Info("watchDeltaChannelsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
297 298 299 300 301 302
		return &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_Success,
		}, nil
	}

	return waitFunc()
303 304
}

305
// LoadSegments load historical data into query node, historical data can be vector data or index
306
func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegmentsRequest) (*commonpb.Status, error) {
307 308
	code := node.stateCode.Load().(internalpb.StateCode)
	if code != internalpb.StateCode_Healthy {
X
Xiaofan 已提交
309
		err := fmt.Errorf("query node %d is not ready", Params.QueryNodeCfg.GetNodeID())
310 311 312 313
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
G
godchen 已提交
314
		return status, nil
315
	}
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
	dct := &loadSegmentsTask{
		baseTask: baseTask{
			ctx:  ctx,
			done: make(chan error),
		},
		req:  in,
		node: node,
	}

	err := node.scheduler.queue.Enqueue(dct)
	if err != nil {
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
X
Xiaofan 已提交
331
		log.Warn(err.Error())
G
godchen 已提交
332
		return status, nil
333
	}
334 335 336 337
	segmentIDs := make([]UniqueID, 0)
	for _, info := range in.Infos {
		segmentIDs = append(segmentIDs, info.SegmentID)
	}
X
Xiaofan 已提交
338
	log.Info("loadSegmentsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
339

340
	waitFunc := func() (*commonpb.Status, error) {
341 342
		err = dct.WaitToFinish()
		if err != nil {
343 344 345 346
			status := &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UnexpectedError,
				Reason:    err.Error(),
			}
X
Xiaofan 已提交
347
			log.Warn(err.Error())
G
godchen 已提交
348
			return status, nil
349
		}
X
Xiaofan 已提交
350
		log.Info("loadSegmentsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
351 352 353
		return &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_Success,
		}, nil
354
	}
355 356

	return waitFunc()
357 358
}

G
godchen 已提交
359
// ReleaseCollection clears all data related to this collection on the querynode
360
func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.ReleaseCollectionRequest) (*commonpb.Status, error) {
361 362
	code := node.stateCode.Load().(internalpb.StateCode)
	if code != internalpb.StateCode_Healthy {
X
Xiaofan 已提交
363
		err := fmt.Errorf("query node %d is not ready", Params.QueryNodeCfg.GetNodeID())
364 365 366 367
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
G
godchen 已提交
368
		return status, nil
369
	}
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
	dct := &releaseCollectionTask{
		baseTask: baseTask{
			ctx:  ctx,
			done: make(chan error),
		},
		req:  in,
		node: node,
	}

	err := node.scheduler.queue.Enqueue(dct)
	if err != nil {
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
X
Xiaofan 已提交
385
		log.Warn(err.Error())
G
godchen 已提交
386
		return status, nil
387
	}
X
Xiaofan 已提交
388
	log.Info("releaseCollectionTask Enqueue done", zap.Int64("collectionID", in.CollectionID))
389

390
	func() {
391 392
		err = dct.WaitToFinish()
		if err != nil {
X
Xiaofan 已提交
393
			log.Warn(err.Error())
394
			return
395
		}
X
Xiaofan 已提交
396
		log.Info("releaseCollectionTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID))
397
	}()
398 399 400 401 402 403 404

	status := &commonpb.Status{
		ErrorCode: commonpb.ErrorCode_Success,
	}
	return status, nil
}

405
// ReleasePartitions clears all data related to this partition on the querynode
406
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.ReleasePartitionsRequest) (*commonpb.Status, error) {
407 408
	code := node.stateCode.Load().(internalpb.StateCode)
	if code != internalpb.StateCode_Healthy {
X
Xiaofan 已提交
409
		err := fmt.Errorf("query node %d is not ready", Params.QueryNodeCfg.GetNodeID())
410 411 412 413
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
G
godchen 已提交
414
		return status, nil
415
	}
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
	dct := &releasePartitionsTask{
		baseTask: baseTask{
			ctx:  ctx,
			done: make(chan error),
		},
		req:  in,
		node: node,
	}

	err := node.scheduler.queue.Enqueue(dct)
	if err != nil {
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
X
Xiaofan 已提交
431
		log.Warn(err.Error())
G
godchen 已提交
432
		return status, nil
433
	}
X
Xiaofan 已提交
434
	log.Info("releasePartitionsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("partitionIDs", in.PartitionIDs))
435

436
	func() {
437 438
		err = dct.WaitToFinish()
		if err != nil {
X
Xiaofan 已提交
439
			log.Warn(err.Error())
440
			return
441
		}
X
Xiaofan 已提交
442
		log.Info("releasePartitionsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("partitionIDs", in.PartitionIDs))
443
	}()
444 445 446 447 448 449 450

	status := &commonpb.Status{
		ErrorCode: commonpb.ErrorCode_Success,
	}
	return status, nil
}

451
// ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID
452
func (node *QueryNode) ReleaseSegments(ctx context.Context, in *queryPb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
453 454
	code := node.stateCode.Load().(internalpb.StateCode)
	if code != internalpb.StateCode_Healthy {
X
Xiaofan 已提交
455
		err := fmt.Errorf("query node %d is not ready", Params.QueryNodeCfg.GetNodeID())
456 457 458 459
		status := &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}
G
godchen 已提交
460
		return status, nil
461
	}
462 463 464
	status := &commonpb.Status{
		ErrorCode: commonpb.ErrorCode_Success,
	}
465
	// collection lock is not needed since we guarantee not query/search will be dispatch from leader
466
	for _, id := range in.SegmentIDs {
467
		err := node.historical.removeSegment(id)
468 469 470 471 472
		if err != nil {
			// not return, try to release all segments
			status.ErrorCode = commonpb.ErrorCode_UnexpectedError
			status.Reason = err.Error()
		}
473
		err = node.streaming.removeSegment(id)
474
		if err != nil {
475 476
			// not return, try to release all segments
			status.ErrorCode = commonpb.ErrorCode_UnexpectedError
477
			status.Reason = err.Error()
478 479
		}
	}
X
xige-16 已提交
480

X
Xiaofan 已提交
481
	log.Info("release segments done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", in.SegmentIDs))
482 483 484
	return status, nil
}

485
// GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ...
486
func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *queryPb.GetSegmentInfoRequest) (*queryPb.GetSegmentInfoResponse, error) {
487 488
	code := node.stateCode.Load().(internalpb.StateCode)
	if code != internalpb.StateCode_Healthy {
X
Xiaofan 已提交
489
		err := fmt.Errorf("query node %d is not ready", Params.QueryNodeCfg.GetNodeID())
490 491 492 493 494 495
		res := &queryPb.GetSegmentInfoResponse{
			Status: &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UnexpectedError,
				Reason:    err.Error(),
			},
		}
G
godchen 已提交
496
		return res, nil
497
	}
498 499 500 501 502 503
	var segmentInfos []*queryPb.SegmentInfo

	segmentIDs := make(map[int64]struct{})
	for _, segmentID := range in.GetSegmentIDs() {
		segmentIDs[segmentID] = struct{}{}
	}
504

505
	// get info from historical
506
	historicalSegmentInfos, err := node.historical.getSegmentInfosByColID(in.CollectionID)
507
	if err != nil {
X
Xiaofan 已提交
508
		log.Warn("GetSegmentInfo: get historical segmentInfo failed", zap.Int64("collectionID", in.CollectionID), zap.Error(err))
509 510 511 512 513 514
		res := &queryPb.GetSegmentInfoResponse{
			Status: &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UnexpectedError,
				Reason:    err.Error(),
			},
		}
G
godchen 已提交
515
		return res, nil
516
	}
517
	segmentInfos = append(segmentInfos, filterSegmentInfo(historicalSegmentInfos, segmentIDs)...)
518

519
	// get info from streaming
520
	streamingSegmentInfos, err := node.streaming.getSegmentInfosByColID(in.CollectionID)
521
	if err != nil {
X
Xiaofan 已提交
522
		log.Warn("GetSegmentInfo: get streaming segmentInfo failed", zap.Int64("collectionID", in.CollectionID), zap.Error(err))
523 524 525 526 527 528
		res := &queryPb.GetSegmentInfoResponse{
			Status: &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UnexpectedError,
				Reason:    err.Error(),
			},
		}
G
godchen 已提交
529
		return res, nil
530
	}
531
	segmentInfos = append(segmentInfos, filterSegmentInfo(streamingSegmentInfos, segmentIDs)...)
532

533 534 535 536
	return &queryPb.GetSegmentInfoResponse{
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_Success,
		},
537
		Infos: segmentInfos,
538 539
	}, nil
}
540

541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
// filterSegmentInfo returns segment info which segment id in segmentIDs map
func filterSegmentInfo(segmentInfos []*queryPb.SegmentInfo, segmentIDs map[int64]struct{}) []*queryPb.SegmentInfo {
	if len(segmentIDs) == 0 {
		return segmentInfos
	}
	filtered := make([]*queryPb.SegmentInfo, 0, len(segmentIDs))
	for _, info := range segmentInfos {
		_, ok := segmentIDs[info.GetSegmentID()]
		if !ok {
			continue
		}
		filtered = append(filtered, info)
	}
	return filtered
}

557
// isHealthy checks if QueryNode is healthy
558 559 560 561 562
func (node *QueryNode) isHealthy() bool {
	code := node.stateCode.Load().(internalpb.StateCode)
	return code == internalpb.StateCode_Healthy
}

563
// Search performs replica search tasks.
564
func (node *QueryNode) Search(ctx context.Context, req *queryPb.SearchRequest) (*internalpb.SearchResults, error) {
565
	metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel).Inc()
566 567 568 569 570
	failRet := &internalpb.SearchResults{
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
		},
	}
571 572 573 574 575 576

	defer func() {
		if failRet.Status.ErrorCode != commonpb.ErrorCode_Success {
			metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.FailLabel).Inc()
		}
	}()
577
	if !node.isHealthy() {
578 579
		failRet.Status.Reason = msgQueryNodeIsUnhealthy(Params.QueryNodeCfg.GetNodeID())
		return failRet, nil
580 581 582 583 584
	}

	log.Debug("Received SearchRequest", zap.String("vchannel", req.GetDmlChannel()), zap.Int64s("segmentIDs", req.GetSegmentIDs()))

	if node.queryShardService == nil {
585 586
		failRet.Status.Reason = "queryShardService is nil"
		return failRet, nil
587 588 589 590
	}

	qs, err := node.queryShardService.getQueryShard(req.GetDmlChannel())
	if err != nil {
591
		log.Warn("Search failed, failed to get query shard", zap.String("dml channel", req.GetDmlChannel()), zap.Error(err))
592 593 594
		failRet.Status.ErrorCode = commonpb.ErrorCode_NotShardLeader
		failRet.Status.Reason = err.Error()
		return failRet, nil
595 596
	}

597 598
	tr := timerecord.NewTimeRecorder(fmt.Sprintf("search %d", req.Req.CollectionID))

599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
	if req.FromShardLeader {
		historicalTask, err2 := newSearchTask(ctx, req)
		if err2 != nil {
			failRet.Status.Reason = err2.Error()
			return failRet, nil
		}
		historicalTask.QS = qs
		historicalTask.DataScope = querypb.DataScope_Historical
		err2 = node.scheduler.AddReadTask(ctx, historicalTask)
		if err2 != nil {
			failRet.Status.Reason = err2.Error()
			return failRet, nil
		}

		err2 = historicalTask.WaitToFinish()
		if err2 != nil {
			failRet.Status.Reason = err2.Error()
			return failRet, nil
		}
618 619 620 621 622 623 624 625 626

		failRet.Status.ErrorCode = commonpb.ErrorCode_Success
		metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
			metrics.SearchLabel).Observe(float64(historicalTask.queueDur.Milliseconds()))
		metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
			metrics.SearchLabel).Observe(float64(historicalTask.reduceDur.Milliseconds()))
		latency := tr.ElapseSpan()
		metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel).Observe(float64(latency.Milliseconds()))
		metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
		return historicalTask.Ret, nil
	}

	//from Proxy
	cluster, ok := qs.clusterService.getShardCluster(req.GetDmlChannel())
	if !ok {
		failRet.Status.Reason = fmt.Sprintf("channel %s leader is not here", req.GetDmlChannel())
		return failRet, nil
	}

	searchCtx, cancel := context.WithCancel(ctx)
	defer cancel()

	var results []*internalpb.SearchResults
	var streamingResult *internalpb.SearchResults

	var wg sync.WaitGroup
	var errCluster error

	wg.Add(1) // search cluster
	go func() {
		defer wg.Done()
		// shard leader dispatches request to its shard cluster
		oResults, cErr := cluster.Search(searchCtx, req)
		if cErr != nil {
			log.Warn("search cluster failed", zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(cErr))
			cancel()
			errCluster = cErr
			return
		}
		results = oResults
	}()

	var errStreaming error
	wg.Add(1) // search streaming
	go func() {
		defer func() {
			if errStreaming != nil {
				cancel()
			}
		}()

		defer wg.Done()
		streamingTask, err2 := newSearchTask(searchCtx, req)
		if err2 != nil {
			errStreaming = err2
		}
		streamingTask.QS = qs
		streamingTask.DataScope = querypb.DataScope_Streaming
		err2 = node.scheduler.AddReadTask(searchCtx, streamingTask)
		if err2 != nil {
			errStreaming = err2
			return
		}
		err2 = streamingTask.WaitToFinish()
		if err2 != nil {
			errStreaming = err2
			return
		}
686 687 688 689
		metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
			metrics.SearchLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
		metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
			metrics.SearchLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
690 691 692 693 694 695 696 697 698 699 700 701 702 703
		streamingResult = streamingTask.Ret
	}()
	wg.Wait()

	var mainErr error
	if errCluster != nil {
		mainErr = errCluster
		if errors.Is(errCluster, context.Canceled) {
			if errStreaming != nil {
				mainErr = errStreaming
			}
		}
	} else if errStreaming != nil {
		mainErr = errStreaming
704 705
	}

706 707 708 709 710 711 712 713 714 715
	if mainErr != nil {
		failRet.Status.Reason = mainErr.Error()
		return failRet, nil
	}
	results = append(results, streamingResult)
	ret, err2 := reduceSearchResults(results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
	if err2 != nil {
		failRet.Status.Reason = err2.Error()
		return failRet, nil
	}
716 717 718 719 720

	failRet.Status.ErrorCode = commonpb.ErrorCode_Success
	latency := tr.ElapseSpan()
	metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel).Observe(float64(latency.Milliseconds()))
	metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
721
	return ret, nil
722 723 724
}

// Query performs replica query tasks.
725
func (node *QueryNode) Query(ctx context.Context, req *queryPb.QueryRequest) (*internalpb.RetrieveResults, error) {
726
	metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel).Inc()
727 728 729 730 731
	failRet := &internalpb.RetrieveResults{
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
		},
	}
732 733 734 735 736 737

	defer func() {
		if failRet.Status.ErrorCode != commonpb.ErrorCode_Success {
			metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc()
		}
	}()
738
	if !node.isHealthy() {
739 740
		failRet.Status.Reason = msgQueryNodeIsUnhealthy(Params.QueryNodeCfg.GetNodeID())
		return failRet, nil
741 742 743 744
	}
	log.Debug("Received QueryRequest", zap.String("vchannel", req.GetDmlChannel()), zap.Int64s("segmentIDs", req.GetSegmentIDs()))

	if node.queryShardService == nil {
745 746 747 748
		failRet.Status.Reason = "queryShardService is nil"
		return failRet, nil
	}

749 750
	qs, err := node.queryShardService.getQueryShard(req.GetDmlChannel())
	if err != nil {
751
		log.Warn("Query failed, failed to get query shard", zap.String("dml channel", req.GetDmlChannel()), zap.Error(err))
752 753 754 755
		failRet.Status.Reason = err.Error()
		return failRet, nil
	}

756
	tr := timerecord.NewTimeRecorder(fmt.Sprintf("retrieve %d", req.Req.CollectionID))
757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
	if req.FromShardLeader {
		// construct a queryTask
		queryTask := newQueryTask(ctx, req)
		queryTask.QS = qs
		queryTask.DataScope = querypb.DataScope_Historical
		err2 := node.scheduler.AddReadTask(ctx, queryTask)
		if err2 != nil {
			failRet.Status.Reason = err2.Error()
			return failRet, nil
		}

		err2 = queryTask.WaitToFinish()
		if err2 != nil {
			failRet.Status.Reason = err2.Error()
			return failRet, nil
		}
773 774 775 776 777 778 779 780
		failRet.Status.ErrorCode = commonpb.ErrorCode_Success
		metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
			metrics.QueryLabel).Observe(float64(queryTask.queueDur.Milliseconds()))
		metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
			metrics.QueryLabel).Observe(float64(queryTask.reduceDur.Milliseconds()))
		latency := tr.ElapseSpan()
		metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel).Observe(float64(latency.Milliseconds()))
		metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
781
		return queryTask.Ret, nil
782 783
	}

784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834
	cluster, ok := qs.clusterService.getShardCluster(req.GetDmlChannel())
	if !ok {
		failRet.Status.Reason = fmt.Sprintf("channel %s leader is not here", req.GetDmlChannel())
		return failRet, nil
	}

	// add cancel when error occurs
	queryCtx, cancel := context.WithCancel(ctx)
	defer cancel()

	var results []*internalpb.RetrieveResults
	var streamingResult *internalpb.RetrieveResults
	var wg sync.WaitGroup

	var errCluster error
	wg.Add(1)
	go func() {
		defer wg.Done()
		// shard leader dispatches request to its shard cluster
		oResults, cErr := cluster.Query(queryCtx, req)
		if cErr != nil {
			log.Warn("failed to query cluster", zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(cErr))
			log.Info("czs_query_cluster_cancel", zap.Error(cErr))
			errCluster = cErr
			cancel()
			return
		}
		results = oResults
	}()

	var errStreaming error
	wg.Add(1)
	go func() {
		defer wg.Done()
		streamingTask := newQueryTask(queryCtx, req)
		streamingTask.DataScope = querypb.DataScope_Streaming
		streamingTask.QS = qs
		err2 := node.scheduler.AddReadTask(queryCtx, streamingTask)
		defer func() {
			errStreaming = err2
			if err2 != nil {
				cancel()
			}
		}()
		if err2 != nil {
			return
		}
		err2 = streamingTask.WaitToFinish()
		if err2 != nil {
			return
		}
835 836 837 838
		metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
			metrics.QueryLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
		metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
			metrics.QueryLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
839 840 841 842 843 844 845 846 847 848 849 850 851 852
		streamingResult = streamingTask.Ret
	}()
	wg.Wait()

	var mainErr error
	if errCluster != nil {
		mainErr = errCluster
		if errors.Is(errCluster, context.Canceled) {
			if errStreaming != nil {
				mainErr = errStreaming
			}
		}
	} else if errStreaming != nil {
		mainErr = errStreaming
853 854
	}

855 856 857 858 859 860 861 862 863 864
	if mainErr != nil {
		failRet.Status.Reason = mainErr.Error()
		return failRet, nil
	}
	results = append(results, streamingResult)
	ret, err2 := mergeInternalRetrieveResults(results)
	if err2 != nil {
		failRet.Status.Reason = err2.Error()
		return failRet, nil
	}
865 866 867 868
	failRet.Status.ErrorCode = commonpb.ErrorCode_Success
	latency := tr.ElapseSpan()
	metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel).Observe(float64(latency.Milliseconds()))
	metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
869
	return ret, nil
870 871
}

872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896
// SyncReplicaSegments syncs replica node & segments states
func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) {
	if !node.isHealthy() {
		return &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    msgQueryNodeIsUnhealthy(Params.QueryNodeCfg.GetNodeID()),
		}, nil
	}

	log.Debug("Received SyncReplicaSegments request", zap.String("vchannelName", req.GetVchannelName()))

	err := node.ShardClusterService.SyncReplicaSegments(req.GetVchannelName(), req.GetReplicaSegments())
	if err != nil {
		log.Warn("failed to sync replica semgents,", zap.String("vchannel", req.GetVchannelName()), zap.Error(err))
		return &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    err.Error(),
		}, nil
	}

	log.Debug("SyncReplicaSegments Done", zap.String("vchannel", req.GetVchannelName()))

	return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}

G
godchen 已提交
897
// GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ...
898
// TODO(dragondriver): cache the Metrics and set a retention to the cache
899 900 901
func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
	if !node.isHealthy() {
		log.Warn("QueryNode.GetMetrics failed",
X
Xiaofan 已提交
902
			zap.Int64("node_id", Params.QueryNodeCfg.GetNodeID()),
903
			zap.String("req", req.Request),
X
Xiaofan 已提交
904
			zap.Error(errQueryNodeIsUnhealthy(Params.QueryNodeCfg.GetNodeID())))
905 906 907 908

		return &milvuspb.GetMetricsResponse{
			Status: &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UnexpectedError,
X
Xiaofan 已提交
909
				Reason:    msgQueryNodeIsUnhealthy(Params.QueryNodeCfg.GetNodeID()),
910 911 912 913 914 915 916 917
			},
			Response: "",
		}, nil
	}

	metricType, err := metricsinfo.ParseMetricType(req.Request)
	if err != nil {
		log.Warn("QueryNode.GetMetrics failed to parse metric type",
X
Xiaofan 已提交
918
			zap.Int64("node_id", Params.QueryNodeCfg.GetNodeID()),
919 920 921 922 923 924 925 926 927 928 929 930 931 932
			zap.String("req", req.Request),
			zap.Error(err))

		return &milvuspb.GetMetricsResponse{
			Status: &commonpb.Status{
				ErrorCode: commonpb.ErrorCode_UnexpectedError,
				Reason:    err.Error(),
			},
			Response: "",
		}, nil
	}

	if metricType == metricsinfo.SystemInfoMetrics {
		metrics, err := getSystemInfoMetrics(ctx, req, node)
X
Xiaofan 已提交
933 934
		if err != nil {
			log.Warn("QueryNode.GetMetrics failed",
X
Xiaofan 已提交
935
				zap.Int64("node_id", Params.QueryNodeCfg.GetNodeID()),
X
Xiaofan 已提交
936 937 938 939
				zap.String("req", req.Request),
				zap.String("metric_type", metricType),
				zap.Error(err))
		}
940

G
godchen 已提交
941
		return metrics, nil
942 943 944
	}

	log.Debug("QueryNode.GetMetrics failed, request metric type is not implemented yet",
X
Xiaofan 已提交
945
		zap.Int64("node_id", Params.QueryNodeCfg.GetNodeID()),
946 947 948 949 950 951 952 953 954 955 956
		zap.String("req", req.Request),
		zap.String("metric_type", metricType))

	return &milvuspb.GetMetricsResponse{
		Status: &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_UnexpectedError,
			Reason:    metricsinfo.MsgUnimplementedMetric,
		},
		Response: "",
	}, nil
}