diff --git a/internal/rootcoord/dml_channels.go b/internal/rootcoord/dml_channels.go index d98fd69737064996450a9d0e8531834371b042e7..37e5a00edfdd51ddbdc234055e0ed138f4f90491 100644 --- a/internal/rootcoord/dml_channels.go +++ b/internal/rootcoord/dml_channels.go @@ -17,13 +17,13 @@ package rootcoord import ( + "container/heap" "context" "fmt" "sync" "github.com/milvus-io/milvus/internal/metrics" - "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/log" @@ -31,9 +31,102 @@ import ( ) type dmlMsgStream struct { - ms msgstream.MsgStream - mutex sync.RWMutex - refcnt int64 + ms msgstream.MsgStream + mutex sync.RWMutex + + refcnt int64 // current in use count + used int64 // total used counter in current run, not stored in meta so meant to be inaccurate + idx int64 // idx for name + pos int // position in the heap slice +} + +// RefCnt returns refcnt with mutex protection. +func (dms *dmlMsgStream) RefCnt() int64 { + dms.mutex.RLock() + defer dms.mutex.RUnlock() + return dms.refcnt +} + +// RefCnt returns refcnt with mutex protection. +func (dms *dmlMsgStream) Used() int64 { + dms.mutex.RLock() + defer dms.mutex.RUnlock() + return dms.used +} + +// IncRefcnt increases refcnt. +func (dms *dmlMsgStream) IncRefcnt() { + dms.mutex.Lock() + defer dms.mutex.Unlock() + dms.refcnt++ +} + +// BookUsage increases used, acting like reservation usage. +func (dms *dmlMsgStream) BookUsage() { + dms.mutex.Lock() + defer dms.mutex.Unlock() + dms.used++ +} + +// DecRefCnt decreases refcnt only. +func (dms *dmlMsgStream) DecRefCnt() { + dms.mutex.Lock() + defer dms.mutex.Unlock() + if dms.refcnt > 0 { + dms.refcnt-- + } else { + log.Warn("Try to remove channel with no ref count", zap.Int64("idx", dms.idx)) + } +} + +// channelsHeap implements heap.Interface to performs like an priority queue. +type channelsHeap []*dmlMsgStream + +// Len is the number of elements in the collection. +func (h channelsHeap) Len() int { + return len(h) +} + +// Less reports whether the element with index i +// must sort before the element with index j. +func (h channelsHeap) Less(i int, j int) bool { + ei, ej := h[i], h[j] + // use less refcnt first + rci, rcj := ei.RefCnt(), ej.RefCnt() + if rci != rcj { + return rci < rcj + } + + // used not used channel first + ui, uj := ei.Used(), ej.Used() + if ui != uj { + return ui < uj + } + + // all number same, used alphabetic smaller one + return ei.idx < ej.idx +} + +// Swap swaps the elements with indexes i and j. +func (h channelsHeap) Swap(i int, j int) { + h[i], h[j] = h[j], h[i] + h[i].pos, h[j].pos = i, j +} + +// Push adds a new element to the heap. +func (h *channelsHeap) Push(x interface{}) { + item := x.(*dmlMsgStream) + *h = append(*h, item) +} + +// Pop implements heap.Interface, pop the last value. +func (h *channelsHeap) Pop() interface{} { + old := *h + n := len(old) + item := old[n-1] + old[n-1] = nil + *h = old[0 : n-1] + return item } type dmlChannels struct { @@ -41,18 +134,21 @@ type dmlChannels struct { factory msgstream.Factory namePrefix string capacity int64 - idx *atomic.Int64 - pool sync.Map + // pool maintains channelName => dmlMsgStream mapping, stable + pool sync.Map + // mut protects channlsHeap only + mut sync.Mutex + // channelsHeap is the heap to pop next dms for use + channelsHeap channelsHeap } func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePrefix string, chanNum int64) *dmlChannels { d := &dmlChannels{ - ctx: ctx, - factory: factory, - namePrefix: chanNamePrefix, - capacity: chanNum, - idx: atomic.NewInt64(0), - pool: sync.Map{}, + ctx: ctx, + factory: factory, + namePrefix: chanNamePrefix, + capacity: chanNum, + channelsHeap: make([]*dmlMsgStream, 0, chanNum), } for i := int64(0); i < chanNum; i++ { @@ -63,12 +159,19 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref panic("Failed to add msgstream") } ms.AsProducer([]string{name}) - d.pool.Store(name, &dmlMsgStream{ + + dms := &dmlMsgStream{ ms: ms, - mutex: sync.RWMutex{}, refcnt: 0, - }) + used: 0, + idx: i, + pos: int(i), + } + d.pool.Store(name, dms) + d.channelsHeap = append(d.channelsHeap, dms) } + + heap.Init(&d.channelsHeap) log.Debug("init dml channels", zap.Int64("num", chanNum)) metrics.RootCoordNumOfDMLChannel.Add(float64(chanNum)) metrics.RootCoordNumOfMsgStream.Add(float64(chanNum)) @@ -77,20 +180,24 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref } func (d *dmlChannels) getChannelName() string { - cnt := d.idx.Inc() - return genChannelName(d.namePrefix, (cnt-1)%d.capacity) + d.mut.Lock() + defer d.mut.Unlock() + // get first item from heap + item := d.channelsHeap[0] + item.BookUsage() + heap.Fix(&d.channelsHeap, 0) + return genChannelName(d.namePrefix, item.idx) } func (d *dmlChannels) listChannels() []string { var chanNames []string + d.pool.Range( func(k, v interface{}) bool { dms := v.(*dmlMsgStream) - dms.mutex.RLock() - if dms.refcnt > 0 { - chanNames = append(chanNames, k.(string)) + if dms.RefCnt() > 0 { + chanNames = append(chanNames, genChannelName(d.namePrefix, dms.idx)) } - dms.mutex.RUnlock() return true }) return chanNames @@ -161,9 +268,10 @@ func (d *dmlChannels) addChannels(names ...string) { } dms := v.(*dmlMsgStream) - dms.mutex.Lock() - dms.refcnt++ - dms.mutex.Unlock() + d.mut.Lock() + dms.IncRefcnt() + heap.Fix(&d.channelsHeap, dms.pos) + d.mut.Unlock() } } @@ -176,13 +284,10 @@ func (d *dmlChannels) removeChannels(names ...string) { } dms := v.(*dmlMsgStream) - dms.mutex.Lock() - if dms.refcnt > 0 { - dms.refcnt-- - } else { - log.Warn("Try to remove channel with no ref count", zap.String("channel name", name)) - } - dms.mutex.Unlock() + d.mut.Lock() + dms.DecRefCnt() + heap.Fix(&d.channelsHeap, dms.pos) + d.mut.Unlock() } } diff --git a/internal/rootcoord/dml_channels_test.go b/internal/rootcoord/dml_channels_test.go index 0086ae46434d756757810575ea5659400f55c4ec..65a05f8c72d6d30e769dec8fe066b8fe97380356 100644 --- a/internal/rootcoord/dml_channels_test.go +++ b/internal/rootcoord/dml_channels_test.go @@ -17,8 +17,10 @@ package rootcoord import ( + "container/heap" "context" "errors" + "math/rand" "sync" "testing" @@ -31,6 +33,93 @@ import ( "github.com/stretchr/testify/require" ) +func TestDmlMsgStream(t *testing.T) { + t.Run("RefCnt", func(t *testing.T) { + + dms := &dmlMsgStream{refcnt: 0} + assert.Equal(t, int64(0), dms.RefCnt()) + assert.Equal(t, int64(0), dms.Used()) + + dms.IncRefcnt() + assert.Equal(t, int64(1), dms.RefCnt()) + dms.BookUsage() + assert.Equal(t, int64(1), dms.Used()) + + dms.DecRefCnt() + assert.Equal(t, int64(0), dms.RefCnt()) + assert.Equal(t, int64(1), dms.Used()) + + dms.DecRefCnt() + assert.Equal(t, int64(0), dms.RefCnt()) + assert.Equal(t, int64(1), dms.Used()) + }) +} + +func TestChannelsHeap(t *testing.T) { + chanNum := 16 + var h channelsHeap + h = make([]*dmlMsgStream, 0, chanNum) + + for i := int64(0); i < int64(chanNum); i++ { + dms := &dmlMsgStream{ + refcnt: 0, + used: 0, + idx: i, + pos: int(i), + } + h = append(h, dms) + } + + check := func(h channelsHeap) bool { + for i := 0; i < chanNum; i++ { + if h[i].pos != i { + return false + } + if i*2+1 < chanNum { + if !h.Less(i, i*2+1) { + t.Log("left", i) + return false + } + } + if i*2+2 < chanNum { + if !h.Less(i, i*2+2) { + t.Log("right", i) + return false + } + } + } + return true + } + + heap.Init(&h) + + assert.True(t, check(h)) + + // add usage for all + for i := 0; i < chanNum; i++ { + h[0].BookUsage() + h[0].IncRefcnt() + heap.Fix(&h, 0) + } + + assert.True(t, check(h)) + for i := 0; i < chanNum; i++ { + assert.EqualValues(t, 1, h[i].RefCnt()) + assert.EqualValues(t, 1, h[i].Used()) + } + + randIdx := rand.Intn(chanNum) + + target := h[randIdx] + h[randIdx].DecRefCnt() + heap.Fix(&h, randIdx) + assert.EqualValues(t, 0, target.pos) + + next := heap.Pop(&h).(*dmlMsgStream) + + assert.Equal(t, target, next) +} + func TestDmlChannels(t *testing.T) { const ( dmlChanPrefix = "rootcoord-dml"