cluster.go 24.6 KB
Newer Older
1 2 3 4 5 6
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
7 8
// with the License. You may obtain a copy of the License at
//
9
//     http://www.apache.org/licenses/LICENSE-2.0
10
//
11 12 13 14 15
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
16

17
package querycoord
18 19 20

import (
	"context"
21
	"encoding/json"
22 23
	"errors"
	"fmt"
24 25
	"path/filepath"
	"strconv"
26 27
	"sync"

28
	"github.com/golang/protobuf/proto"
29 30
	"go.uber.org/zap"

31
	"github.com/milvus-io/milvus/internal/kv"
32
	etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
33
	minioKV "github.com/milvus-io/milvus/internal/kv/minio"
34
	"github.com/milvus-io/milvus/internal/log"
35
	"github.com/milvus-io/milvus/internal/proto/commonpb"
36
	"github.com/milvus-io/milvus/internal/proto/internalpb"
X
xige-16 已提交
37
	"github.com/milvus-io/milvus/internal/proto/milvuspb"
38
	"github.com/milvus-io/milvus/internal/proto/querypb"
39
	"github.com/milvus-io/milvus/internal/storage"
40
	"github.com/milvus-io/milvus/internal/util/sessionutil"
41
	"github.com/milvus-io/milvus/internal/util/typeutil"
42 43 44
)

const (
45 46
	queryNodeMetaPrefix = "queryCoord-queryNodeMeta"
	queryNodeInfoPrefix = "queryCoord-queryNodeInfo"
47 48
)

49
// Cluster manages all query node connections and grpc requests
50 51 52 53 54 55 56 57 58
type Cluster interface {
	reloadFromKV() error
	getComponentInfos(ctx context.Context) ([]*internalpb.ComponentInfo, error)

	loadSegments(ctx context.Context, nodeID int64, in *querypb.LoadSegmentsRequest) error
	releaseSegments(ctx context.Context, nodeID int64, in *querypb.ReleaseSegmentsRequest) error
	getNumSegments(nodeID int64) (int, error)

	watchDmChannels(ctx context.Context, nodeID int64, in *querypb.WatchDmChannelsRequest) error
59
	watchDeltaChannels(ctx context.Context, nodeID int64, in *querypb.WatchDeltaChannelsRequest) error
60
	//TODO:: removeDmChannel
61 62 63
	getNumDmChannels(nodeID int64) (int, error)

	hasWatchedQueryChannel(ctx context.Context, nodeID int64, collectionID UniqueID) bool
G
godchen 已提交
64
	hasWatchedDeltaChannel(ctx context.Context, nodeID int64, collectionID UniqueID) bool
65 66 67 68 69 70
	getCollectionInfosByID(ctx context.Context, nodeID int64) []*querypb.CollectionInfo
	addQueryChannel(ctx context.Context, nodeID int64, in *querypb.AddQueryChannelRequest) error
	removeQueryChannel(ctx context.Context, nodeID int64, in *querypb.RemoveQueryChannelRequest) error
	releaseCollection(ctx context.Context, nodeID int64, in *querypb.ReleaseCollectionRequest) error
	releasePartitions(ctx context.Context, nodeID int64, in *querypb.ReleasePartitionsRequest) error
	getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error)
71
	getSegmentInfoByNode(ctx context.Context, nodeID int64, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error)
72
	getSegmentInfoByID(ctx context.Context, segmentID UniqueID) (*querypb.SegmentInfo, error)
73

74
	registerNode(ctx context.Context, session *sessionutil.Session, id UniqueID, state nodeState) error
75
	getNodeInfoByID(nodeID int64) (Node, error)
76 77
	removeNodeInfo(nodeID int64) error
	stopNode(nodeID int64)
78 79 80
	onlineNodes() (map[int64]Node, error)
	isOnline(nodeID int64) (bool, error)
	offlineNodes() (map[int64]Node, error)
81
	hasNode(nodeID int64) bool
82

83
	allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error
84 85
	allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error

86 87 88
	getSessionVersion() int64

	getMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) []queryNodeGetMetricsResponse
89
	estimateSegmentsSize(segments *querypb.LoadSegmentsRequest) (int64, error)
90 91
}

X
xige-16 已提交
92 93
type newQueryNodeFn func(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error)

94 95 96 97 98 99 100 101
type nodeState int

const (
	disConnect nodeState = 0
	online     nodeState = 1
	offline    nodeState = 2
)

102
type queryNodeCluster struct {
103 104
	ctx    context.Context
	cancel context.CancelFunc
105
	client *etcdkv.EtcdKV
106
	dataKV kv.DataKV
107

108 109 110
	session        *sessionutil.Session
	sessionVersion int64

111
	sync.RWMutex
112 113 114 115 116
	clusterMeta      Meta
	nodes            map[int64]Node
	newNodeFn        newQueryNodeFn
	segmentAllocator SegmentAllocatePolicy
	channelAllocator ChannelAllocatePolicy
117
	segSizeEstimator func(request *querypb.LoadSegmentsRequest, dataKV kv.DataKV) (int64, error)
118 119
}

120
func newQueryNodeCluster(ctx context.Context, clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn, session *sessionutil.Session) (Cluster, error) {
121
	childCtx, cancel := context.WithCancel(ctx)
122
	nodes := make(map[int64]Node)
123
	c := &queryNodeCluster{
124 125 126 127 128 129 130 131 132
		ctx:              childCtx,
		cancel:           cancel,
		client:           kv,
		session:          session,
		clusterMeta:      clusterMeta,
		nodes:            nodes,
		newNodeFn:        newNodeFn,
		segmentAllocator: defaultSegAllocatePolicy(),
		channelAllocator: defaultChannelAllocatePolicy(),
133
		segSizeEstimator: defaultSegEstimatePolicy(),
134
	}
135 136 137 138 139
	err := c.reloadFromKV()
	if err != nil {
		return nil, err
	}

140 141 142 143 144 145 146 147 148 149 150 151 152 153
	option := &minioKV.Option{
		Address:           Params.MinioEndPoint,
		AccessKeyID:       Params.MinioAccessKeyID,
		SecretAccessKeyID: Params.MinioSecretAccessKey,
		UseSSL:            Params.MinioUseSSLStr,
		CreateBucket:      true,
		BucketName:        Params.MinioBucketName,
	}

	c.dataKV, err = minioKV.NewMinIOKV(ctx, option)
	if err != nil {
		return nil, err
	}

154 155 156
	return c, nil
}

157 158
// Reload trigger task, trigger task states, internal task, internal task state from etcd
// Assign the internal task to the corresponding trigger task as a child task
159
func (c *queryNodeCluster) reloadFromKV() error {
160 161 162 163 164 165 166
	toLoadMetaNodeIDs := make([]int64, 0)
	// get current online session
	onlineNodeSessions, version, _ := c.session.GetSessions(typeutil.QueryNodeRole)
	onlineSessionMap := make(map[int64]*sessionutil.Session)
	for _, session := range onlineNodeSessions {
		nodeID := session.ServerID
		onlineSessionMap[nodeID] = session
167
	}
168 169 170
	for nodeID, session := range onlineSessionMap {
		log.Debug("ReloadFromKV: register a queryNode to cluster", zap.Any("nodeID", nodeID))
		err := c.registerNode(c.ctx, session, nodeID, disConnect)
171
		if err != nil {
172
			log.Error("query node failed to register", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error()))
173 174
			return err
		}
175 176 177
		toLoadMetaNodeIDs = append(toLoadMetaNodeIDs, nodeID)
	}
	c.sessionVersion = version
178

179 180 181 182 183 184 185 186
	// load node information before power off from etcd
	oldStringNodeIDs, oldNodeSessions, err := c.client.LoadWithPrefix(queryNodeInfoPrefix)
	if err != nil {
		log.Error("reloadFromKV: get previous node info from etcd error", zap.Error(err))
		return err
	}
	for index := range oldStringNodeIDs {
		nodeID, err := strconv.ParseInt(filepath.Base(oldStringNodeIDs[index]), 10, 64)
187
		if err != nil {
188
			log.Error("WatchNodeLoop: parse nodeID error", zap.Error(err))
189 190
			return err
		}
191 192 193 194 195 196 197 198 199 200 201 202 203
		if _, ok := onlineSessionMap[nodeID]; !ok {
			session := &sessionutil.Session{}
			err = json.Unmarshal([]byte(oldNodeSessions[index]), session)
			if err != nil {
				log.Error("WatchNodeLoop: unmarshal session error", zap.Error(err))
				return err
			}
			err = c.registerNode(context.Background(), session, nodeID, offline)
			if err != nil {
				log.Debug("ReloadFromKV: failed to add queryNode to cluster", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error()))
				return err
			}
			toLoadMetaNodeIDs = append(toLoadMetaNodeIDs, nodeID)
204 205
		}
	}
206 207 208

	// load collection meta of queryNode from etcd
	for _, nodeID := range toLoadMetaNodeIDs {
209
		infoPrefix := fmt.Sprintf("%s/%d", queryNodeMetaPrefix, nodeID)
210
		_, collectionValues, err := c.client.LoadWithPrefix(infoPrefix)
211 212 213
		if err != nil {
			return err
		}
214 215
		for _, value := range collectionValues {
			collectionInfo := &querypb.CollectionInfo{}
216
			err = proto.Unmarshal([]byte(value), collectionInfo)
217 218 219
			if err != nil {
				return err
			}
220
			err = c.nodes[nodeID].setCollectionInfo(collectionInfo)
221
			if err != nil {
222
				log.Debug("ReloadFromKV: failed to add queryNode meta to cluster", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error()))
223 224
				return err
			}
225
			log.Debug("ReloadFromKV: reload collection info from etcd", zap.Any("info", collectionInfo))
226 227 228
		}
	}
	return nil
229 230
}

231 232 233 234
func (c *queryNodeCluster) getSessionVersion() int64 {
	return c.sessionVersion
}

235
func (c *queryNodeCluster) getComponentInfos(ctx context.Context) ([]*internalpb.ComponentInfo, error) {
236 237 238
	c.RLock()
	defer c.RUnlock()
	subComponentInfos := make([]*internalpb.ComponentInfo, 0)
239
	nodes, err := c.getOnlineNodes()
240
	if err != nil {
241
		log.Debug("GetComponentInfos: failed get on service nodes", zap.String("error info", err.Error()))
242 243
		return nil, err
	}
244 245 246
	for _, node := range nodes {
		componentState := node.getComponentInfo(ctx)
		subComponentInfos = append(subComponentInfos, componentState)
247 248
	}

249
	return subComponentInfos, nil
250 251
}

252
func (c *queryNodeCluster) loadSegments(ctx context.Context, nodeID int64, in *querypb.LoadSegmentsRequest) error {
253 254
	c.RLock()
	var targetNode Node
255
	if node, ok := c.nodes[nodeID]; ok {
256 257 258 259 260 261
		targetNode = node
	}
	c.RUnlock()

	if targetNode != nil {
		err := targetNode.loadSegments(ctx, in)
262
		if err != nil {
263
			log.Debug("loadSegments: queryNode load segments error", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error()))
264
			return err
265
		}
266

267
		return nil
268
	}
269
	return fmt.Errorf("loadSegments: Can't find query node by nodeID, nodeID = %d", nodeID)
270 271
}

272
func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in *querypb.ReleaseSegmentsRequest) error {
273 274
	c.RLock()
	var targetNode Node
275
	if node, ok := c.nodes[nodeID]; ok {
276 277 278 279 280 281
		targetNode = node
	}
	c.RUnlock()

	if targetNode != nil {
		if !targetNode.isOnline() {
282
			return errors.New("node offline")
283
		}
284

285
		err := targetNode.releaseSegments(ctx, in)
286
		if err != nil {
287
			log.Debug("releaseSegments: queryNode release segments error", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error()))
288
			return err
289
		}
290 291

		return nil
292 293
	}

294
	return fmt.Errorf("releaseSegments: Can't find query node by nodeID, nodeID = %d", nodeID)
295 296
}

297
func (c *queryNodeCluster) watchDmChannels(ctx context.Context, nodeID int64, in *querypb.WatchDmChannelsRequest) error {
298 299
	c.RLock()
	var targetNode Node
300
	if node, ok := c.nodes[nodeID]; ok {
301 302 303 304 305 306
		targetNode = node
	}
	c.RUnlock()

	if targetNode != nil {
		err := targetNode.watchDmChannels(ctx, in)
307
		if err != nil {
308
			log.Debug("watchDmChannels: queryNode watch dm channel error", zap.String("error", err.Error()))
309
			return err
310
		}
311 312 313 314
		channels := make([]string, 0)
		for _, info := range in.Infos {
			channels = append(channels, info.ChannelName)
		}
315

316 317 318
		collectionID := in.CollectionID
		err = c.clusterMeta.addDmChannel(collectionID, nodeID, channels)
		if err != nil {
319
			log.Debug("watchDmChannels: queryNode watch dm channel error", zap.String("error", err.Error()))
320
			return err
321
		}
322 323

		return nil
324
	}
325
	return fmt.Errorf("watchDmChannels: Can't find query node by nodeID, nodeID = %d", nodeID)
326 327
}

328
func (c *queryNodeCluster) watchDeltaChannels(ctx context.Context, nodeID int64, in *querypb.WatchDeltaChannelsRequest) error {
329 330
	c.RLock()
	var targetNode Node
331
	if node, ok := c.nodes[nodeID]; ok {
332 333 334 335 336 337
		targetNode = node
	}
	c.RUnlock()

	if targetNode != nil {
		err := targetNode.watchDeltaChannels(ctx, in)
338
		if err != nil {
339
			log.Debug("watchDeltaChannels: queryNode watch delta channel error", zap.String("error", err.Error()))
340 341
			return err
		}
G
godchen 已提交
342 343
		err = c.clusterMeta.setDeltaChannel(in.CollectionID, in.Infos)
		if err != nil {
344
			log.Debug("watchDeltaChannels: queryNode watch delta channel error", zap.String("error", err.Error()))
G
godchen 已提交
345 346 347
			return err
		}

348 349
		return nil
	}
350 351

	return fmt.Errorf("watchDeltaChannels: Can't find query node by nodeID, nodeID = %d", nodeID)
352 353
}

G
godchen 已提交
354
func (c *queryNodeCluster) hasWatchedDeltaChannel(ctx context.Context, nodeID int64, collectionID UniqueID) bool {
355 356
	c.RLock()
	defer c.RUnlock()
G
godchen 已提交
357 358 359 360

	return c.nodes[nodeID].hasWatchedDeltaChannel(collectionID)
}

361
func (c *queryNodeCluster) hasWatchedQueryChannel(ctx context.Context, nodeID int64, collectionID UniqueID) bool {
362 363
	c.RLock()
	defer c.RUnlock()
364 365 366 367

	return c.nodes[nodeID].hasWatchedQueryChannel(collectionID)
}

368
func (c *queryNodeCluster) addQueryChannel(ctx context.Context, nodeID int64, in *querypb.AddQueryChannelRequest) error {
369 370
	c.RLock()
	var targetNode Node
371
	if node, ok := c.nodes[nodeID]; ok {
372 373 374 375 376 377
		targetNode = node
	}
	c.RUnlock()

	if targetNode != nil {
		err := targetNode.addQueryChannel(ctx, in)
378
		if err != nil {
379
			log.Debug("addQueryChannel: queryNode add query channel error", zap.String("error", err.Error()))
380
			return err
381
		}
382
		return nil
383 384
	}

385
	return fmt.Errorf("addQueryChannel: can't find query node by nodeID, nodeID = %d", nodeID)
386
}
387
func (c *queryNodeCluster) removeQueryChannel(ctx context.Context, nodeID int64, in *querypb.RemoveQueryChannelRequest) error {
388 389
	c.RLock()
	var targetNode Node
390
	if node, ok := c.nodes[nodeID]; ok {
391 392 393 394 395 396
		targetNode = node
	}
	c.RUnlock()

	if targetNode != nil {
		err := targetNode.removeQueryChannel(ctx, in)
397
		if err != nil {
398
			log.Debug("removeQueryChannel: queryNode remove query channel error", zap.String("error", err.Error()))
399
			return err
400
		}
401 402

		return nil
403 404
	}

405
	return fmt.Errorf("removeQueryChannel: can't find query node by nodeID, nodeID = %d", nodeID)
406 407
}

408
func (c *queryNodeCluster) releaseCollection(ctx context.Context, nodeID int64, in *querypb.ReleaseCollectionRequest) error {
409 410
	c.RLock()
	var targetNode Node
411
	if node, ok := c.nodes[nodeID]; ok {
412 413 414 415 416 417
		targetNode = node
	}
	c.RUnlock()

	if targetNode != nil {
		err := targetNode.releaseCollection(ctx, in)
418
		if err != nil {
419
			log.Debug("releaseCollection: queryNode release collection error", zap.String("error", err.Error()))
420
			return err
421
		}
422 423
		err = c.clusterMeta.releaseCollection(in.CollectionID)
		if err != nil {
424
			log.Debug("releaseCollection: meta release collection error", zap.String("error", err.Error()))
425
			return err
426
		}
427
		return nil
428 429
	}

430
	return fmt.Errorf("releaseCollection: can't find query node by nodeID, nodeID = %d", nodeID)
431 432
}

433
func (c *queryNodeCluster) releasePartitions(ctx context.Context, nodeID int64, in *querypb.ReleasePartitionsRequest) error {
434 435
	c.RLock()
	var targetNode Node
436
	if node, ok := c.nodes[nodeID]; ok {
437 438 439 440 441 442
		targetNode = node
	}
	c.RUnlock()

	if targetNode != nil {
		err := targetNode.releasePartitions(ctx, in)
443
		if err != nil {
444
			log.Debug("releasePartitions: queryNode release partitions error", zap.String("error", err.Error()))
445
			return err
446
		}
447

448 449 450
		for _, partitionID := range in.PartitionIDs {
			err = c.clusterMeta.releasePartition(in.CollectionID, partitionID)
			if err != nil {
451
				log.Debug("releasePartitions: meta release partitions error", zap.String("error", err.Error()))
452
				return err
453 454
			}
		}
455
		return nil
456 457
	}

458
	return fmt.Errorf("releasePartitions: can't find query node by nodeID, nodeID = %d", nodeID)
459 460
}

461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
func (c *queryNodeCluster) getSegmentInfoByID(ctx context.Context, segmentID UniqueID) (*querypb.SegmentInfo, error) {
	c.RLock()
	defer c.RUnlock()

	segmentInfo, err := c.clusterMeta.getSegmentInfoByID(segmentID)
	if err != nil {
		return nil, err
	}
	if node, ok := c.nodes[segmentInfo.NodeID]; ok {
		res, err := node.getSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{
			Base: &commonpb.MsgBase{
				MsgType: commonpb.MsgType_SegmentInfo,
			},
			CollectionID: segmentInfo.CollectionID,
		})
		if err != nil {
			return nil, err
		}
		if res != nil {
			for _, info := range res.Infos {
				if info.SegmentID == segmentID {
					return info, nil
				}
			}
		}
		return nil, fmt.Errorf("updateSegmentInfo: can't find segment %d on query node %d", segmentID, segmentInfo.NodeID)
	}

	return nil, fmt.Errorf("updateSegmentInfo: can't find query node by nodeID, nodeID = %d", segmentInfo.NodeID)
}

492
func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) {
493 494
	c.RLock()
	defer c.RUnlock()
495 496

	segmentInfos := make([]*querypb.SegmentInfo, 0)
497 498
	for _, node := range c.nodes {
		res, err := node.getSegmentInfo(ctx, in)
499 500 501
		if err != nil {
			return nil, err
		}
502 503 504
		if res != nil {
			segmentInfos = append(segmentInfos, res.Infos...)
		}
505 506
	}

507
	//TODO::update meta
508 509 510
	return segmentInfos, nil
}

511 512 513 514 515 516 517 518 519 520 521 522
func (c *queryNodeCluster) getSegmentInfoByNode(ctx context.Context, nodeID int64, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) {
	c.RLock()
	defer c.RUnlock()

	if node, ok := c.nodes[nodeID]; ok {
		res, err := node.getSegmentInfo(ctx, in)
		if err != nil {
			return nil, err
		}
		return res.Infos, nil
	}

523
	return nil, fmt.Errorf("getSegmentInfoByNode: can't find query node by nodeID, nodeID = %d", nodeID)
524 525
}

526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
type queryNodeGetMetricsResponse struct {
	resp *milvuspb.GetMetricsResponse
	err  error
}

func (c *queryNodeCluster) getMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) []queryNodeGetMetricsResponse {
	c.RLock()
	defer c.RUnlock()

	ret := make([]queryNodeGetMetricsResponse, 0, len(c.nodes))
	for _, node := range c.nodes {
		resp, err := node.getMetrics(ctx, in)
		ret = append(ret, queryNodeGetMetricsResponse{
			resp: resp,
			err:  err,
		})
	}

	return ret
}

547
func (c *queryNodeCluster) getNumDmChannels(nodeID int64) (int, error) {
548 549
	c.RLock()
	defer c.RUnlock()
550 551

	if _, ok := c.nodes[nodeID]; !ok {
552
		return 0, fmt.Errorf("getNumDmChannels: Can't find query node by nodeID, nodeID = %d", nodeID)
553 554 555
	}

	numChannel := 0
556 557
	collectionInfos := c.clusterMeta.showCollections()
	for _, info := range collectionInfos {
558 559 560 561 562 563 564 565 566 567
		for _, channelInfo := range info.ChannelInfos {
			if channelInfo.NodeIDLoaded == nodeID {
				numChannel++
			}
		}
	}
	return numChannel, nil
}

func (c *queryNodeCluster) getNumSegments(nodeID int64) (int, error) {
568 569
	c.RLock()
	defer c.RUnlock()
570 571

	if _, ok := c.nodes[nodeID]; !ok {
572
		return 0, fmt.Errorf("getNumSegments: Can't find query node by nodeID, nodeID = %d", nodeID)
573 574 575
	}

	numSegment := 0
576 577 578 579 580 581 582
	segmentInfos := make([]*querypb.SegmentInfo, 0)
	collectionInfos := c.clusterMeta.showCollections()
	for _, info := range collectionInfos {
		res := c.clusterMeta.showSegmentInfos(info.CollectionID, nil)
		segmentInfos = append(segmentInfos, res...)
	}
	for _, info := range segmentInfos {
583 584 585 586 587 588 589
		if info.NodeID == nodeID {
			numSegment++
		}
	}
	return numSegment, nil
}

590
func (c *queryNodeCluster) registerNode(ctx context.Context, session *sessionutil.Session, id UniqueID, state nodeState) error {
591 592 593
	c.Lock()
	defer c.Unlock()

594
	if _, ok := c.nodes[id]; !ok {
595 596
		sessionJSON, err := json.Marshal(session)
		if err != nil {
597
			log.Debug("registerNode: marshal session error", zap.Int64("nodeID", id), zap.Any("address", session))
598 599 600 601 602 603 604
			return err
		}
		key := fmt.Sprintf("%s/%d", queryNodeInfoPrefix, id)
		err = c.client.Save(key, string(sessionJSON))
		if err != nil {
			return err
		}
605
		node, err := c.newNodeFn(ctx, session.Address, id, c.client)
X
xige-16 已提交
606
		if err != nil {
607
			log.Debug("registerNode: create a new query node failed", zap.Int64("nodeID", id), zap.Error(err))
X
xige-16 已提交
608 609
			return err
		}
610 611 612 613 614
		node.setState(state)
		if state < online {
			go node.start()
		}
		c.nodes[id] = node
615
		log.Debug("registerNode: create a new query node", zap.Int64("nodeID", id), zap.String("address", session.Address))
616 617
		return nil
	}
618
	return fmt.Errorf("registerNode: node %d alredy exists in cluster", id)
619
}
620

621
func (c *queryNodeCluster) getNodeInfoByID(nodeID int64) (Node, error) {
622 623 624 625
	c.RLock()
	defer c.RUnlock()

	if node, ok := c.nodes[nodeID]; ok {
626 627 628 629 630
		nodeInfo, err := node.getNodeInfo()
		if err != nil {
			return nil, err
		}
		return nodeInfo, nil
631 632
	}

633
	return nil, fmt.Errorf("getNodeInfoByID: query node %d not exist", nodeID)
634 635
}

636
func (c *queryNodeCluster) removeNodeInfo(nodeID int64) error {
637 638 639
	c.Lock()
	defer c.Unlock()

640
	key := fmt.Sprintf("%s/%d", queryNodeInfoPrefix, nodeID)
641 642 643 644 645
	err := c.client.Remove(key)
	if err != nil {
		return err
	}

646 647 648 649 650 651
	if _, ok := c.nodes[nodeID]; ok {
		err = c.nodes[nodeID].clearNodeInfo()
		if err != nil {
			return err
		}
		delete(c.nodes, nodeID)
652
		log.Debug("removeNodeInfo: delete nodeInfo in cluster MetaReplica and etcd", zap.Int64("nodeID", nodeID))
653 654 655
	}

	return nil
656 657
}

658
func (c *queryNodeCluster) stopNode(nodeID int64) {
659 660
	c.RLock()
	defer c.RUnlock()
X
xige-16 已提交
661

662 663
	if node, ok := c.nodes[nodeID]; ok {
		node.stop()
664
		log.Debug("stopNode: queryNode offline", zap.Int64("nodeID", nodeID))
665 666 667
	}
}

668
func (c *queryNodeCluster) onlineNodes() (map[int64]Node, error) {
669 670
	c.RLock()
	defer c.RUnlock()
671

672
	return c.getOnlineNodes()
673 674
}

675
func (c *queryNodeCluster) getOnlineNodes() (map[int64]Node, error) {
676
	nodes := make(map[int64]Node)
677
	for nodeID, node := range c.nodes {
678
		if node.isOnline() {
679
			nodes[nodeID] = node
680 681
		}
	}
682
	if len(nodes) == 0 {
683
		return nil, errors.New("getOnlineNodes: no queryNode is alive")
684 685
	}

686 687 688
	return nodes, nil
}

689 690 691 692 693 694 695
func (c *queryNodeCluster) offlineNodes() (map[int64]Node, error) {
	c.RLock()
	defer c.RUnlock()

	return c.getOfflineNodes()
}

696 697 698 699 700 701 702 703 704 705 706
func (c *queryNodeCluster) hasNode(nodeID int64) bool {
	c.RLock()
	defer c.RUnlock()

	if _, ok := c.nodes[nodeID]; ok {
		return true
	}

	return false
}

707 708 709 710 711 712 713 714
func (c *queryNodeCluster) getOfflineNodes() (map[int64]Node, error) {
	nodes := make(map[int64]Node)
	for nodeID, node := range c.nodes {
		if node.isOffline() {
			nodes[nodeID] = node
		}
	}
	if len(nodes) == 0 {
715
		return nil, errors.New("getOfflineNodes: no queryNode is offline")
716 717 718 719 720 721
	}

	return nodes, nil
}

func (c *queryNodeCluster) isOnline(nodeID int64) (bool, error) {
722 723
	c.RLock()
	defer c.RUnlock()
724 725

	if node, ok := c.nodes[nodeID]; ok {
726
		return node.isOnline(), nil
727 728
	}

729
	return false, fmt.Errorf("isOnline: query node %d not exist", nodeID)
730 731
}

732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749
//func (c *queryNodeCluster) printMeta() {
//	c.RLock()
//	defer c.RUnlock()
//
//	for id, node := range c.nodes {
//		if node.isOnline() {
//			collectionInfos := node.showCollections()
//			for _, info := range collectionInfos {
//				log.Debug("PrintMeta: query coordinator cluster info: collectionInfo", zap.Int64("nodeID", id), zap.Int64("collectionID", info.CollectionID), zap.Any("info", info))
//			}
//
//			queryChannelInfos := node.showWatchedQueryChannels()
//			for _, info := range queryChannelInfos {
//				log.Debug("PrintMeta: query coordinator cluster info: watchedQueryChannelInfo", zap.Int64("nodeID", id), zap.Int64("collectionID", info.CollectionID), zap.Any("info", info))
//			}
//		}
//	}
//}
750 751 752 753 754 755 756 757 758 759

func (c *queryNodeCluster) getCollectionInfosByID(ctx context.Context, nodeID int64) []*querypb.CollectionInfo {
	c.RLock()
	defer c.RUnlock()
	if node, ok := c.nodes[nodeID]; ok {
		return node.showCollections()
	}

	return nil
}
760

761 762
func (c *queryNodeCluster) allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error {
	return c.segmentAllocator(ctx, reqs, c, wait, excludeNodeIDs, includeNodeIDs)
763 764 765 766 767
}

func (c *queryNodeCluster) allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error {
	return c.channelAllocator(ctx, reqs, c, wait, excludeNodeIDs)
}
768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824

func (c *queryNodeCluster) estimateSegmentsSize(segments *querypb.LoadSegmentsRequest) (int64, error) {
	return c.segSizeEstimator(segments, c.dataKV)
}

func defaultSegEstimatePolicy() segEstimatePolicy {
	return estimateSegmentsSize
}

type segEstimatePolicy func(request *querypb.LoadSegmentsRequest, dataKv kv.DataKV) (int64, error)

func estimateSegmentsSize(segments *querypb.LoadSegmentsRequest, kvClient kv.DataKV) (int64, error) {
	segmentSize := int64(0)

	//TODO:: collection has multi vector field
	//vecFields := make([]int64, 0)
	//for _, field := range segments.Schema.Fields {
	//	if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
	//		vecFields = append(vecFields, field.FieldID)
	//	}
	//}
	// get fields data size, if len(indexFieldIDs) == 0, vector field would be involved in fieldBinLogs
	for _, loadInfo := range segments.Infos {
		// get index size
		if loadInfo.EnableIndex {
			for _, pathInfo := range loadInfo.IndexPathInfos {
				for _, path := range pathInfo.IndexFilePaths {
					indexSize, err := storage.EstimateMemorySize(kvClient, path)
					if err != nil {
						indexSize, err = storage.GetBinlogSize(kvClient, path)
						if err != nil {
							return 0, err
						}
					}
					segmentSize += indexSize
				}
			}
			continue
		}

		// get binlog size
		for _, binlogPath := range loadInfo.BinlogPaths {
			for _, path := range binlogPath.Binlogs {
				binlogSize, err := storage.EstimateMemorySize(kvClient, path)
				if err != nil {
					binlogSize, err = storage.GetBinlogSize(kvClient, path)
					if err != nil {
						return 0, err
					}
				}
				segmentSize += binlogSize
			}
		}
	}

	return segmentSize, nil
}