diff --git a/internal/querycoord/meta.go b/internal/querycoord/meta.go index c1de1095ab7bb2ff46acaceae7a9cb3d9999e1ca..fa02f7f26c5eeecda28956a5a4eb49e4b73fd1b5 100644 --- a/internal/querycoord/meta.go +++ b/internal/querycoord/meta.go @@ -105,6 +105,7 @@ type Meta interface { getReplicasByCollectionID(collectionID int64) ([]*milvuspb.ReplicaInfo, error) getReplicasByNodeID(nodeID int64) ([]*milvuspb.ReplicaInfo, error) applyReplicaBalancePlan(p *balancePlan) error + updateShardLeader(replicaID UniqueID, dmChannel string, leaderID UniqueID, leaderAddr string) error } // MetaReplica records the current load information on all querynodes @@ -1288,6 +1289,10 @@ func (m *MetaReplica) applyReplicaBalancePlan(p *balancePlan) error { return m.replicas.ApplyBalancePlan(p, m.getKvClient()) } +func (m *MetaReplica) updateShardLeader(replicaID UniqueID, dmChannel string, leaderID UniqueID, leaderAddr string) error { + return m.replicas.UpdateShardLeader(replicaID, dmChannel, leaderID, leaderAddr, m.getKvClient()) +} + //func (m *MetaReplica) printMeta() { // m.RLock() // defer m.RUnlock() diff --git a/internal/querycoord/replica.go b/internal/querycoord/replica.go index 3381d800ae7f78859f309e553fe50bfe3b310868..3b4c16af3cccb22939520ae5f0e462ac5fb9350f 100644 --- a/internal/querycoord/replica.go +++ b/internal/querycoord/replica.go @@ -179,22 +179,9 @@ func (rep *ReplicaInfos) ApplyBalancePlan(p *balancePlan, kv kv.MetaKv) error { } // save to etcd first - if len(replicasChanged) > 0 { - data := make(map[string]string) - - for _, info := range replicasChanged { - infoBytes, err := proto.Marshal(info) - if err != nil { - return err - } - - key := fmt.Sprintf("%s/%d", ReplicaMetaPrefix, info.ReplicaID) - data[key] = string(infoBytes) - } - err := kv.MultiSave(data) - if err != nil { - return err - } + err := saveReplica(kv, replicasChanged...) + if err != nil { + return err } // apply change to in-memory meta @@ -209,6 +196,33 @@ func (rep *ReplicaInfos) ApplyBalancePlan(p *balancePlan, kv kv.MetaKv) error { return nil } +func (rep *ReplicaInfos) UpdateShardLeader(replicaID UniqueID, dmChannel string, leaderID UniqueID, leaderAddr string, meta kv.MetaKv) error { + rep.globalGuard.Lock() + defer rep.globalGuard.Unlock() + + replica, ok := rep.get(replicaID) + if !ok { + return fmt.Errorf("replica %v not found", replicaID) + } + + for _, shard := range replica.ShardReplicas { + if shard.DmChannelName == dmChannel { + shard.LeaderID = leaderID + shard.LeaderAddr = leaderAddr + break + } + } + + err := saveReplica(meta, replica) + if err != nil { + return err + } + + rep.upsert(replica) + + return nil +} + // removeNodeFromReplica helper function to remove nodeID from replica NodeIds list. func removeNodeFromReplica(replica *milvuspb.ReplicaInfo, nodeID int64) *milvuspb.ReplicaInfo { for i := 0; i < len(replica.NodeIds); i++ { @@ -220,3 +234,20 @@ func removeNodeFromReplica(replica *milvuspb.ReplicaInfo, nodeID int64) *milvusp } return replica } + +// save the replicas into etcd. +func saveReplica(meta kv.MetaKv, replicas ...*milvuspb.ReplicaInfo) error { + data := make(map[string]string) + + for _, info := range replicas { + infoBytes, err := proto.Marshal(info) + if err != nil { + return err + } + + key := fmt.Sprintf("%s/%d", ReplicaMetaPrefix, info.ReplicaID) + data[key] = string(infoBytes) + } + + return meta.MultiSave(data) +} diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 54acc02d6f4f4b8ffa92e94a937bf1b4eb739941..8c9cbc707b87ed13c5c4c350324cde306f53a912 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -2405,50 +2405,36 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error { for _, childTask := range lbt.getChildTask() { if task, ok := childTask.(*watchDmChannelTask); ok { wg.Go(func() error { - nodeInfo, err := lbt.cluster.getNodeInfoByID(task.NodeID) + leaderID := task.NodeID + dmChannel := task.Infos[0].ChannelName + + nodeInfo, err := lbt.cluster.getNodeInfoByID(leaderID) if err != nil { log.Error("failed to get node info to update shard leader info", zap.Int64("triggerTaskID", lbt.getTaskID()), zap.Int64("taskID", task.getTaskID()), - zap.Int64("nodeID", task.NodeID), - zap.String("dmChannel", task.Infos[0].ChannelName), + zap.Int64("nodeID", leaderID), + zap.String("dmChannel", dmChannel), zap.Error(err)) return err } - replica, err := lbt.meta.getReplicaByID(task.ReplicaID) + err = lbt.meta.updateShardLeader(task.ReplicaID, dmChannel, leaderID, nodeInfo.(*queryNode).address) if err != nil { - log.Error("failed to get replica to update shard leader info", + log.Error("failed to update shard leader info of replica", zap.Int64("triggerTaskID", lbt.getTaskID()), zap.Int64("taskID", task.getTaskID()), zap.Int64("replicaID", task.ReplicaID), - zap.String("dmChannel", task.Infos[0].ChannelName), + zap.String("dmChannel", dmChannel), zap.Error(err)) return err } - for _, shard := range replica.ShardReplicas { - if shard.DmChannelName == task.Infos[0].ChannelName { - log.Debug("LoadBalance: update shard leader", - zap.Int64("triggerTaskID", lbt.getTaskID()), - zap.Int64("taskID", task.getTaskID()), - zap.Int64("oldLeader", shard.LeaderID), - zap.Int64("newLeader", task.NodeID)) - shard.LeaderID = task.NodeID - shard.LeaderAddr = nodeInfo.(*queryNode).address - break - } - } - - err = lbt.meta.setReplicaInfo(replica) - if err != nil { - log.Error("failed to remove offline nodes from replica info", - zap.Int64("triggerTaskID", lbt.getTaskID()), - zap.Int64("taskID", task.getTaskID()), - zap.Int64("replicaID", replica.ReplicaID), - zap.Error(err)) - return err - } + log.Debug("LoadBalance: update shard leader", + zap.Int64("triggerTaskID", lbt.getTaskID()), + zap.Int64("taskID", task.getTaskID()), + zap.String("dmChannel", dmChannel), + zap.Int64("leader", leaderID)) return nil })