未验证 提交 214f40b2 编写于 作者: B bigsheeper 提交者: GitHub

Add timeout ts for search and query (#12890)

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 73b63b9a
......@@ -274,6 +274,11 @@ func (st *SearchMsg) TravelTs() Timestamp {
return st.GetTravelTimestamp()
}
// TimeoutTs returns the timestamp of timeout
func (st *SearchMsg) TimeoutTs() Timestamp {
return st.GetTimeoutTimestamp()
}
// Marshal is used to serializing a message pack to byte array
func (st *SearchMsg) Marshal(input TsMsg) (MarshalType, error) {
searchTask := input.(*SearchMsg)
......@@ -396,6 +401,11 @@ func (rm *RetrieveMsg) TravelTs() Timestamp {
return rm.GetTravelTimestamp()
}
// TimeoutTs returns the timestamp of timeout
func (rm *RetrieveMsg) TimeoutTs() Timestamp {
return rm.GetTimeoutTimestamp()
}
// Marshal is used to serializing a message pack to byte array
func (rm *RetrieveMsg) Marshal(input TsMsg) (MarshalType, error) {
retrieveTask := input.(*RetrieveMsg)
......
......@@ -205,6 +205,7 @@ func TestSearchMsg(t *testing.T) {
OutputFieldsId: []int64{},
TravelTimestamp: 6,
GuaranteeTimestamp: 7,
TimeoutTimestamp: 8,
},
}
......@@ -219,6 +220,7 @@ func TestSearchMsg(t *testing.T) {
assert.Equal(t, int64(3), searchMsg.SourceID())
assert.Equal(t, uint64(7), searchMsg.GuaranteeTs())
assert.Equal(t, uint64(6), searchMsg.TravelTs())
assert.Equal(t, uint64(8), searchMsg.TimeoutTs())
bytes, err := searchMsg.Marshal(searchMsg)
assert.Nil(t, err)
......@@ -310,6 +312,7 @@ func TestRetrieveMsg(t *testing.T) {
OutputFieldsId: []int64{8, 9},
TravelTimestamp: 10,
GuaranteeTimestamp: 11,
TimeoutTimestamp: 12,
},
}
......@@ -324,6 +327,7 @@ func TestRetrieveMsg(t *testing.T) {
assert.Equal(t, int64(3), retrieveMsg.SourceID())
assert.Equal(t, uint64(11), retrieveMsg.GuaranteeTs())
assert.Equal(t, uint64(10), retrieveMsg.TravelTs())
assert.Equal(t, uint64(12), retrieveMsg.TimeoutTs())
bytes, err := retrieveMsg.Marshal(retrieveMsg)
assert.Nil(t, err)
......
......@@ -157,6 +157,7 @@ message SearchRequest {
repeated int64 output_fields_id = 10;
uint64 travel_timestamp = 11;
uint64 guarantee_timestamp = 12;
uint64 timeout_timestamp = 13;
}
message SearchResults {
......@@ -185,6 +186,7 @@ message RetrieveRequest {
repeated int64 output_fields_id = 7;
uint64 travel_timestamp = 8;
uint64 guarantee_timestamp = 9;
uint64 timeout_timestamp = 10;
}
message RetrieveResults {
......
......@@ -30,9 +30,9 @@ import (
"strings"
"unsafe"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
......@@ -51,6 +51,7 @@ import (
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
......@@ -1575,6 +1576,10 @@ func (st *searchTask) PreExecute(ctx context.Context) error {
}
st.SearchRequest.TravelTimestamp = travelTimestamp
st.SearchRequest.GuaranteeTimestamp = guaranteeTimestamp
deadline, ok := st.TraceCtx().Deadline()
if ok {
st.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
}
st.SearchRequest.ResultChannelID = Params.SearchResultChannelNames[0]
st.SearchRequest.DbID = 0 // todo
......@@ -1671,6 +1676,7 @@ func (st *searchTask) Execute(ctx context.Context) error {
zap.Any("collectionID", st.CollectionID),
zap.Any("msgID", tsMsg.ID()),
zap.Int("length of search msg", len(msgPack.Msgs)),
zap.Any("timeoutTs", st.SearchRequest.TimeoutTimestamp),
)
return err
}
......@@ -2175,6 +2181,10 @@ func (qt *queryTask) PreExecute(ctx context.Context) error {
}
qt.TravelTimestamp = travelTimestamp
qt.GuaranteeTimestamp = guaranteeTimestamp
deadline, ok := qt.TraceCtx().Deadline()
if ok {
qt.RetrieveRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
}
qt.ResultChannelID = Params.RetrieveResultChannelNames[0]
qt.DbID = 0 // todo(yukun)
......@@ -2255,7 +2265,12 @@ func (qt *queryTask) Execute(ctx context.Context) error {
}
}
err = stream.Produce(&msgPack)
log.Debug("proxy", zap.Int("length of retrieveMsg", len(msgPack.Msgs)))
log.Debug("proxy sent one retrieveMsg",
zap.Any("collectionID", qt.CollectionID),
zap.Any("msgID", tsMsg.ID()),
zap.Int("length of search msg", len(msgPack.Msgs)),
zap.Any("timeoutTs", qt.RetrieveRequest.TimeoutTimestamp),
)
if err != nil {
log.Debug("Failed to send retrieve request.",
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
......
......@@ -29,9 +29,10 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
......@@ -42,8 +43,8 @@ import (
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
"github.com/stretchr/testify/assert"
)
// TODO(dragondriver): add more test cases
......@@ -3010,6 +3011,16 @@ func TestSearchTask_PreExecute(t *testing.T) {
},
}
// search task with timeout
ctx1, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
// before preExecute
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
task.ctx = ctx1
assert.NoError(t, task.PreExecute(ctx))
// after preExecute
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
// field not exist
task.query.OutputFields = []string{int64Field + funcutil.GenRandomStr()}
assert.Error(t, task.PreExecute(ctx))
......@@ -3467,7 +3478,18 @@ func TestQueryTask_all(t *testing.T) {
}()
assert.NoError(t, task.OnEnqueue())
// test query task with timeout
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
defer cancel1()
// before preExecute
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
task.ctx = ctx1
assert.NoError(t, task.PreExecute(ctx))
// after preExecute
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
task.ctx = ctx
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
......
......@@ -44,6 +44,7 @@ type queryMsg interface {
msgstream.TsMsg
GuaranteeTs() Timestamp
TravelTs() Timestamp
TimeoutTs() Timestamp
}
type queryCollection struct {
......@@ -299,6 +300,21 @@ func (q *queryCollection) setServiceableTime(t Timestamp) {
q.serviceableTime = t
}
func (q *queryCollection) checkTimeout(msg queryMsg) bool {
curTime := tsoutil.GetCurrentTime()
//curTimePhysical, _ := tsoutil.ParseTS(curTime)
//timeoutTsPhysical, _ := tsoutil.ParseTS(msg.TimeoutTs())
//log.Debug("check if query timeout",
// zap.Any("collectionID", q.collectionID),
// zap.Any("msgID", msg.ID()),
// zap.Any("TimeoutTs", msg.TimeoutTs()),
// zap.Any("curTime", curTime),
// zap.Any("timeoutTsPhysical", timeoutTsPhysical),
// zap.Any("curTimePhysical", curTimePhysical),
//)
return msg.TimeoutTs() > typeutil.ZeroTimestamp && curTime >= msg.TimeoutTs()
}
func (q *queryCollection) consumeQuery() {
for {
select {
......@@ -427,6 +443,17 @@ func (q *queryCollection) receiveQueryMsg(msg queryMsg) error {
msg.SetTraceCtx(ctx)
tr := timerecord.NewTimeRecorder(fmt.Sprintf("receiveQueryMsg %d", msg.ID()))
if q.checkTimeout(msg) {
err := errors.New(fmt.Sprintln("do query failed in receiveQueryMsg because timeout"+
", collectionID = ", collectionID,
", msgID = ", msg.ID()))
publishErr := q.publishFailedQueryResult(msg, err.Error())
if publishErr != nil {
return fmt.Errorf("first err = %s, second err = %s", err, publishErr)
}
return err
}
// check if collection has been released
collection, err := q.historical.replica.getCollectionByID(collectionID)
if err != nil {
......@@ -561,6 +588,19 @@ func (q *queryCollection) doUnsolvedQueryMsg() {
zap.Any("guaranteeTime_l", guaranteeTs),
zap.Any("serviceTime_l", serviceTime),
)
if q.checkTimeout(m) {
err := errors.New(fmt.Sprintln("do query failed in doUnsolvedQueryMsg because timeout"+
", collectionID = ", q.collectionID,
", msgID = ", m.ID()))
log.Warn(err.Error())
publishErr := q.publishFailedQueryResult(m, err.Error())
if publishErr != nil {
log.Error(publishErr.Error())
}
continue
}
if guaranteeTs <= q.getServiceableTime() {
unSolvedMsg = append(unSolvedMsg, m)
continue
......@@ -1370,10 +1410,5 @@ func (q *queryCollection) publishFailedQueryResult(msg msgstream.TsMsg, errMsg s
return fmt.Errorf("publish invalid msgType %d", msgType)
}
err := q.queryResultMsgStream.Produce(&msgPack)
if err != nil {
return err
}
return nil
return q.queryResultMsgStream.Produce(&msgPack)
}
......@@ -308,6 +308,13 @@ func TestQueryCollection_consumeQuery(t *testing.T) {
msg.Base.MsgType = commonpb.MsgType_CreateCollection
runConsumeQuery(msg)
})
t.Run("consume timeout msg", func(t *testing.T) {
msg, err := genSimpleRetrieveMsg()
assert.NoError(t, err)
msg.TimeoutTimestamp = tsoutil.GetCurrentTime() - Timestamp(time.Second<<18)
runConsumeQuery(msg)
})
}
func TestQueryCollection_TranslateHits(t *testing.T) {
......@@ -557,19 +564,38 @@ func TestQueryCollection_mergeRetrieveResults(t *testing.T) {
func TestQueryCollection_doUnsolvedQueryMsg(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
assert.NoError(t, err)
t.Run("test doUnsolvedQueryMsg", func(t *testing.T) {
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
assert.NoError(t, err)
timestamp := Timestamp(1000)
updateTSafe(queryCollection, timestamp)
timestamp := Timestamp(1000)
updateTSafe(queryCollection, timestamp)
go queryCollection.doUnsolvedQueryMsg()
go queryCollection.doUnsolvedQueryMsg()
msg, err := genSimpleSearchMsg()
assert.NoError(t, err)
queryCollection.addToUnsolvedMsg(msg)
msg, err := genSimpleSearchMsg()
assert.NoError(t, err)
queryCollection.addToUnsolvedMsg(msg)
time.Sleep(200 * time.Millisecond)
})
t.Run("test doUnsolvedQueryMsg timeout", func(t *testing.T) {
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
assert.NoError(t, err)
time.Sleep(200 * time.Millisecond)
timestamp := Timestamp(1000)
updateTSafe(queryCollection, timestamp)
go queryCollection.doUnsolvedQueryMsg()
msg, err := genSimpleSearchMsg()
assert.NoError(t, err)
msg.TimeoutTimestamp = tsoutil.GetCurrentTime() - Timestamp(time.Second<<18)
queryCollection.addToUnsolvedMsg(msg)
time.Sleep(2000 * time.Millisecond)
})
}
func TestQueryCollection_search(t *testing.T) {
......
......@@ -16,6 +16,8 @@ import (
"time"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
......@@ -28,6 +30,16 @@ func ComposeTS(physical, logical int64) uint64 {
return uint64((physical << logicalBits) + logical)
}
// ComposeTSByTime returns a timestamp composed of physical time.Time and logical time
func ComposeTSByTime(physical time.Time, logical int64) uint64 {
return ComposeTS(physical.UnixNano()/int64(time.Millisecond), logical)
}
// GetCurrentTime returns the current timestamp
func GetCurrentTime() typeutil.Timestamp {
return ComposeTSByTime(time.Now(), 0)
}
// ParseTS parses the ts to (physical,logical).
func ParseTS(ts uint64) (time.Time, uint64) {
logical := ts & logicalBitsMask
......
......@@ -15,8 +15,11 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/log"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestParseHybridTs(t *testing.T) {
......@@ -28,3 +31,22 @@ func TestParseHybridTs(t *testing.T) {
zap.Int64("logical", logical),
zap.Any("physical time", physicalTime))
}
func Test_Tso(t *testing.T) {
t.Run("test ComposeTSByTime", func(t *testing.T) {
physical := time.Now()
logical := int64(1000)
timestamp := ComposeTSByTime(physical, logical)
pRes, lRes := ParseTS(timestamp)
assert.Equal(t, physical.Unix(), pRes.Unix())
assert.Equal(t, uint64(logical), lRes)
})
t.Run("test GetCurrentTime", func(t *testing.T) {
curTime := GetCurrentTime()
p, l := ParseTS(curTime)
subTime := time.Since(p)
assert.Less(t, subTime, time.Millisecond)
assert.Equal(t, typeutil.ZeroTimestamp, l)
})
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册