未验证 提交 0b354cba 编写于 作者: Y yihao.dai 提交者: GitHub

Fix timetick block caused by dml task pop failed (#23277)

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 52765295
......@@ -5,10 +5,18 @@ type removeDMLStreamFuncType = func(collectionID UniqueID) error
type mockChannelsMgr struct {
channelsMgr
getChannelsFunc func(collectionID UniqueID) ([]pChan, error)
getVChannelsFuncType
removeDMLStreamFuncType
}
func (m *mockChannelsMgr) getChannels(collectionID UniqueID) ([]pChan, error) {
if m.getChannelsFunc != nil {
return m.getChannelsFunc(collectionID)
}
return nil, nil
}
func (m *mockChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) {
if m.getVChannelsFuncType != nil {
return m.getVChannelsFuncType(collectionID)
......
......@@ -82,11 +82,19 @@ func (dt *deleteTask) OnEnqueue() error {
}
func (dt *deleteTask) getChannels() ([]pChan, error) {
if len(dt.pChannels) != 0 {
return dt.pChannels, nil
}
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.CollectionName)
if err != nil {
return nil, err
}
return dt.chMgr.getChannels(collID)
channels, err := dt.chMgr.getChannels(collID)
if err != nil {
return nil, err
}
dt.pChannels = channels
return channels, nil
}
func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res *schemapb.IDs, rowNum int64, err error) {
......
package proxy
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func Test_getPrimaryKeysFromExpr(t *testing.T) {
......@@ -37,3 +42,41 @@ func Test_getPrimaryKeysFromExpr(t *testing.T) {
assert.Error(t, err)
})
}
func TestDeleteTask(t *testing.T) {
t.Run("test getChannels", func(t *testing.T) {
collectionID := UniqueID(0)
collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return collectionID, nil
})
globalMetaCache = cache
chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return channels, nil
}
dt := deleteTask{
ctx: context.Background(),
deleteMsg: &msgstream.DeleteMsg{
DeleteRequest: msgpb.DeleteRequest{
CollectionName: collectionName,
},
},
chMgr: chMgr,
}
resChannels, err := dt.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, dt.pChannels)
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return nil, fmt.Errorf("mock err")
}
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
resChannels, err = dt.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
})
}
......@@ -71,11 +71,19 @@ func (it *insertTask) EndTs() Timestamp {
}
func (it *insertTask) getChannels() ([]pChan, error) {
if len(it.pChannels) != 0 {
return it.pChannels, nil
}
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.CollectionName)
if err != nil {
return nil, err
}
return it.chMgr.getChannels(collID)
channels, err := it.chMgr.getChannels(collID)
if err != nil {
return nil, err
}
it.pChannels = channels
return channels, nil
}
func (it *insertTask) OnEnqueue() error {
......
package proxy
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
......@@ -8,6 +10,8 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestInsertTask_CheckAligned(t *testing.T) {
......@@ -219,3 +223,41 @@ func TestInsertTask_CheckAligned(t *testing.T) {
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
}
func TestInsertTask(t *testing.T) {
t.Run("test getChannels", func(t *testing.T) {
collectionID := UniqueID(0)
collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return collectionID, nil
})
globalMetaCache = cache
chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return channels, nil
}
it := insertTask{
ctx: context.Background(),
insertMsg: &msgstream.InsertMsg{
InsertRequest: msgpb.InsertRequest{
CollectionName: collectionName,
},
},
chMgr: chMgr,
}
resChannels, err := it.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, it.pChannels)
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return nil, fmt.Errorf("mock err")
}
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
resChannels, err = it.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
})
}
......@@ -227,15 +227,19 @@ func (queue *dmTaskQueue) Enqueue(t task) error {
// This statsLock has two functions:
// 1) Protect member pChanStatisticsInfos
// 2) Serialize the timestamp allocation for dml tasks
queue.statsLock.Lock()
defer queue.statsLock.Unlock()
//1. preAdd will check whether provided task is valid or addable
//and get the current pChannels for this dmTask
pChannels, dmt, err := queue.preAddPChanStats(t)
dmt := t.(dmlTask)
pChannels, err := dmt.getChannels()
if err != nil {
log.Warn("getChannels failed when Enqueue", zap.Any("tID", t.ID()), zap.Error(err))
return err
}
//2. enqueue dml task
queue.statsLock.Lock()
defer queue.statsLock.Unlock()
err = queue.baseTaskQueue.Enqueue(t)
if err != nil {
return err
......@@ -265,19 +269,6 @@ func (queue *dmTaskQueue) PopActiveTask(taskID UniqueID) task {
return t
}
func (queue *dmTaskQueue) preAddPChanStats(t task) ([]pChan, dmlTask, error) {
if dmT, ok := t.(dmlTask); ok {
channels, err := dmT.getChannels()
if err != nil {
log.Warn("Proxy dmTaskQueue preAddPChanStats getChannels failed", zap.Any("tID", t.ID()),
zap.Error(err))
return nil, nil, err
}
return channels, dmT, nil
}
return nil, nil, fmt.Errorf("proxy preAddPChanStats reflect to dmlTask failed, tID:%v", t.ID())
}
func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
//1. prepare new stat for all pChannels
newStats := make(map[pChan]pChanStatistics)
......@@ -312,34 +303,31 @@ func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
}
}
func (queue *dmTaskQueue) popPChanStats(t task) error {
if dmT, ok := t.(dmlTask); ok {
channels, err := dmT.getChannels()
if err != nil {
return err
}
taskTs := t.BeginTs()
for _, cName := range channels {
info, ok := queue.pChanStatisticsInfos[cName]
if ok {
delete(info.tsSet, taskTs)
if len(info.tsSet) <= 0 {
delete(queue.pChanStatisticsInfos, cName)
} else {
newMinTs := info.maxTs
for ts := range info.tsSet {
if newMinTs > ts {
newMinTs = ts
}
func (queue *dmTaskQueue) popPChanStats(t task) {
channels, err := t.(dmlTask).getChannels()
if err != nil {
err = fmt.Errorf("get channels failed when popPChanStats, err=%w", err)
log.Error(err.Error())
panic(err)
}
taskTs := t.BeginTs()
for _, cName := range channels {
info, ok := queue.pChanStatisticsInfos[cName]
if ok {
delete(info.tsSet, taskTs)
if len(info.tsSet) <= 0 {
delete(queue.pChanStatisticsInfos, cName)
} else {
newMinTs := info.maxTs
for ts := range info.tsSet {
if newMinTs > ts {
newMinTs = ts
}
info.minTs = newMinTs
}
info.minTs = newMinTs
}
}
} else {
return fmt.Errorf("proxy dmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID())
}
return nil
}
func (queue *dmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) {
......
......@@ -26,7 +26,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestBaseTaskQueue(t *testing.T) {
......@@ -197,11 +201,6 @@ func TestDmTaskQueue_Basic(t *testing.T) {
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
//test wrong task type
dqlTask := newDefaultMockDqlTask()
err = queue.Enqueue(dqlTask)
assert.NotNil(t, err)
st := newDefaultMockDmlTask()
stID := st.ID()
......@@ -563,3 +562,51 @@ func TestTaskScheduler(t *testing.T) {
wg.Wait()
}
func TestTaskScheduler_concurrentPushAndPop(t *testing.T) {
collectionID := UniqueID(0)
collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return collectionID, nil
})
globalMetaCache = cache
tsoAllocatorIns := newMockTsoAllocator()
factory := newSimpleMockMsgStreamFactory()
scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns, factory)
assert.NoError(t, err)
run := func(wg *sync.WaitGroup) {
defer wg.Done()
chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return channels, nil
}
it := &insertTask{
ctx: context.Background(),
insertMsg: &msgstream.InsertMsg{
InsertRequest: msgpb.InsertRequest{
Base: &commonpb.MsgBase{},
CollectionName: collectionName,
},
},
chMgr: chMgr,
}
err := scheduler.dmQueue.Enqueue(it)
assert.NoError(t, err)
task := scheduler.scheduleDmTask()
scheduler.dmQueue.AddActiveTask(task)
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return nil, fmt.Errorf("mock err")
}
scheduler.dmQueue.PopActiveTask(task.ID()) // assert no panic
}
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go run(wg)
}
wg.Wait()
}
......@@ -116,11 +116,19 @@ func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
}
func (it *upsertTask) getChannels() ([]pChan, error) {
if len(it.pChannels) != 0 {
return it.pChannels, nil
}
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.CollectionName)
if err != nil {
return nil, err
}
return it.chMgr.getChannels(collID)
channels, err := it.chMgr.getChannels(collID)
if err != nil {
return nil, err
}
it.pChannels = channels
return channels, nil
}
func (it *upsertTask) OnEnqueue() error {
......
......@@ -16,6 +16,8 @@
package proxy
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
......@@ -26,6 +28,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestUpsertTask_CheckAligned(t *testing.T) {
......@@ -290,3 +293,39 @@ func TestUpsertTask_CheckAligned(t *testing.T) {
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
}
func TestUpsertTask(t *testing.T) {
t.Run("test getChannels", func(t *testing.T) {
collectionID := UniqueID(0)
collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return collectionID, nil
})
globalMetaCache = cache
chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return channels, nil
}
ut := upsertTask{
ctx: context.Background(),
req: &milvuspb.UpsertRequest{
CollectionName: collectionName,
},
chMgr: chMgr,
}
resChannels, err := ut.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, ut.pChannels)
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return nil, fmt.Errorf("mock err")
}
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
resChannels, err = ut.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
})
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册