From f31ed089b5da01bca41e3d77daa0c60f67c249d8 Mon Sep 17 00:00:00 2001 From: godchen Date: Fri, 26 Nov 2021 22:45:24 +0800 Subject: [PATCH] Add timeout for reader next (#12308) Signed-off-by: godchen --- .../flow_graph_dmstream_input_node_test.go | 3 ++ internal/msgstream/mq_msgstream.go | 42 ++++++++++--------- internal/msgstream/mq_msgstream_test.go | 20 +++++++-- internal/msgstream/msgstream.go | 1 + internal/proxy/mock_test.go | 4 ++ internal/querynode/segment_loader.go | 13 +++++- 6 files changed, 57 insertions(+), 26 deletions(-) diff --git a/internal/datanode/flow_graph_dmstream_input_node_test.go b/internal/datanode/flow_graph_dmstream_input_node_test.go index 0f8d13fbf..bc14b02c1 100644 --- a/internal/datanode/flow_graph_dmstream_input_node_test.go +++ b/internal/datanode/flow_graph_dmstream_input_node_test.go @@ -104,6 +104,9 @@ func (mtm *mockTtMsgStream) SeekReaders(msgPositions []*internalpb.MsgPosition) func (mtm *mockTtMsgStream) Next(ctx context.Context, channelName string) (msgstream.TsMsg, error) { return nil, nil } +func (mtm *mockTtMsgStream) HasNext(channelName string) bool { + return true +} func TestNewDmInputNode(t *testing.T) { ctx := context.Background() diff --git a/internal/msgstream/mq_msgstream.go b/internal/msgstream/mq_msgstream.go index 7d7ea3f8d..2c33d76cf 100644 --- a/internal/msgstream/mq_msgstream.go +++ b/internal/msgstream/mq_msgstream.go @@ -575,27 +575,29 @@ func (ms *mqMsgStream) Next(ctx context.Context, channelName string) (TsMsg, err if !ok { return nil, fmt.Errorf("reader for channel %s is not exist", channelName) } - if reader.HasNext() { - msg, err := reader.Next(ctx) - if err != nil { - return nil, err - } - tsMsg, err := ms.getTsMsgFromConsumerMsg(msg) - if err != nil { - log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) - return nil, errors.New("Failed to getTsMsgFromConsumerMsg") - } - pos := tsMsg.Position() - tsMsg.SetPosition(&MsgPosition{ - ChannelName: pos.ChannelName, - MsgID: pos.MsgID, - Timestamp: tsMsg.BeginTs(), - }) - return tsMsg, nil + msg, err := reader.Next(ctx) + if err != nil { + return nil, err } - log.Debug("All data has been read, there is no more data", zap.String("channel", channelName)) - return nil, nil - + tsMsg, err := ms.getTsMsgFromConsumerMsg(msg) + if err != nil { + log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) + return nil, errors.New("Failed to getTsMsgFromConsumerMsg") + } + pos := tsMsg.Position() + tsMsg.SetPosition(&MsgPosition{ + ChannelName: pos.ChannelName, + MsgID: pos.MsgID, + Timestamp: tsMsg.BeginTs(), + }) + return tsMsg, nil +} +func (ms *mqMsgStream) HasNext(channelName string) bool { + reader, ok := ms.readers[channelName] + if !ok { + return false + } + return reader.HasNext() } // Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive diff --git a/internal/msgstream/mq_msgstream_test.go b/internal/msgstream/mq_msgstream_test.go index 88484ba81..441a9ec2c 100644 --- a/internal/msgstream/mq_msgstream_test.go +++ b/internal/msgstream/mq_msgstream_test.go @@ -1296,6 +1296,8 @@ func TestStream_MqMsgStream_Reader(t *testing.T) { defer readStream.Close() var seekPosition *internalpb.MsgPosition for i := 0; i < n; i++ { + hasNext := readStream.HasNext(c) + assert.True(t, hasNext) result, err := readStream.Next(ctx, c) assert.Nil(t, err) assert.Equal(t, result.ID(), int64(i)) @@ -1303,8 +1305,12 @@ func TestStream_MqMsgStream_Reader(t *testing.T) { seekPosition = result.Position() } } - result, err := readStream.Next(ctx, c) - assert.Nil(t, err) + hasNext := readStream.HasNext(c) + assert.False(t, hasNext) + timeoutCtx1, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + result, err := readStream.Next(timeoutCtx1, c) + assert.NotNil(t, err) assert.Nil(t, result) readStream2 := getPulsarReader(pulsarAddress, readerChannels) @@ -1312,12 +1318,18 @@ func TestStream_MqMsgStream_Reader(t *testing.T) { readStream2.SeekReaders([]*internalpb.MsgPosition{seekPosition}) for i := p; i < 10; i++ { + hasNext := readStream2.HasNext(c) + assert.True(t, hasNext) result, err := readStream2.Next(ctx, c) assert.Nil(t, err) assert.Equal(t, result.ID(), int64(i)) } - result2, err := readStream2.Next(ctx, c) - assert.Nil(t, err) + hasNext = readStream2.HasNext(c) + assert.False(t, hasNext) + timeoutCtx2, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + result2, err := readStream2.Next(timeoutCtx2, c) + assert.NotNil(t, err) assert.Nil(t, result2) } diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index b49115d72..3b815a897 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -69,6 +69,7 @@ type MsgStream interface { BroadcastMark(*MsgPack) (map[string][]MessageID, error) Consume() *MsgPack Next(ctx context.Context, channelName string) (TsMsg, error) + HasNext(channelName string) bool Seek(offset []*MsgPosition) error SeekReaders(msgPositions []*internalpb.MsgPosition) error } diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 051452433..526aa4173 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -288,6 +288,10 @@ func (ms *simpleMockMsgStream) Next(ctx context.Context, channelName string) (ms return nil, nil } +func (ms *simpleMockMsgStream) HasNext(channelName string) bool { + return true +} + func (ms *simpleMockMsgStream) AsConsumerWithPosition(channels []string, subName string, position mqclient.SubscriptionInitialPosition) { } diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index 515ef7697..fb8657910 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -18,6 +18,7 @@ import ( "path" "strconv" "sync" + "time" "go.uber.org/zap" @@ -37,6 +38,8 @@ import ( "github.com/milvus-io/milvus/internal/util/funcutil" ) +const timeoutForEachRead = 10 * time.Second + // segmentLoader is only responsible for loading the field data from binlog type segmentLoader struct { historicalReplica ReplicaInterface @@ -458,24 +461,30 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection deleteOffset: make(map[UniqueID]int64), } log.Debug("start read msg from stream reader") - for { + for stream.HasNext(pChannelName) { + ctx, cancel := context.WithTimeout(ctx, timeoutForEachRead) tsMsg, err := stream.Next(ctx, pChannelName) if err != nil { + cancel() return err } if tsMsg == nil { - break + cancel() + continue } if tsMsg.Type() == commonpb.MsgType_Delete { dmsg := tsMsg.(*msgstream.DeleteMsg) if dmsg.CollectionID != collectionID { + cancel() continue } log.Debug("delete pk", zap.Any("pk", dmsg.PrimaryKeys)) processDeleteMessages(loader.historicalReplica, dmsg, delData) } + cancel() } + log.Debug("All data has been read, there is no more data", zap.String("channel", pChannelName)) for segmentID, pks := range delData.deleteIDs { segment, err := loader.historicalReplica.getSegmentByID(segmentID) if err != nil { -- GitLab