未验证 提交 d4c297b1 编写于 作者: C Cai Yudong 提交者: GitHub

Enhance dml channel operations (#12143)

Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 ac175b6f
......@@ -22,11 +22,16 @@ import (
"github.com/milvus-io/milvus/internal/msgstream"
)
type dmlMsgStream struct {
ms msgstream.MsgStream
mutex sync.RWMutex
refcnt int64
}
type dmlChannels struct {
core *Core
namePrefix string
capacity int64
refcnt sync.Map
idx *atomic.Int64
pool sync.Map
}
......@@ -36,60 +41,71 @@ func newDmlChannels(c *Core, chanNamePrefix string, chanNum int64) *dmlChannels
core: c,
namePrefix: chanNamePrefix,
capacity: chanNum,
refcnt: sync.Map{},
idx: atomic.NewInt64(0),
pool: sync.Map{},
}
var i int64
for i = 0; i < chanNum; i++ {
name := fmt.Sprintf("%s_%d", d.namePrefix, i)
for i := int64(0); i < chanNum; i++ {
name := getDmlChannelName(d.namePrefix, i)
ms, err := c.msFactory.NewMsgStream(c.ctx)
if err != nil {
log.Error("Failed to add msgstream", zap.String("name", name), zap.Error(err))
panic("Failed to add msgstream")
}
d.pool.Store(name, &ms)
d.pool.Store(name, &dmlMsgStream{
ms: ms,
mutex: sync.RWMutex{},
refcnt: 0,
})
}
log.Debug("init dml channels", zap.Int64("num", chanNum))
return d
}
func (d *dmlChannels) GetDmlMsgStreamName() string {
cnt := d.idx.Load()
name := fmt.Sprintf("%s_%d", d.namePrefix, cnt)
d.idx.Store((cnt + 1) % d.capacity)
return name
cnt := d.idx.Inc()
return getDmlChannelName(d.namePrefix, (cnt-1)%d.capacity)
}
// ListChannels lists all dml channel names
func (d *dmlChannels) ListChannels() []string {
// ListPhysicalChannels lists all dml channel names
func (d *dmlChannels) ListPhysicalChannels() []string {
var chanNames []string
d.refcnt.Range(
d.pool.Range(
func(k, v interface{}) bool {
chanNames = append(chanNames, k.(string))
dms := v.(*dmlMsgStream)
dms.mutex.RLock()
if dms.refcnt > 0 {
chanNames = append(chanNames, k.(string))
}
dms.mutex.RUnlock()
return true
})
return chanNames
}
// GetNumChannels get current dml channel count
func (d *dmlChannels) GetNumChannels() int {
return len(d.ListChannels())
func (d *dmlChannels) GetPhysicalChannelNum() int {
return len(d.ListPhysicalChannels())
}
// Broadcast broadcasts msg pack into specified channel
func (d *dmlChannels) Broadcast(chanNames []string, pack *msgstream.MsgPack) error {
for _, chanName := range chanNames {
// only in-use chanName exist in refcnt
if _, ok := d.refcnt.Load(chanName); ok {
v, _ := d.pool.Load(chanName)
if err := (*(v.(*msgstream.MsgStream))).Broadcast(pack); err != nil {
v, ok := d.pool.Load(chanName)
if !ok {
log.Error("invalid channel name", zap.String("chanName", chanName))
panic("invalid channel name: " + chanName)
}
dms := v.(*dmlMsgStream)
dms.mutex.RLock()
if dms.refcnt > 0 {
if err := dms.ms.Broadcast(pack); err != nil {
log.Error("Broadcast failed", zap.String("chanName", chanName))
return err
}
} else {
return fmt.Errorf("channel %s not exist", chanName)
}
dms.mutex.RUnlock()
}
return nil
}
......@@ -98,22 +114,28 @@ func (d *dmlChannels) Broadcast(chanNames []string, pack *msgstream.MsgPack) err
func (d *dmlChannels) BroadcastMark(chanNames []string, pack *msgstream.MsgPack) (map[string][]byte, error) {
result := make(map[string][]byte)
for _, chanName := range chanNames {
// only in-use chanName exist in refcnt
if _, ok := d.refcnt.Load(chanName); ok {
v, _ := d.pool.Load(chanName)
ids, err := (*(v.(*msgstream.MsgStream))).BroadcastMark(pack)
v, ok := d.pool.Load(chanName)
if !ok {
log.Error("invalid channel name", zap.String("chanName", chanName))
panic("invalid channel name: " + chanName)
}
dms := v.(*dmlMsgStream)
dms.mutex.RLock()
if dms.refcnt > 0 {
ids, err := dms.ms.BroadcastMark(pack)
if err != nil {
log.Error("BroadcastMark failed", zap.String("chanName", chanName))
return result, err
}
for chanName, idList := range ids {
for cn, idList := range ids {
// idList should have length 1, just flat by iteration
for _, id := range idList {
result[chanName] = id.Serialize()
result[cn] = id.Serialize()
}
}
} else {
return result, fmt.Errorf("channel %s not exist", chanName)
}
dms.mutex.RUnlock()
}
return result, nil
}
......@@ -121,38 +143,43 @@ func (d *dmlChannels) BroadcastMark(chanNames []string, pack *msgstream.MsgPack)
// AddProducerChannels add named channels as producer
func (d *dmlChannels) AddProducerChannels(names ...string) {
for _, name := range names {
if v, ok := d.pool.Load(name); ok {
var cnt int64
if _, ok := d.refcnt.Load(name); !ok {
ms := *(v.(*msgstream.MsgStream))
ms.AsProducer([]string{name})
cnt = 1
} else {
v, _ := d.refcnt.Load(name)
cnt = v.(int64) + 1
}
d.refcnt.Store(name, cnt)
log.Debug("assign dml channel", zap.String("chanName", name), zap.Int64("refcnt", cnt))
} else {
v, ok := d.pool.Load(name)
if !ok {
log.Error("invalid channel name", zap.String("chanName", name))
panic("invalid channel name: " + name)
}
dms := v.(*dmlMsgStream)
dms.mutex.Lock()
if dms.refcnt == 0 {
dms.ms.AsProducer([]string{name})
}
dms.refcnt++
dms.mutex.Unlock()
}
}
// RemoveProducerChannels removes specified channels
func (d *dmlChannels) RemoveProducerChannels(names ...string) {
for _, name := range names {
if v, ok := d.refcnt.Load(name); ok {
cnt := v.(int64)
if cnt > 1 {
d.refcnt.Store(name, cnt-1)
} else {
v1, _ := d.pool.Load(name)
ms := *(v1.(*msgstream.MsgStream))
ms.Close()
d.refcnt.Delete(name)
v, ok := d.pool.Load(name)
if !ok {
log.Error("invalid channel name", zap.String("chanName", name))
panic("invalid channel name: " + name)
}
dms := v.(*dmlMsgStream)
dms.mutex.Lock()
if dms.refcnt > 0 {
dms.refcnt--
if dms.refcnt == 0 {
dms.ms.Close()
}
}
dms.mutex.Unlock()
}
}
func getDmlChannelName(prefix string, idx int64) string {
return fmt.Sprintf("%s_%d", prefix, idx)
}
......@@ -13,7 +13,6 @@ package rootcoord
import (
"context"
"fmt"
"testing"
"github.com/milvus-io/milvus/internal/msgstream"
......@@ -43,36 +42,35 @@ func TestDmlChannels(t *testing.T) {
assert.Nil(t, err)
dml := newDmlChannels(core, dmlChanPrefix, totalDmlChannelNum)
chanNames := dml.ListChannels()
chanNames := dml.ListPhysicalChannels()
assert.Equal(t, 0, len(chanNames))
randStr := funcutil.RandomString(8)
assert.Panics(t, func() { dml.AddProducerChannels(randStr) })
err = dml.Broadcast([]string{randStr}, nil)
assert.NotNil(t, err)
assert.EqualError(t, err, fmt.Sprintf("channel %s not exist", randStr))
assert.Panics(t, func() { dml.Broadcast([]string{randStr}, nil) })
assert.Panics(t, func() { dml.BroadcastMark([]string{randStr}, nil) })
assert.Panics(t, func() { dml.RemoveProducerChannels(randStr) })
// dml_xxx_0 => {chanName0, chanName2}
// dml_xxx_1 => {chanName1}
chanName0 := dml.GetDmlMsgStreamName()
dml.AddProducerChannels(chanName0)
assert.Equal(t, 1, dml.GetNumChannels())
assert.Equal(t, 1, dml.GetPhysicalChannelNum())
chanName1 := dml.GetDmlMsgStreamName()
dml.AddProducerChannels(chanName1)
assert.Equal(t, 2, dml.GetNumChannels())
assert.Equal(t, 2, dml.GetPhysicalChannelNum())
chanName2 := dml.GetDmlMsgStreamName()
dml.AddProducerChannels(chanName2)
assert.Equal(t, 2, dml.GetNumChannels())
assert.Equal(t, 2, dml.GetPhysicalChannelNum())
dml.RemoveProducerChannels(chanName0)
assert.Equal(t, 2, dml.GetNumChannels())
assert.Equal(t, 2, dml.GetPhysicalChannelNum())
dml.RemoveProducerChannels(chanName1)
assert.Equal(t, 1, dml.GetNumChannels())
assert.Equal(t, 1, dml.GetPhysicalChannelNum())
dml.RemoveProducerChannels(chanName0)
assert.Equal(t, 0, dml.GetNumChannels())
assert.Equal(t, 0, dml.GetPhysicalChannelNum())
}
......@@ -482,7 +482,7 @@ func (c *Core) setMsgStreams() error {
metrics.RootCoordDDChannelTimeTick.Set(float64(tsoutil.Mod24H(t)))
//c.dmlChannels.BroadcastAll(&msgPack)
pc := c.dmlChannels.ListChannels()
pc := c.dmlChannels.ListPhysicalChannels()
pt := make([]uint64, len(pc))
for i := 0; i < len(pt); i++ {
pt[i] = t
......
......@@ -680,7 +680,7 @@ func TestRootCoord(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
assert.Equal(t, shardsNum, int32(core.dmlChannels.GetNumChannels()))
assert.Equal(t, shardsNum, int32(core.dmlChannels.GetPhysicalChannelNum()))
createMeta, err := core.MetaTable.GetCollectionByName(collName, 0)
assert.Nil(t, err)
......
......@@ -310,7 +310,7 @@ func (t *timetickSync) GetProxyNum() int {
// GetChanNum return the num of channel
func (t *timetickSync) GetChanNum() int {
return t.core.dmlChannels.GetNumChannels()
return t.core.dmlChannels.GetPhysicalChannelNum()
}
func minTimeTick(tt ...typeutil.Timestamp) typeutil.Timestamp {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册