未验证 提交 a077bad8 编写于 作者: Y yah01 提交者: GitHub

Make updating replica atomic, balance plan idempotent (#17686)

Signed-off-by: Nyah01 <yang.cen@zilliz.com>
上级 7d51b652
package querycoord package querycoord
import "sort" import (
"sort"
)
type balancer interface { type Balancer interface {
addNode(nodeID int64) ([]*balancePlan, error) AddNode(nodeID int64) ([]*balancePlan, error)
removeNode(nodeID int64) []*balancePlan RemoveNode(nodeID int64) []*balancePlan
rebalance() []*balancePlan Rebalance() []*balancePlan
} }
// Plan for adding/removing node from replica,
// adds node into targetReplica,
// removes node from sourceReplica.
// Set the replica ID to invalidReplicaID to avoid adding/removing into/from replica
type balancePlan struct { type balancePlan struct {
nodeID int64 nodes []UniqueID
sourceReplica int64 sourceReplica UniqueID
targetReplica int64 targetReplica UniqueID
} }
type replicaBalancer struct { type replicaBalancer struct {
...@@ -23,7 +29,7 @@ func newReplicaBalancer(meta Meta, cluster Cluster) *replicaBalancer { ...@@ -23,7 +29,7 @@ func newReplicaBalancer(meta Meta, cluster Cluster) *replicaBalancer {
return &replicaBalancer{meta, cluster} return &replicaBalancer{meta, cluster}
} }
func (b *replicaBalancer) addNode(nodeID int64) ([]*balancePlan, error) { func (b *replicaBalancer) AddNode(nodeID int64) ([]*balancePlan, error) {
// allocate this node to all collections replicas // allocate this node to all collections replicas
var ret []*balancePlan var ret []*balancePlan
collections := b.meta.showCollections() collections := b.meta.showCollections()
...@@ -36,6 +42,25 @@ func (b *replicaBalancer) addNode(nodeID int64) ([]*balancePlan, error) { ...@@ -36,6 +42,25 @@ func (b *replicaBalancer) addNode(nodeID int64) ([]*balancePlan, error) {
continue continue
} }
foundNode := false
for _, replica := range replicas {
for _, replicaNode := range replica.NodeIds {
if replicaNode == nodeID {
foundNode = true
break
}
}
if foundNode {
break
}
}
// This node is serving this collection
if foundNode {
continue
}
replicaAvailableMemory := make(map[UniqueID]uint64, len(replicas)) replicaAvailableMemory := make(map[UniqueID]uint64, len(replicas))
for _, replica := range replicas { for _, replica := range replicas {
replicaAvailableMemory[replica.ReplicaID] = getReplicaAvailableMemory(b.cluster, replica) replicaAvailableMemory[replica.ReplicaID] = getReplicaAvailableMemory(b.cluster, replica)
...@@ -48,7 +73,7 @@ func (b *replicaBalancer) addNode(nodeID int64) ([]*balancePlan, error) { ...@@ -48,7 +73,7 @@ func (b *replicaBalancer) addNode(nodeID int64) ([]*balancePlan, error) {
}) })
ret = append(ret, &balancePlan{ ret = append(ret, &balancePlan{
nodeID: nodeID, nodes: []UniqueID{nodeID},
sourceReplica: invalidReplicaID, sourceReplica: invalidReplicaID,
targetReplica: replicas[0].GetReplicaID(), targetReplica: replicas[0].GetReplicaID(),
}) })
...@@ -56,11 +81,11 @@ func (b *replicaBalancer) addNode(nodeID int64) ([]*balancePlan, error) { ...@@ -56,11 +81,11 @@ func (b *replicaBalancer) addNode(nodeID int64) ([]*balancePlan, error) {
return ret, nil return ret, nil
} }
func (b *replicaBalancer) removeNode(nodeID int64) []*balancePlan { func (b *replicaBalancer) RemoveNode(nodeID int64) []*balancePlan {
// for this version, querynode does not support move from a replica to another // for this version, querynode does not support move from a replica to another
return nil return nil
} }
func (b *replicaBalancer) rebalance() []*balancePlan { func (b *replicaBalancer) Rebalance() []*balancePlan {
return nil return nil
} }
package querycoord
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/stretchr/testify/assert"
)
func TestAddNode(t *testing.T) {
defer removeAllSession()
ctx := context.Background()
coord, err := startQueryCoord(ctx)
assert.NoError(t, err)
defer coord.Stop()
node1, err := startQueryNodeServer(ctx)
assert.NoError(t, err)
defer node1.stop()
node2, err := startQueryNodeServer(ctx)
assert.NoError(t, err)
defer node2.stop()
waitQueryNodeOnline(coord.cluster, node1.queryNodeID)
waitQueryNodeOnline(coord.cluster, node2.queryNodeID)
loadCollectionReq := &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genDefaultCollectionSchema(false),
ReplicaNumber: 1,
}
status, err := coord.LoadCollection(ctx, loadCollectionReq)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
waitLoadCollectionDone(ctx, coord, defaultCollectionID)
plans, err := coord.groupBalancer.AddNode(node1.queryNodeID)
assert.NoError(t, err)
assert.Equal(t, 0, len(plans))
plans, err = coord.groupBalancer.AddNode(node2.queryNodeID)
assert.NoError(t, err)
assert.Equal(t, 0, len(plans))
newNodeID := node2.queryNodeID + 1
plans, err = coord.groupBalancer.AddNode(newNodeID)
assert.NoError(t, err)
assert.Equal(t, 1, len(plans))
}
...@@ -81,7 +81,7 @@ func waitLoadCollectionDone(ctx context.Context, queryCoord *QueryCoord, collect ...@@ -81,7 +81,7 @@ func waitLoadCollectionDone(ctx context.Context, queryCoord *QueryCoord, collect
return errors.New("showCollection failed") return errors.New("showCollection failed")
} }
loadDone := true loadDone := len(res.InMemoryPercentages) > 0
for _, percent := range res.InMemoryPercentages { for _, percent := range res.InMemoryPercentages {
if percent < 100 { if percent < 100 {
loadDone = false loadDone = false
...@@ -90,6 +90,8 @@ func waitLoadCollectionDone(ctx context.Context, queryCoord *QueryCoord, collect ...@@ -90,6 +90,8 @@ func waitLoadCollectionDone(ctx context.Context, queryCoord *QueryCoord, collect
if loadDone { if loadDone {
break break
} }
time.Sleep(500 * time.Millisecond)
} }
return nil return nil
......
...@@ -97,7 +97,7 @@ type QueryCoord struct { ...@@ -97,7 +97,7 @@ type QueryCoord struct {
factory dependency.Factory factory dependency.Factory
chunkManager storage.ChunkManager chunkManager storage.ChunkManager
groupBalancer balancer groupBalancer Balancer
} }
// Register register query service at etcd // Register register query service at etcd
...@@ -341,16 +341,37 @@ func (qc *QueryCoord) watchNodeLoop() { ...@@ -341,16 +341,37 @@ func (qc *QueryCoord) watchNodeLoop() {
defer qc.loopWg.Done() defer qc.loopWg.Done()
log.Info("QueryCoord start watch node loop") log.Info("QueryCoord start watch node loop")
unallocatedNodes := qc.getUnallocatedNodes() onlineNodes := qc.cluster.OnlineNodeIDs()
for _, n := range unallocatedNodes { for _, node := range onlineNodes {
if err := qc.allocateNode(n); err != nil { if err := qc.allocateNode(node); err != nil {
log.Warn("unable to allcoate node", zap.Int64("nodeID", n), zap.Error(err)) log.Warn("unable to allcoate node", zap.Int64("nodeID", node), zap.Error(err))
} }
} }
go qc.loadBalanceNodeLoop(ctx) go qc.loadBalanceNodeLoop(ctx)
for _, nodeID := range qc.cluster.OfflineNodeIDs() { offlineNodes := make(typeutil.UniqueSet)
qc.offlineNodesChan <- nodeID collections := qc.meta.showCollections()
for _, collection := range collections {
for _, replicaID := range collection.ReplicaIds {
replica, err := qc.meta.getReplicaByID(replicaID)
if err != nil {
log.Warn("failed to get replica",
zap.Int64("replicaID", replicaID),
zap.Error(err))
continue
}
for _, node := range replica.NodeIds {
ok, err := qc.cluster.IsOnline(node)
if err != nil || !ok {
offlineNodes.Insert(node)
}
}
}
}
for node := range offlineNodes {
qc.offlineNodesChan <- node
} }
// TODO silverxia add Rewatch logic // TODO silverxia add Rewatch logic
...@@ -359,7 +380,7 @@ func (qc *QueryCoord) watchNodeLoop() { ...@@ -359,7 +380,7 @@ func (qc *QueryCoord) watchNodeLoop() {
} }
func (qc *QueryCoord) allocateNode(nodeID int64) error { func (qc *QueryCoord) allocateNode(nodeID int64) error {
plans, err := qc.groupBalancer.addNode(nodeID) plans, err := qc.groupBalancer.AddNode(nodeID)
if err != nil { if err != nil {
return err return err
} }
...@@ -371,22 +392,6 @@ func (qc *QueryCoord) allocateNode(nodeID int64) error { ...@@ -371,22 +392,6 @@ func (qc *QueryCoord) allocateNode(nodeID int64) error {
return nil return nil
} }
func (qc *QueryCoord) getUnallocatedNodes() []int64 {
onlines := qc.cluster.OnlineNodeIDs()
var ret []int64
for _, n := range onlines {
replica, err := qc.meta.getReplicasByNodeID(n)
if err != nil {
log.Warn("failed to get replica", zap.Int64("nodeID", n), zap.Error(err))
continue
}
if replica == nil {
ret = append(ret, n)
}
}
return ret
}
func (qc *QueryCoord) handleNodeEvent(ctx context.Context) { func (qc *QueryCoord) handleNodeEvent(ctx context.Context) {
for { for {
select { select {
...@@ -439,6 +444,8 @@ func (qc *QueryCoord) handleNodeEvent(ctx context.Context) { ...@@ -439,6 +444,8 @@ func (qc *QueryCoord) handleNodeEvent(ctx context.Context) {
} }
func (qc *QueryCoord) loadBalanceNodeLoop(ctx context.Context) { func (qc *QueryCoord) loadBalanceNodeLoop(ctx context.Context) {
const LoadBalanceRetryAfter = 100 * time.Millisecond
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
...@@ -463,13 +470,14 @@ func (qc *QueryCoord) loadBalanceNodeLoop(ctx context.Context) { ...@@ -463,13 +470,14 @@ func (qc *QueryCoord) loadBalanceNodeLoop(ctx context.Context) {
meta: qc.meta, meta: qc.meta,
} }
qc.metricsCacheManager.InvalidateSystemInfoMetrics() qc.metricsCacheManager.InvalidateSystemInfoMetrics()
//TODO:: deal enqueue error
err := qc.scheduler.Enqueue(loadBalanceTask) err := qc.scheduler.Enqueue(loadBalanceTask)
if err != nil { if err != nil {
log.Warn("failed to enqueue LoadBalance task into the scheduler", log.Warn("failed to enqueue LoadBalance task into the scheduler",
zap.Int64("nodeID", node), zap.Int64("nodeID", node),
zap.Error(err)) zap.Error(err))
qc.offlineNodesChan <- node qc.offlineNodesChan <- node
time.Sleep(LoadBalanceRetryAfter)
continue continue
} }
...@@ -483,6 +491,7 @@ func (qc *QueryCoord) loadBalanceNodeLoop(ctx context.Context) { ...@@ -483,6 +491,7 @@ func (qc *QueryCoord) loadBalanceNodeLoop(ctx context.Context) {
zap.Int64("nodeID", node), zap.Int64("nodeID", node),
zap.Error(err)) zap.Error(err))
qc.offlineNodesChan <- node qc.offlineNodesChan <- node
time.Sleep(LoadBalanceRetryAfter)
continue continue
} }
......
...@@ -238,7 +238,7 @@ func TestWatchNodeLoop(t *testing.T) { ...@@ -238,7 +238,7 @@ func TestWatchNodeLoop(t *testing.T) {
err = removeNodeSession(nodeID) err = removeNodeSession(nodeID)
assert.Nil(t, err) assert.Nil(t, err)
waitAllQueryNodeOffline(queryCoord.cluster, onlineNodeIDs) waitAllQueryNodeOffline(queryCoord.cluster, onlineNodeIDs...)
queryCoord.Stop() queryCoord.Stop()
err = removeAllSession() err = removeAllSession()
...@@ -620,6 +620,8 @@ func TestLoadBalanceSegmentLoop(t *testing.T) { ...@@ -620,6 +620,8 @@ func TestLoadBalanceSegmentLoop(t *testing.T) {
if len(segmentInfos) > 0 { if len(segmentInfos) > 0 {
break break
} }
time.Sleep(time.Second)
} }
queryCoord.Stop() queryCoord.Stop()
......
...@@ -65,7 +65,7 @@ func removeAllSession() error { ...@@ -65,7 +65,7 @@ func removeAllSession() error {
return nil return nil
} }
func waitAllQueryNodeOffline(cluster Cluster, nodeIDs []int64) bool { func waitAllQueryNodeOffline(cluster Cluster, nodeIDs ...int64) bool {
for { for {
allOffline := true allOffline := true
for _, nodeID := range nodeIDs { for _, nodeID := range nodeIDs {
...@@ -136,7 +136,7 @@ func TestQueryNode_MultiNode_stop(t *testing.T) { ...@@ -136,7 +136,7 @@ func TestQueryNode_MultiNode_stop(t *testing.T) {
err = removeNodeSession(queryNode2.queryNodeID) err = removeNodeSession(queryNode2.queryNodeID)
assert.Nil(t, err) assert.Nil(t, err)
waitAllQueryNodeOffline(queryCoord.cluster, onlineNodeIDs) waitAllQueryNodeOffline(queryCoord.cluster, onlineNodeIDs...)
queryCoord.Stop() queryCoord.Stop()
err = removeAllSession() err = removeAllSession()
assert.Nil(t, err) assert.Nil(t, err)
...@@ -182,7 +182,7 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) { ...@@ -182,7 +182,7 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) {
err = removeNodeSession(queryNode3.queryNodeID) err = removeNodeSession(queryNode3.queryNodeID)
assert.Nil(t, err) assert.Nil(t, err)
waitAllQueryNodeOffline(queryCoord.cluster, onlineNodeIDs) waitAllQueryNodeOffline(queryCoord.cluster, onlineNodeIDs...)
queryCoord.Stop() queryCoord.Stop()
err = removeAllSession() err = removeAllSession()
assert.Nil(t, err) assert.Nil(t, err)
......
...@@ -169,12 +169,13 @@ func (rep *ReplicaInfos) ApplyBalancePlan(p *balancePlan, kv kv.MetaKv) error { ...@@ -169,12 +169,13 @@ func (rep *ReplicaInfos) ApplyBalancePlan(p *balancePlan, kv kv.MetaKv) error {
// generate ReplicaInfo to save to MetaKv // generate ReplicaInfo to save to MetaKv
if sourceReplica != nil { if sourceReplica != nil {
// remove node from replica node list // remove node from replica node list
removeNodeFromReplica(sourceReplica, p.nodeID) sourceReplica.NodeIds = removeFromSlice(sourceReplica.NodeIds, p.nodes...)
replicasChanged = append(replicasChanged, sourceReplica) replicasChanged = append(replicasChanged, sourceReplica)
} }
if targetReplica != nil { if targetReplica != nil {
// add node to replica // add node to replica
targetReplica.NodeIds = append(targetReplica.NodeIds, p.nodeID) targetReplica.NodeIds = append(targetReplica.NodeIds, p.nodes...)
targetReplica.NodeIds = uniqueSlice(targetReplica.NodeIds)
replicasChanged = append(replicasChanged, targetReplica) replicasChanged = append(replicasChanged, targetReplica)
} }
...@@ -223,18 +224,6 @@ func (rep *ReplicaInfos) UpdateShardLeader(replicaID UniqueID, dmChannel string, ...@@ -223,18 +224,6 @@ func (rep *ReplicaInfos) UpdateShardLeader(replicaID UniqueID, dmChannel string,
return nil return nil
} }
// removeNodeFromReplica helper function to remove nodeID from replica NodeIds list.
func removeNodeFromReplica(replica *milvuspb.ReplicaInfo, nodeID int64) *milvuspb.ReplicaInfo {
for i := 0; i < len(replica.NodeIds); i++ {
if replica.NodeIds[i] != nodeID {
continue
}
replica.NodeIds = append(replica.NodeIds[:i], replica.NodeIds[i+1:]...)
return replica
}
return replica
}
// save the replicas into etcd. // save the replicas into etcd.
func saveReplica(meta kv.MetaKv, replicas ...*milvuspb.ReplicaInfo) error { func saveReplica(meta kv.MetaKv, replicas ...*milvuspb.ReplicaInfo) error {
data := make(map[string]string) data := make(map[string]string)
......
...@@ -135,7 +135,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) { ...@@ -135,7 +135,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) {
t.Run("source replica not exist", func(t *testing.T) { t.Run("source replica not exist", func(t *testing.T) {
replicas := NewReplicaInfos() replicas := NewReplicaInfos()
err := replicas.ApplyBalancePlan(&balancePlan{ err := replicas.ApplyBalancePlan(&balancePlan{
nodeID: 1, nodes: []UniqueID{1},
sourceReplica: 1, sourceReplica: 1,
targetReplica: invalidReplicaID, targetReplica: invalidReplicaID,
}, kv) }, kv)
...@@ -145,7 +145,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) { ...@@ -145,7 +145,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) {
t.Run("target replica not exist", func(t *testing.T) { t.Run("target replica not exist", func(t *testing.T) {
replicas := NewReplicaInfos() replicas := NewReplicaInfos()
err := replicas.ApplyBalancePlan(&balancePlan{ err := replicas.ApplyBalancePlan(&balancePlan{
nodeID: 1, nodes: []UniqueID{1},
sourceReplica: invalidReplicaID, sourceReplica: invalidReplicaID,
targetReplica: 1, targetReplica: 1,
}, kv) }, kv)
...@@ -162,7 +162,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) { ...@@ -162,7 +162,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) {
}) })
err := replicas.ApplyBalancePlan(&balancePlan{ err := replicas.ApplyBalancePlan(&balancePlan{
nodeID: 2, nodes: []UniqueID{2},
sourceReplica: invalidReplicaID, sourceReplica: invalidReplicaID,
targetReplica: 1, targetReplica: 1,
}, kv) }, kv)
...@@ -189,7 +189,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) { ...@@ -189,7 +189,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) {
}) })
err := replicas.ApplyBalancePlan(&balancePlan{ err := replicas.ApplyBalancePlan(&balancePlan{
nodeID: 1, nodes: []UniqueID{1},
sourceReplica: 1, sourceReplica: 1,
targetReplica: invalidReplicaID, targetReplica: invalidReplicaID,
}, kv) }, kv)
...@@ -216,7 +216,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) { ...@@ -216,7 +216,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) {
}) })
err := replicas.ApplyBalancePlan(&balancePlan{ err := replicas.ApplyBalancePlan(&balancePlan{
nodeID: 2, nodes: []UniqueID{2},
sourceReplica: 1, sourceReplica: 1,
targetReplica: invalidReplicaID, targetReplica: invalidReplicaID,
}, kv) }, kv)
...@@ -235,7 +235,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) { ...@@ -235,7 +235,7 @@ func TestReplicaInfos_ApplyBalancePlan(t *testing.T) {
}) })
err := replicas.ApplyBalancePlan(&balancePlan{ err := replicas.ApplyBalancePlan(&balancePlan{
nodeID: 2, nodes: []UniqueID{2},
sourceReplica: invalidReplicaID, sourceReplica: invalidReplicaID,
targetReplica: 1, targetReplica: 1,
}, kv) }, kv)
......
...@@ -2306,26 +2306,13 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error { ...@@ -2306,26 +2306,13 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error {
offlineNodes.Insert(nodeID) offlineNodes.Insert(nodeID)
} }
for _, replica := range replicas { for replicaID := range replicas {
replica := replica replicaID := replicaID
wg.Go(func() error { wg.Go(func() error {
onlineNodes := make([]UniqueID, 0, len(replica.NodeIds)) return lbt.meta.applyReplicaBalancePlan(&balancePlan{
for _, nodeID := range replica.NodeIds { nodes: lbt.SourceNodeIDs,
if !offlineNodes.Contain(nodeID) { sourceReplica: replicaID,
onlineNodes = append(onlineNodes, nodeID) })
}
}
replica.NodeIds = onlineNodes
err := lbt.meta.setReplicaInfo(replica)
if err != nil {
log.Error("failed to remove offline nodes from replica info",
zap.Int64("replicaID", replica.ReplicaID),
zap.Error(err))
return err
}
return nil
}) })
} }
} }
......
...@@ -685,12 +685,14 @@ func (scheduler *TaskScheduler) scheduleLoop() { ...@@ -685,12 +685,14 @@ func (scheduler *TaskScheduler) scheduleLoop() {
} }
} }
err = triggerTask.globalPostExecute(triggerTask.traceCtx()) if triggerTask.getResultInfo().ErrorCode == commonpb.ErrorCode_Success {
if err != nil { err = triggerTask.globalPostExecute(triggerTask.traceCtx())
log.Error("scheduleLoop: failed to execute globalPostExecute() of task", if err != nil {
zap.Int64("taskID", triggerTask.getTaskID()), log.Error("scheduleLoop: failed to execute globalPostExecute() of task",
zap.Error(err)) zap.Int64("taskID", triggerTask.getTaskID()),
triggerTask.setResultInfo(err) zap.Error(err))
triggerTask.setResultInfo(err)
}
} }
err = removeTaskFromKVFn(triggerTask) err = removeTaskFromKVFn(triggerTask)
......
...@@ -203,6 +203,12 @@ func removeFromSlice(origin []UniqueID, del ...UniqueID) []UniqueID { ...@@ -203,6 +203,12 @@ func removeFromSlice(origin []UniqueID, del ...UniqueID) []UniqueID {
return set.Collect() return set.Collect()
} }
func uniqueSlice(origin []UniqueID) []UniqueID {
set := make(typeutil.UniqueSet, len(origin))
set.Insert(origin...)
return set.Collect()
}
func getReplicaAvailableMemory(cluster Cluster, replica *milvuspb.ReplicaInfo) uint64 { func getReplicaAvailableMemory(cluster Cluster, replica *milvuspb.ReplicaInfo) uint64 {
availableMemory := uint64(0) availableMemory := uint64(0)
nodes := getNodeInfos(cluster, replica.NodeIds) nodes := getNodeInfos(cluster, replica.NodeIds)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册