未验证 提交 0c3f92d7 编写于 作者: S SimFG 提交者: GitHub

Improve the panic code about the rootcoord/session/rocksmq (#24859) (#25024)

Signed-off-by: NSimFG <bang.fu@zilliz.com>
上级 b88e74a1
......@@ -1051,7 +1051,7 @@ func (rmq *rocksmq) updateAckedInfo(topicName, groupName string, firstID UniqueI
consumers, ok := vals.([]*Consumer)
if !ok || len(consumers) == 0 {
log.Error("update ack with no consumer", zap.String("topic", topicName))
panic("update ack with no consumer")
return nil
}
// find min id of all consumer
......
......@@ -1229,7 +1229,7 @@ func TestRocksmq_updateAckedInfoErr(t *testing.T) {
rmq.DestroyConsumerGroup(topicName, groupName+strconv.Itoa(i))
}
// update acked for topic without any consumer
assert.Panics(t, func() { rmq.updateAckedInfo(topicName, groupName, 0, ids[len(ids)-1]) })
assert.Nil(t, rmq.updateAckedInfo(topicName, groupName, 0, ids[len(ids)-1]))
}
func TestRocksmq_Info(t *testing.T) {
......
......@@ -24,6 +24,8 @@ import (
"strings"
"sync"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/metrics"
"go.uber.org/zap"
......@@ -259,14 +261,21 @@ func (d *dmlChannels) getChannelNum() int {
return len(d.listChannels())
}
func (d *dmlChannels) getMsgStreamByName(chanName string) (*dmlMsgStream, error) {
v, ok := d.pool.Load(chanName)
if !ok {
log.Error("invalid channel name", zap.String("chanName", chanName))
return nil, errors.Newf("invalid channel name: %s", chanName)
}
return v.(*dmlMsgStream), nil
}
func (d *dmlChannels) broadcast(chanNames []string, pack *msgstream.MsgPack) error {
for _, chanName := range chanNames {
v, ok := d.pool.Load(chanName)
if !ok {
log.Error("invalid channel name", zap.String("chanName", chanName))
panic("invalid channel name: " + chanName)
dms, err := d.getMsgStreamByName(chanName)
if err != nil {
return err
}
dms := v.(*dmlMsgStream)
dms.mutex.RLock()
if dms.refcnt > 0 {
......@@ -284,12 +293,10 @@ func (d *dmlChannels) broadcast(chanNames []string, pack *msgstream.MsgPack) err
func (d *dmlChannels) broadcastMark(chanNames []string, pack *msgstream.MsgPack) (map[string][]byte, error) {
result := make(map[string][]byte)
for _, chanName := range chanNames {
v, ok := d.pool.Load(chanName)
if !ok {
log.Error("invalid channel name", zap.String("chanName", chanName))
panic("invalid channel name: " + chanName)
dms, err := d.getMsgStreamByName(chanName)
if err != nil {
return result, err
}
dms := v.(*dmlMsgStream)
dms.mutex.RLock()
if dms.refcnt > 0 {
......@@ -313,12 +320,10 @@ func (d *dmlChannels) broadcastMark(chanNames []string, pack *msgstream.MsgPack)
func (d *dmlChannels) addChannels(names ...string) {
for _, name := range names {
v, ok := d.pool.Load(name)
if !ok {
log.Error("invalid channel name", zap.String("chanName", name))
panic("invalid channel name: " + name)
dms, err := d.getMsgStreamByName(name)
if err != nil {
continue
}
dms := v.(*dmlMsgStream)
d.mut.Lock()
dms.IncRefcnt()
......@@ -329,12 +334,10 @@ func (d *dmlChannels) addChannels(names ...string) {
func (d *dmlChannels) removeChannels(names ...string) {
for _, name := range names {
v, ok := d.pool.Load(name)
if !ok {
log.Error("invalid channel name", zap.String("chanName", name))
panic("invalid channel name: " + name)
dms, err := d.getMsgStreamByName(name)
if err != nil {
continue
}
dms := v.(*dmlMsgStream)
d.mut.Lock()
dms.DecRefCnt()
......
......@@ -138,10 +138,13 @@ func TestDmlChannels(t *testing.T) {
assert.Equal(t, 0, len(chanNames))
randStr := funcutil.RandomString(8)
assert.Panics(t, func() { dml.addChannels(randStr) })
assert.Panics(t, func() { dml.broadcast([]string{randStr}, nil) })
assert.Panics(t, func() { dml.broadcastMark([]string{randStr}, nil) })
assert.Panics(t, func() { dml.removeChannels(randStr) })
dml.addChannels(randStr)
assert.Error(t, dml.broadcast([]string{randStr}, nil))
{
_, err := dml.broadcastMark([]string{randStr}, nil)
assert.Error(t, err)
}
dml.removeChannels(randStr)
chans0 := dml.getChannelNames(2)
dml.addChannels(chans0...)
......
......@@ -359,7 +359,16 @@ func (s *Session) getSessionKey() string {
}
func (s *Session) initWatchSessionCh() {
getResp, err := s.etcdCli.Get(context.Background(), s.getSessionKey())
var (
err error
getResp *clientv3.GetResponse
)
err = retry.Do(context.Background(), func() error {
getResp, err = s.etcdCli.Get(context.Background(), s.getSessionKey())
log.Warn("fail to get the session key from the etcd", zap.Error(err))
return err
}, retry.Attempts(100))
if err != nil {
panic(err)
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册