未验证 提交 318cdbeb 编写于 作者: C congqixia 提交者: GitHub

Add heap for dmlChannels to get next channel (#18451)

Signed-off-by: NCongqi Xia <congqi.xia@zilliz.com>
上级 0420a8be
......@@ -17,13 +17,13 @@
package rootcoord
import (
"container/heap"
"context"
"fmt"
"sync"
"github.com/milvus-io/milvus/internal/metrics"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
......@@ -31,9 +31,102 @@ import (
)
type dmlMsgStream struct {
ms msgstream.MsgStream
mutex sync.RWMutex
refcnt int64
ms msgstream.MsgStream
mutex sync.RWMutex
refcnt int64 // current in use count
used int64 // total used counter in current run, not stored in meta so meant to be inaccurate
idx int64 // idx for name
pos int // position in the heap slice
}
// RefCnt returns refcnt with mutex protection.
func (dms *dmlMsgStream) RefCnt() int64 {
dms.mutex.RLock()
defer dms.mutex.RUnlock()
return dms.refcnt
}
// RefCnt returns refcnt with mutex protection.
func (dms *dmlMsgStream) Used() int64 {
dms.mutex.RLock()
defer dms.mutex.RUnlock()
return dms.used
}
// IncRefcnt increases refcnt.
func (dms *dmlMsgStream) IncRefcnt() {
dms.mutex.Lock()
defer dms.mutex.Unlock()
dms.refcnt++
}
// BookUsage increases used, acting like reservation usage.
func (dms *dmlMsgStream) BookUsage() {
dms.mutex.Lock()
defer dms.mutex.Unlock()
dms.used++
}
// DecRefCnt decreases refcnt only.
func (dms *dmlMsgStream) DecRefCnt() {
dms.mutex.Lock()
defer dms.mutex.Unlock()
if dms.refcnt > 0 {
dms.refcnt--
} else {
log.Warn("Try to remove channel with no ref count", zap.Int64("idx", dms.idx))
}
}
// channelsHeap implements heap.Interface to performs like an priority queue.
type channelsHeap []*dmlMsgStream
// Len is the number of elements in the collection.
func (h channelsHeap) Len() int {
return len(h)
}
// Less reports whether the element with index i
// must sort before the element with index j.
func (h channelsHeap) Less(i int, j int) bool {
ei, ej := h[i], h[j]
// use less refcnt first
rci, rcj := ei.RefCnt(), ej.RefCnt()
if rci != rcj {
return rci < rcj
}
// used not used channel first
ui, uj := ei.Used(), ej.Used()
if ui != uj {
return ui < uj
}
// all number same, used alphabetic smaller one
return ei.idx < ej.idx
}
// Swap swaps the elements with indexes i and j.
func (h channelsHeap) Swap(i int, j int) {
h[i], h[j] = h[j], h[i]
h[i].pos, h[j].pos = i, j
}
// Push adds a new element to the heap.
func (h *channelsHeap) Push(x interface{}) {
item := x.(*dmlMsgStream)
*h = append(*h, item)
}
// Pop implements heap.Interface, pop the last value.
func (h *channelsHeap) Pop() interface{} {
old := *h
n := len(old)
item := old[n-1]
old[n-1] = nil
*h = old[0 : n-1]
return item
}
type dmlChannels struct {
......@@ -41,18 +134,21 @@ type dmlChannels struct {
factory msgstream.Factory
namePrefix string
capacity int64
idx *atomic.Int64
pool sync.Map
// pool maintains channelName => dmlMsgStream mapping, stable
pool sync.Map
// mut protects channlsHeap only
mut sync.Mutex
// channelsHeap is the heap to pop next dms for use
channelsHeap channelsHeap
}
func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePrefix string, chanNum int64) *dmlChannels {
d := &dmlChannels{
ctx: ctx,
factory: factory,
namePrefix: chanNamePrefix,
capacity: chanNum,
idx: atomic.NewInt64(0),
pool: sync.Map{},
ctx: ctx,
factory: factory,
namePrefix: chanNamePrefix,
capacity: chanNum,
channelsHeap: make([]*dmlMsgStream, 0, chanNum),
}
for i := int64(0); i < chanNum; i++ {
......@@ -63,12 +159,19 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref
panic("Failed to add msgstream")
}
ms.AsProducer([]string{name})
d.pool.Store(name, &dmlMsgStream{
dms := &dmlMsgStream{
ms: ms,
mutex: sync.RWMutex{},
refcnt: 0,
})
used: 0,
idx: i,
pos: int(i),
}
d.pool.Store(name, dms)
d.channelsHeap = append(d.channelsHeap, dms)
}
heap.Init(&d.channelsHeap)
log.Debug("init dml channels", zap.Int64("num", chanNum))
metrics.RootCoordNumOfDMLChannel.Add(float64(chanNum))
metrics.RootCoordNumOfMsgStream.Add(float64(chanNum))
......@@ -77,20 +180,24 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref
}
func (d *dmlChannels) getChannelName() string {
cnt := d.idx.Inc()
return genChannelName(d.namePrefix, (cnt-1)%d.capacity)
d.mut.Lock()
defer d.mut.Unlock()
// get first item from heap
item := d.channelsHeap[0]
item.BookUsage()
heap.Fix(&d.channelsHeap, 0)
return genChannelName(d.namePrefix, item.idx)
}
func (d *dmlChannels) listChannels() []string {
var chanNames []string
d.pool.Range(
func(k, v interface{}) bool {
dms := v.(*dmlMsgStream)
dms.mutex.RLock()
if dms.refcnt > 0 {
chanNames = append(chanNames, k.(string))
if dms.RefCnt() > 0 {
chanNames = append(chanNames, genChannelName(d.namePrefix, dms.idx))
}
dms.mutex.RUnlock()
return true
})
return chanNames
......@@ -161,9 +268,10 @@ func (d *dmlChannels) addChannels(names ...string) {
}
dms := v.(*dmlMsgStream)
dms.mutex.Lock()
dms.refcnt++
dms.mutex.Unlock()
d.mut.Lock()
dms.IncRefcnt()
heap.Fix(&d.channelsHeap, dms.pos)
d.mut.Unlock()
}
}
......@@ -176,13 +284,10 @@ func (d *dmlChannels) removeChannels(names ...string) {
}
dms := v.(*dmlMsgStream)
dms.mutex.Lock()
if dms.refcnt > 0 {
dms.refcnt--
} else {
log.Warn("Try to remove channel with no ref count", zap.String("channel name", name))
}
dms.mutex.Unlock()
d.mut.Lock()
dms.DecRefCnt()
heap.Fix(&d.channelsHeap, dms.pos)
d.mut.Unlock()
}
}
......
......@@ -17,8 +17,10 @@
package rootcoord
import (
"container/heap"
"context"
"errors"
"math/rand"
"sync"
"testing"
......@@ -31,6 +33,93 @@ import (
"github.com/stretchr/testify/require"
)
func TestDmlMsgStream(t *testing.T) {
t.Run("RefCnt", func(t *testing.T) {
dms := &dmlMsgStream{refcnt: 0}
assert.Equal(t, int64(0), dms.RefCnt())
assert.Equal(t, int64(0), dms.Used())
dms.IncRefcnt()
assert.Equal(t, int64(1), dms.RefCnt())
dms.BookUsage()
assert.Equal(t, int64(1), dms.Used())
dms.DecRefCnt()
assert.Equal(t, int64(0), dms.RefCnt())
assert.Equal(t, int64(1), dms.Used())
dms.DecRefCnt()
assert.Equal(t, int64(0), dms.RefCnt())
assert.Equal(t, int64(1), dms.Used())
})
}
func TestChannelsHeap(t *testing.T) {
chanNum := 16
var h channelsHeap
h = make([]*dmlMsgStream, 0, chanNum)
for i := int64(0); i < int64(chanNum); i++ {
dms := &dmlMsgStream{
refcnt: 0,
used: 0,
idx: i,
pos: int(i),
}
h = append(h, dms)
}
check := func(h channelsHeap) bool {
for i := 0; i < chanNum; i++ {
if h[i].pos != i {
return false
}
if i*2+1 < chanNum {
if !h.Less(i, i*2+1) {
t.Log("left", i)
return false
}
}
if i*2+2 < chanNum {
if !h.Less(i, i*2+2) {
t.Log("right", i)
return false
}
}
}
return true
}
heap.Init(&h)
assert.True(t, check(h))
// add usage for all
for i := 0; i < chanNum; i++ {
h[0].BookUsage()
h[0].IncRefcnt()
heap.Fix(&h, 0)
}
assert.True(t, check(h))
for i := 0; i < chanNum; i++ {
assert.EqualValues(t, 1, h[i].RefCnt())
assert.EqualValues(t, 1, h[i].Used())
}
randIdx := rand.Intn(chanNum)
target := h[randIdx]
h[randIdx].DecRefCnt()
heap.Fix(&h, randIdx)
assert.EqualValues(t, 0, target.pos)
next := heap.Pop(&h).(*dmlMsgStream)
assert.Equal(t, target, next)
}
func TestDmlChannels(t *testing.T) {
const (
dmlChanPrefix = "rootcoord-dml"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册