未验证 提交 587ccc05 编写于 作者: C Cai Yudong 提交者: GitHub

Optimize dml_channels (#5783)

* update timetickSync::UpdateTimeTick
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* update dml_channels.go
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* fix unittest
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* remove ProduceAll and BroadcastAll
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 5c18138a
...@@ -20,117 +20,111 @@ import ( ...@@ -20,117 +20,111 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
type dmlStream struct {
msgStream msgstream.MsgStream
valid bool
}
type dmlChannels struct { type dmlChannels struct {
core *Core core *Core
lock sync.RWMutex lock sync.RWMutex
dml map[string]msgstream.MsgStream dml map[string]*dmlStream
} }
func newDMLChannels(c *Core) *dmlChannels { func newDMLChannels(c *Core) *dmlChannels {
return &dmlChannels{ return &dmlChannels{
core: c, core: c,
lock: sync.RWMutex{}, lock: sync.RWMutex{},
dml: make(map[string]msgstream.MsgStream), dml: make(map[string]*dmlStream),
} }
} }
func (d *dmlChannels) GetNumChannles() int { func (d *dmlChannels) GetNumChannles() int {
d.lock.RLock() d.lock.RLock()
defer d.lock.RUnlock() defer d.lock.RUnlock()
return len(d.dml) count := 0
} for _, ds := range d.dml {
if ds.valid {
func (d *dmlChannels) ProduceAll(pack *msgstream.MsgPack) { count++
d.lock.RLock()
defer d.lock.RUnlock()
for n, ms := range d.dml {
if err := ms.Produce(pack); err != nil {
log.Debug("msgstream produce error", zap.String("name", n), zap.Error(err))
}
}
}
func (d *dmlChannels) BroadcastMany(channels []string, pack *msgstream.MsgPack) error {
d.lock.RLock()
defer d.lock.RUnlock()
for _, ch := range channels {
ms, ok := d.dml[ch]
if !ok {
return fmt.Errorf("channel %s not exist", ch)
}
if err := ms.Broadcast(pack); err != nil {
return err
} }
} }
return nil return count
} }
func (d *dmlChannels) BroadcastAll(pack *msgstream.MsgPack) { func (d *dmlChannels) Produce(name string, pack *msgstream.MsgPack) error {
d.lock.RLock() d.lock.Lock()
defer d.lock.RUnlock() defer d.lock.Unlock()
for n, ms := range d.dml { ds, ok := d.dml[name]
if err := ms.Broadcast(pack); err != nil { if !ok {
log.Debug("msgstream broadcast error", zap.String("name", n), zap.Error(err)) return fmt.Errorf("channel %s not exist", name)
} }
if err := ds.msgStream.Produce(pack); err != nil {
return err
}
if !ds.valid {
ds.msgStream.Close()
delete(d.dml, name)
} }
return nil
} }
func (d *dmlChannels) Produce(name string, pack *msgstream.MsgPack) error { func (d *dmlChannels) Broadcast(name string, pack *msgstream.MsgPack) error {
d.lock.Lock() d.lock.Lock()
defer d.lock.Unlock() defer d.lock.Unlock()
var err error ds, ok := d.dml[name]
ms, ok := d.dml[name]
if !ok { if !ok {
ms, err = d.core.msFactory.NewMsgStream(d.core.ctx) return fmt.Errorf("channel %s not exist", name)
if err != nil {
return fmt.Errorf("create mstream failed, name = %s, error=%w", name, err)
}
ms.AsProducer([]string{name})
d.dml[name] = ms
} }
return ms.Produce(pack) if err := ds.msgStream.Broadcast(pack); err != nil {
return err
}
if !ds.valid {
ds.msgStream.Close()
delete(d.dml, name)
}
return nil
} }
func (d *dmlChannels) Broadcast(name string, pack *msgstream.MsgPack) error { func (d *dmlChannels) BroadcastAll(channels []string, pack *msgstream.MsgPack) error {
d.lock.Lock() d.lock.Lock()
defer d.lock.Unlock() defer d.lock.Unlock()
if len(name) == 0 { for _, ch := range channels {
return fmt.Errorf("channel name is empty") ds, ok := d.dml[ch]
} if !ok {
var err error return fmt.Errorf("channel %s not exist", ch)
ms, ok := d.dml[name] }
if !ok { if err := ds.msgStream.Broadcast(pack); err != nil {
ms, err = d.core.msFactory.NewMsgStream(d.core.ctx) return err
if err != nil { }
return fmt.Errorf("create msgtream failed, name = %s, error=%w", name, err) if !ds.valid {
ds.msgStream.Close()
delete(d.dml, ch)
} }
ms.AsProducer([]string{name})
d.dml[name] = ms
} }
return ms.Broadcast(pack) return nil
} }
func (d *dmlChannels) AddProducerChannels(names ...string) { func (d *dmlChannels) AddProducerChannels(names ...string) {
d.lock.Lock() d.lock.Lock()
defer d.lock.Unlock() defer d.lock.Unlock()
var err error
for _, name := range names { for _, name := range names {
log.Debug("add dml channel", zap.String("channel name", name)) log.Debug("add dml channel", zap.String("channel name", name))
ms, ok := d.dml[name] _, ok := d.dml[name]
if !ok { if !ok {
ms, err = d.core.msFactory.NewMsgStream(d.core.ctx) ms, err := d.core.msFactory.NewMsgStream(d.core.ctx)
if err != nil { if err != nil {
log.Debug("add msgstream failed", zap.String("name", name), zap.Error(err)) log.Debug("add msgstream failed", zap.String("name", name), zap.Error(err))
continue continue
} }
ms.AsProducer([]string{name}) ms.AsProducer([]string{name})
d.dml[name] = ms d.dml[name] = &dmlStream{
msgStream: ms,
valid: true,
}
} }
} }
} }
...@@ -141,22 +135,8 @@ func (d *dmlChannels) RemoveProducerChannels(names ...string) { ...@@ -141,22 +135,8 @@ func (d *dmlChannels) RemoveProducerChannels(names ...string) {
for _, name := range names { for _, name := range names {
log.Debug("delete dml channel", zap.String("channel name", name)) log.Debug("delete dml channel", zap.String("channel name", name))
if ms, ok := d.dml[name]; ok { if ds, ok := d.dml[name]; ok {
ms.Close() ds.valid = false
delete(d.dml, name)
}
}
}
func (d *dmlChannels) HasChannel(names ...string) bool {
d.lock.Lock()
defer d.lock.Unlock()
for _, name := range names {
if _, ok := d.dml[name]; !ok {
log.Debug("unknown channel", zap.String("channel name", name))
return false
} }
} }
return true
} }
...@@ -602,7 +602,7 @@ func (c *Core) setMsgStreams() error { ...@@ -602,7 +602,7 @@ func (c *Core) setMsgStreams() error {
CreateCollectionRequest: *req, CreateCollectionRequest: *req,
} }
msgPack.Msgs = append(msgPack.Msgs, msg) msgPack.Msgs = append(msgPack.Msgs, msg)
return c.dmlChannels.BroadcastMany(channelNames, &msgPack) return c.dmlChannels.BroadcastAll(channelNames, &msgPack)
} }
c.SendDdDropCollectionReq = func(ctx context.Context, req *internalpb.DropCollectionRequest, channelNames []string) error { c.SendDdDropCollectionReq = func(ctx context.Context, req *internalpb.DropCollectionRequest, channelNames []string) error {
...@@ -618,7 +618,7 @@ func (c *Core) setMsgStreams() error { ...@@ -618,7 +618,7 @@ func (c *Core) setMsgStreams() error {
DropCollectionRequest: *req, DropCollectionRequest: *req,
} }
msgPack.Msgs = append(msgPack.Msgs, msg) msgPack.Msgs = append(msgPack.Msgs, msg)
return c.dmlChannels.BroadcastMany(channelNames, &msgPack) return c.dmlChannels.BroadcastAll(channelNames, &msgPack)
} }
c.SendDdCreatePartitionReq = func(ctx context.Context, req *internalpb.CreatePartitionRequest, channelNames []string) error { c.SendDdCreatePartitionReq = func(ctx context.Context, req *internalpb.CreatePartitionRequest, channelNames []string) error {
...@@ -634,7 +634,7 @@ func (c *Core) setMsgStreams() error { ...@@ -634,7 +634,7 @@ func (c *Core) setMsgStreams() error {
CreatePartitionRequest: *req, CreatePartitionRequest: *req,
} }
msgPack.Msgs = append(msgPack.Msgs, msg) msgPack.Msgs = append(msgPack.Msgs, msg)
return c.dmlChannels.BroadcastMany(channelNames, &msgPack) return c.dmlChannels.BroadcastAll(channelNames, &msgPack)
} }
c.SendDdDropPartitionReq = func(ctx context.Context, req *internalpb.DropPartitionRequest, channelNames []string) error { c.SendDdDropPartitionReq = func(ctx context.Context, req *internalpb.DropPartitionRequest, channelNames []string) error {
...@@ -650,7 +650,7 @@ func (c *Core) setMsgStreams() error { ...@@ -650,7 +650,7 @@ func (c *Core) setMsgStreams() error {
DropPartitionRequest: *req, DropPartitionRequest: *req,
} }
msgPack.Msgs = append(msgPack.Msgs, msg) msgPack.Msgs = append(msgPack.Msgs, msg)
return c.dmlChannels.BroadcastMany(channelNames, &msgPack) return c.dmlChannels.BroadcastAll(channelNames, &msgPack)
} }
if Params.DataServiceSegmentChannel == "" { if Params.DataServiceSegmentChannel == "" {
...@@ -1885,12 +1885,6 @@ func (c *Core) UpdateChannelTimeTick(ctx context.Context, in *internalpb.Channel ...@@ -1885,12 +1885,6 @@ func (c *Core) UpdateChannelTimeTick(ctx context.Context, in *internalpb.Channel
status.Reason = fmt.Sprintf("UpdateChannelTimeTick receive invalid message %d", in.Base.GetMsgType()) status.Reason = fmt.Sprintf("UpdateChannelTimeTick receive invalid message %d", in.Base.GetMsgType())
return status, nil return status, nil
} }
if !c.dmlChannels.HasChannel(in.ChannelNames...) {
log.Debug("update time tick with unkonw channel", zap.Int("input channel size", len(in.ChannelNames)), zap.Strings("input channels", in.ChannelNames))
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = fmt.Sprintf("update time tick with unknown channel name, input channels = %v", in.ChannelNames)
return status, nil
}
err := c.chanTimeTick.UpdateTimeTick(in) err := c.chanTimeTick.UpdateTimeTick(in)
if err != nil { if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.ErrorCode = commonpb.ErrorCode_UnexpectedError
......
...@@ -45,14 +45,14 @@ func newTimeTickSync(core *Core) *timetickSync { ...@@ -45,14 +45,14 @@ func newTimeTickSync(core *Core) *timetickSync {
// sendToChannel send all channels' timetick to sendChan // sendToChannel send all channels' timetick to sendChan
// lock is needed by the invoker // lock is needed by the invoker
func (t *timetickSync) sendToChannel() { func (t *timetickSync) sendToChannel() {
if len(t.proxyTimeTick) == 0 {
return
}
for _, v := range t.proxyTimeTick { for _, v := range t.proxyTimeTick {
if v == nil { if v == nil {
return return
} }
} }
if len(t.proxyTimeTick) == 0 {
return
}
// clear proxyTimeTick and send a clone // clear proxyTimeTick and send a clone
ptt := make(map[typeutil.UniqueID]*internalpb.ChannelTimeTickMsg) ptt := make(map[typeutil.UniqueID]*internalpb.ChannelTimeTickMsg)
for k, v := range t.proxyTimeTick { for k, v := range t.proxyTimeTick {
...@@ -77,9 +77,11 @@ func (t *timetickSync) UpdateTimeTick(in *internalpb.ChannelTimeTickMsg) error { ...@@ -77,9 +77,11 @@ func (t *timetickSync) UpdateTimeTick(in *internalpb.ChannelTimeTickMsg) error {
if !ok { if !ok {
return fmt.Errorf("Skip ChannelTimeTickMsg from un-recognized proxy node %d", in.Base.SourceID) return fmt.Errorf("Skip ChannelTimeTickMsg from un-recognized proxy node %d", in.Base.SourceID)
} }
if prev != nil && prev.Timestamps[0] >= in.Timestamps[0] { if in.Base.SourceID == t.core.session.ServerID {
log.Debug("timestamp go back", zap.Int64("source id", in.Base.SourceID), zap.Uint64("prev ts", prev.Timestamps[0]), zap.Uint64("curr ts", in.Timestamps[0])) if prev != nil && prev.Timestamps[0] >= in.Timestamps[0] {
return nil log.Debug("timestamp go back", zap.Int64("source id", in.Base.SourceID), zap.Uint64("prev ts", prev.Timestamps[0]), zap.Uint64("curr ts", in.Timestamps[0]))
return nil
}
} }
t.proxyTimeTick[in.Base.SourceID] = in t.proxyTimeTick[in.Base.SourceID] = in
t.sendToChannel() t.sendToChannel()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册