diff --git a/internal/msgstream/msg.go b/internal/msgstream/msg.go index 15f4127b1b8b940308e1b30625f0f1d7846dec67..7e31b286139b2d503671508e779f14815c138577 100644 --- a/internal/msgstream/msg.go +++ b/internal/msgstream/msg.go @@ -32,6 +32,7 @@ type TsMsg interface { BeginTs() Timestamp EndTs() Timestamp Type() MsgType + SourceID() int64 HashKeys() []uint32 Marshal(TsMsg) (MarshalType, error) Unmarshal(MarshalType) (TsMsg, error) @@ -97,6 +98,10 @@ func (it *InsertMsg) Type() MsgType { return it.Base.MsgType } +func (it *InsertMsg) SourceID() int64 { + return it.Base.SourceID +} + func (it *InsertMsg) Marshal(input TsMsg) (MarshalType, error) { insertMsg := input.(*InsertMsg) insertRequest := &insertMsg.InsertRequest @@ -157,6 +162,10 @@ func (fl *FlushCompletedMsg) Type() MsgType { return fl.Base.MsgType } +func (fl *FlushCompletedMsg) SourceID() int64 { + return fl.Base.SourceID +} + func (fl *FlushCompletedMsg) Marshal(input TsMsg) (MarshalType, error) { flushCompletedMsgTask := input.(*FlushCompletedMsg) flushCompletedMsg := &flushCompletedMsgTask.SegmentFlushCompletedMsg @@ -206,6 +215,10 @@ func (dt *DeleteMsg) Type() MsgType { return dt.Base.MsgType } +func (dt *DeleteMsg) SourceID() int64 { + return dt.Base.SourceID +} + func (dt *DeleteMsg) Marshal(input TsMsg) (MarshalType, error) { deleteMsg := input.(*DeleteMsg) deleteRequest := &deleteMsg.DeleteRequest @@ -267,6 +280,10 @@ func (st *SearchMsg) Type() MsgType { return st.Base.MsgType } +func (st *SearchMsg) SourceID() int64 { + return st.Base.SourceID +} + func (st *SearchMsg) Marshal(input TsMsg) (MarshalType, error) { searchTask := input.(*SearchMsg) searchRequest := &searchTask.SearchRequest @@ -316,6 +333,10 @@ func (srt *SearchResultMsg) Type() MsgType { return srt.Base.MsgType } +func (srt *SearchResultMsg) SourceID() int64 { + return srt.Base.SourceID +} + func (srt *SearchResultMsg) Marshal(input TsMsg) (MarshalType, error) { searchResultTask := input.(*SearchResultMsg) searchResultRequest := &searchResultTask.SearchResults @@ -365,6 +386,10 @@ func (rm *RetrieveMsg) Type() MsgType { return rm.Base.MsgType } +func (rm *RetrieveMsg) SourceID() int64 { + return rm.Base.SourceID +} + func (rm *RetrieveMsg) Marshal(input TsMsg) (MarshalType, error) { retrieveTask := input.(*RetrieveMsg) retrieveRequest := &retrieveTask.RetrieveRequest @@ -414,6 +439,10 @@ func (rrm *RetrieveResultMsg) Type() MsgType { return rrm.Base.MsgType } +func (rrm *RetrieveResultMsg) SourceID() int64 { + return rrm.Base.SourceID +} + func (rrm *RetrieveResultMsg) Marshal(input TsMsg) (MarshalType, error) { retrieveResultTask := input.(*RetrieveResultMsg) retrieveResultRequest := &retrieveResultTask.RetrieveResults @@ -463,6 +492,10 @@ func (tst *TimeTickMsg) Type() MsgType { return tst.Base.MsgType } +func (tst *TimeTickMsg) SourceID() int64 { + return tst.Base.SourceID +} + func (tst *TimeTickMsg) Marshal(input TsMsg) (MarshalType, error) { timeTickTask := input.(*TimeTickMsg) timeTick := &timeTickTask.TimeTickMsg @@ -513,6 +546,10 @@ func (qs *QueryNodeStatsMsg) Type() MsgType { return qs.Base.MsgType } +func (qs *QueryNodeStatsMsg) SourceID() int64 { + return qs.Base.SourceID +} + func (qs *QueryNodeStatsMsg) Marshal(input TsMsg) (MarshalType, error) { queryNodeSegStatsTask := input.(*QueryNodeStatsMsg) queryNodeSegStats := &queryNodeSegStatsTask.QueryNodeStats @@ -560,6 +597,10 @@ func (ss *SegmentStatisticsMsg) Type() MsgType { return ss.Base.MsgType } +func (ss *SegmentStatisticsMsg) SourceID() int64 { + return ss.Base.SourceID +} + func (ss *SegmentStatisticsMsg) Marshal(input TsMsg) (MarshalType, error) { segStatsTask := input.(*SegmentStatisticsMsg) segStats := &segStatsTask.SegmentStatistics @@ -607,6 +648,10 @@ func (cc *CreateCollectionMsg) Type() MsgType { return cc.Base.MsgType } +func (cc *CreateCollectionMsg) SourceID() int64 { + return cc.Base.SourceID +} + func (cc *CreateCollectionMsg) Marshal(input TsMsg) (MarshalType, error) { createCollectionMsg := input.(*CreateCollectionMsg) createCollectionRequest := &createCollectionMsg.CreateCollectionRequest @@ -656,6 +701,10 @@ func (dc *DropCollectionMsg) Type() MsgType { return dc.Base.MsgType } +func (dc *DropCollectionMsg) SourceID() int64 { + return dc.Base.SourceID +} + func (dc *DropCollectionMsg) Marshal(input TsMsg) (MarshalType, error) { dropCollectionMsg := input.(*DropCollectionMsg) dropCollectionRequest := &dropCollectionMsg.DropCollectionRequest @@ -705,6 +754,10 @@ func (cp *CreatePartitionMsg) Type() MsgType { return cp.Base.MsgType } +func (cp *CreatePartitionMsg) SourceID() int64 { + return cp.Base.SourceID +} + func (cp *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) { createPartitionMsg := input.(*CreatePartitionMsg) createPartitionRequest := &createPartitionMsg.CreatePartitionRequest @@ -754,6 +807,10 @@ func (dp *DropPartitionMsg) Type() MsgType { return dp.Base.MsgType } +func (dp *DropPartitionMsg) SourceID() int64 { + return dp.Base.SourceID +} + func (dp *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) { dropPartitionMsg := input.(*DropPartitionMsg) dropPartitionRequest := &dropPartitionMsg.DropPartitionRequest @@ -803,6 +860,10 @@ func (lim *LoadIndexMsg) Type() MsgType { return lim.Base.MsgType } +func (lim *LoadIndexMsg) SourceID() int64 { + return lim.Base.SourceID +} + func (lim *LoadIndexMsg) Marshal(input TsMsg) (MarshalType, error) { loadIndexMsg := input.(*LoadIndexMsg) loadIndexRequest := &loadIndexMsg.LoadIndex @@ -850,6 +911,10 @@ func (sim *SegmentInfoMsg) Type() MsgType { return sim.Base.MsgType } +func (sim *SegmentInfoMsg) SourceID() int64 { + return sim.Base.SourceID +} + func (sim *SegmentInfoMsg) Marshal(input TsMsg) (MarshalType, error) { segInfoMsg := input.(*SegmentInfoMsg) mb, err := proto.Marshal(&segInfoMsg.SegmentMsg) @@ -896,6 +961,10 @@ func (l *LoadBalanceSegmentsMsg) Type() MsgType { return l.Base.MsgType } +func (l *LoadBalanceSegmentsMsg) SourceID() int64 { + return l.Base.SourceID +} + func (l *LoadBalanceSegmentsMsg) Marshal(input TsMsg) (MarshalType, error) { load := input.(*LoadBalanceSegmentsMsg) loadReq := &load.LoadBalanceSegmentsRequest @@ -944,6 +1013,10 @@ func (m *DataNodeTtMsg) Type() MsgType { return m.Base.MsgType } +func (m *DataNodeTtMsg) SourceID() int64 { + return m.Base.SourceID +} + func (m *DataNodeTtMsg) Marshal(input TsMsg) (MarshalType, error) { msg := input.(*DataNodeTtMsg) t, err := proto.Marshal(&msg.DataNodeTtMsg) diff --git a/internal/querynode/query_collection.go b/internal/querynode/query_collection.go index db63ff7951fe087f90384514010e1af58a60c3b7..187c83059c98e2c37134d0491baf4e5249439c55 100644 --- a/internal/querynode/query_collection.go +++ b/internal/querynode/query_collection.go @@ -14,7 +14,6 @@ package querynode import ( "context" "encoding/binary" - "errors" "fmt" "math" "reflect" @@ -44,9 +43,8 @@ type queryCollection struct { historical *historical streaming *streaming - unsolvedMsgMu sync.Mutex // guards unsolvedMsg - unsolvedMsg []*msgstream.SearchMsg - unsolvedRetrieveMsg []*msgstream.RetrieveMsg + unsolvedMsgMu sync.Mutex // guards unsolvedMsg + unsolvedMsg []msgstream.TsMsg tSafeWatchers map[Channel]*tSafeWatcher watcherSelectCase []reflect.SelectCase @@ -67,7 +65,7 @@ func newQueryCollection(releaseCtx context.Context, streaming *streaming, factory msgstream.Factory) *queryCollection { - unsolvedMsg := make([]*msgstream.SearchMsg, 0) + unsolvedMsg := make([]msgstream.TsMsg, 0) queryStream, _ := factory.NewQueryMsgStream(releaseCtx) queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx) @@ -96,8 +94,7 @@ func (q *queryCollection) start() { go q.queryMsgStream.Start() go q.queryResultMsgStream.Start() go q.consumeQuery() - go q.doUnsolvedMsgSearch() - go q.doUnsolvedMsgRetrieve() + go q.doUnsolvedQueryMsg() } func (q *queryCollection) close() { @@ -132,19 +129,13 @@ func (q *queryCollection) register() { } } -func (q *queryCollection) addToUnsolvedMsg(msg *msgstream.SearchMsg) { +func (q *queryCollection) addToUnsolvedMsg(msg msgstream.TsMsg) { q.unsolvedMsgMu.Lock() defer q.unsolvedMsgMu.Unlock() q.unsolvedMsg = append(q.unsolvedMsg, msg) } -func (q *queryCollection) addToUnsolvedRetrieveMsg(msg *msgstream.RetrieveMsg) { - q.unsolvedMsgMu.Lock() - defer q.unsolvedMsgMu.Unlock() - q.unsolvedRetrieveMsg = append(q.unsolvedRetrieveMsg, msg) -} - -func (q *queryCollection) popAllUnsolvedMsg() []*msgstream.SearchMsg { +func (q *queryCollection) popAllUnsolvedMsg() []msgstream.TsMsg { q.unsolvedMsgMu.Lock() defer q.unsolvedMsgMu.Unlock() tmp := q.unsolvedMsg @@ -152,14 +143,6 @@ func (q *queryCollection) popAllUnsolvedMsg() []*msgstream.SearchMsg { return tmp } -func (q *queryCollection) popAllUnsolvedRetrieveMsg() []*msgstream.RetrieveMsg { - q.unsolvedMsgMu.Lock() - defer q.unsolvedMsgMu.Unlock() - tmp := q.unsolvedRetrieveMsg - q.unsolvedRetrieveMsg = q.unsolvedRetrieveMsg[:0] - return tmp -} - func (q *queryCollection) waitNewTSafe() Timestamp { // block until any vChannel updating tSafe _, _, recvOK := reflect.Select(q.watcherSelectCase) @@ -201,17 +184,6 @@ func (q *queryCollection) setServiceableTime(t Timestamp) { } } -func (q *queryCollection) emptySearch(searchMsg *msgstream.SearchMsg) { - sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx()) - defer sp.Finish() - searchMsg.SetTraceCtx(ctx) - err := q.search(searchMsg) - if err != nil { - log.Error(err.Error()) - q.publishFailedSearchResult(searchMsg, err.Error()) - } -} - func (q *queryCollection) consumeQuery() { for { select { @@ -233,11 +205,11 @@ func (q *queryCollection) consumeQuery() { for _, msg := range msgPack.Msgs { switch sm := msg.(type) { case *msgstream.SearchMsg: - q.receiveSearch(sm) + q.receiveQueryMsg(sm) case *msgstream.LoadBalanceSegmentsMsg: q.loadBalance(sm) case *msgstream.RetrieveMsg: - q.receiveRetrieve(sm) + q.receiveQueryMsg(sm) default: log.Warn("unsupported msg type in search channel", zap.Any("msg", sm)) } @@ -277,123 +249,86 @@ func (q *queryCollection) loadBalance(msg *msgstream.LoadBalanceSegmentsMsg) { // zap.Int("num of segment", len(msg.Infos))) } -func (q *queryCollection) receiveRetrieve(msg *msgstream.RetrieveMsg) { - if msg.CollectionID != q.collectionID { - log.Debug("not target collection retrieve request", - zap.Any("collectionID", msg.CollectionID), +func (q *queryCollection) receiveQueryMsg(msg msgstream.TsMsg) { + msgType := msg.Type() + var collectionID UniqueID + var msgTypeStr string + + switch msgType { + case commonpb.MsgType_Retrieve: + collectionID = msg.(*msgstream.RetrieveMsg).CollectionID + msgTypeStr = "retrieve" + log.Debug("consume retrieve message", + zap.Any("collectionID", collectionID), zap.Int64("msgID", msg.ID()), ) - return - } - - log.Debug("consume retrieve message", - zap.Any("collectionID", msg.CollectionID), - zap.Int64("msgID", msg.ID()), - ) - sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) - msg.SetTraceCtx(ctx) - - // check if collection has been released - collection, err := q.historical.replica.getCollectionByID(msg.CollectionID) - if err != nil { - log.Error(err.Error()) - q.publishFailedRetrieveResult(msg, err.Error()) - return - } - if msg.BeginTs() >= collection.getReleaseTime() { - err := errors.New("retrieve failed, collection has been released, msgID = " + - fmt.Sprintln(msg.ID()) + - ", collectionID = " + - fmt.Sprintln(msg.CollectionID)) + case commonpb.MsgType_Search: + collectionID = msg.(*msgstream.SearchMsg).CollectionID + msgTypeStr = "search" + log.Debug("consume search message", + zap.Any("collectionID", collectionID), + zap.Int64("msgID", msg.ID()), + ) + default: + err := fmt.Errorf("receive invalid msgType = %d", msgType) log.Error(err.Error()) - q.publishFailedRetrieveResult(msg, err.Error()) return } - - serviceTime := q.getServiceableTime() - if msg.BeginTs() > serviceTime { - bt, _ := tsoutil.ParseTS(msg.BeginTs()) - st, _ := tsoutil.ParseTS(serviceTime) - log.Debug("query node::receiveRetrieveMsg: add to unsolvedMsg", + if collectionID != q.collectionID { + log.Error("not target collection query request", zap.Any("collectionID", q.collectionID), - zap.Any("sm.BeginTs", bt), - zap.Any("serviceTime", st), - zap.Any("delta seconds", (msg.BeginTs()-serviceTime)/(1000*1000*1000)), - zap.Any("msgID", msg.ID()), - ) - q.addToUnsolvedRetrieveMsg(msg) - sp.LogFields( - oplog.String("send to unsolved buffer", "send to unsolved buffer"), - oplog.Object("begin ts", bt), - oplog.Object("serviceTime", st), - oplog.Float64("delta seconds", float64(msg.BeginTs()-serviceTime)/(1000.0*1000.0*1000.0)), - ) - sp.Finish() - return - } - log.Debug("doing retrieve in receiveRetrieveMsg...", - zap.Int64("collectionID", msg.CollectionID), - zap.Int64("msgID", msg.ID()), - ) - err = q.retrieve(msg) - if err != nil { - log.Error(err.Error()) - log.Debug("do retrieve failed in receiveRetrieveMsg, prepare to publish failed retrieve result", - zap.Int64("collectionID", msg.CollectionID), - zap.Int64("msgID", msg.ID()), - ) - q.publishFailedRetrieveResult(msg, err.Error()) - } - log.Debug("do retrieve done in receiveRetrieve", - zap.Int64("collectionID", msg.CollectionID), - zap.Int64("msgID", msg.ID()), - ) - sp.Finish() -} - -func (q *queryCollection) receiveSearch(msg *msgstream.SearchMsg) { - if msg.CollectionID != q.collectionID { - log.Debug("not target collection search request", - zap.Any("collectionID", msg.CollectionID), + zap.Int64("target collectionID", collectionID), zap.Int64("msgID", msg.ID()), ) return } - log.Debug("consume search message", - zap.Any("collectionID", msg.CollectionID), - zap.Int64("msgID", msg.ID()), - ) sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) msg.SetTraceCtx(ctx) // check if collection has been released - collection, err := q.historical.replica.getCollectionByID(msg.CollectionID) + collection, err := q.historical.replica.getCollectionByID(collectionID) if err != nil { log.Error(err.Error()) - q.publishFailedSearchResult(msg, err.Error()) + err = q.publishFailedQueryResult(msg, err.Error()) + if err != nil { + log.Error(err.Error()) + } else { + log.Debug("do query failed in receiveQueryMsg, publish failed query result", + zap.Int64("collectionID", collectionID), + zap.Int64("msgID", msg.ID()), + zap.String("msgType", msgTypeStr), + ) + } return } if msg.BeginTs() >= collection.getReleaseTime() { - err := errors.New("search failed, collection has been released, msgID = " + - fmt.Sprintln(msg.ID()) + - ", collectionID = " + - fmt.Sprintln(msg.CollectionID)) + err = fmt.Errorf("retrieve failed, collection has been released, msgID = %d, collectionID = %d", msg.ID(), collectionID) log.Error(err.Error()) - q.publishFailedSearchResult(msg, err.Error()) + err = q.publishFailedQueryResult(msg, err.Error()) + if err != nil { + log.Error(err.Error()) + } else { + log.Debug("do query failed in receiveQueryMsg, publish failed query result", + zap.Int64("collectionID", collectionID), + zap.Int64("msgID", msg.ID()), + zap.String("msgType", msgTypeStr), + ) + } return } serviceTime := q.getServiceableTime() - bt, _ := tsoutil.ParseTS(msg.BeginTs()) - st, _ := tsoutil.ParseTS(serviceTime) if msg.BeginTs() > serviceTime { - log.Debug("query node::receiveSearchMsg: add to unsolvedMsg", + bt, _ := tsoutil.ParseTS(msg.BeginTs()) + st, _ := tsoutil.ParseTS(serviceTime) + log.Debug("query node::receiveQueryMsg: add to unsolvedMsg", zap.Any("collectionID", q.collectionID), zap.Any("sm.BeginTs", bt), zap.Any("serviceTime", st), zap.Any("delta seconds", (msg.BeginTs()-serviceTime)/(1000*1000*1000)), zap.Any("msgID", msg.ID()), + zap.String("msgType", msgTypeStr), ) q.addToUnsolvedMsg(msg) sp.LogFields( @@ -405,36 +340,49 @@ func (q *queryCollection) receiveSearch(msg *msgstream.SearchMsg) { sp.Finish() return } - log.Debug("doing search in receiveSearchMsg...", - zap.Int64("collectionID", msg.CollectionID), + log.Debug("doing query in receiveQueryMsg...", + zap.Int64("collectionID", collectionID), zap.Int64("msgID", msg.ID()), - zap.Any("serviceTime_l", serviceTime), - zap.Any("searchTime_l", msg.BeginTs()), - zap.Any("serviceTime_p", st), - zap.Any("searchTime_p", bt), + zap.String("msgType", msgTypeStr), ) - err = q.search(msg) + switch msgType { + case commonpb.MsgType_Retrieve: + err = q.retrieve(msg) + case commonpb.MsgType_Search: + err = q.search(msg) + default: + err := fmt.Errorf("receive invalid msgType = %d", msgType) + log.Error(err.Error()) + return + } + if err != nil { log.Error(err.Error()) - log.Debug("do search failed in receiveSearchMsg, prepare to publish failed search result", - zap.Int64("collectionID", msg.CollectionID), - zap.Int64("msgID", msg.ID()), - ) - q.publishFailedSearchResult(msg, err.Error()) + err = q.publishFailedQueryResult(msg, err.Error()) + if err != nil { + log.Error(err.Error()) + } else { + log.Debug("do query failed in receiveQueryMsg, publish failed query result", + zap.Int64("collectionID", collectionID), + zap.Int64("msgID", msg.ID()), + zap.String("msgType", msgTypeStr), + ) + } } - log.Debug("do search done in receiveSearch", - zap.Int64("collectionID", msg.CollectionID), + log.Debug("do query done in receiveQueryMsg", + zap.Int64("collectionID", collectionID), zap.Int64("msgID", msg.ID()), + zap.String("msgType", msgTypeStr), ) sp.Finish() } -func (q *queryCollection) doUnsolvedMsgSearch() { - log.Debug("starting doUnsolvedMsgSearch...", zap.Any("collectionID", q.collectionID)) +func (q *queryCollection) doUnsolvedQueryMsg() { + log.Debug("starting doUnsolvedMsg...", zap.Any("collectionID", q.collectionID)) for { select { case <-q.releaseCtx.Done(): - log.Debug("stop searchCollection's doUnsolvedMsgSearch", zap.Int64("collectionID", q.collectionID)) + log.Debug("stop Collection's doUnsolvedMsg", zap.Int64("collectionID", q.collectionID)) return default: //time.Sleep(10 * time.Millisecond) @@ -445,64 +393,80 @@ func (q *queryCollection) doUnsolvedMsgSearch() { zap.Any("tSafe", st)) q.setServiceableTime(serviceTime) - //log.Debug("query node::doUnsolvedMsgSearch: setServiceableTime", + //log.Debug("query node::doUnsolvedMsg: setServiceableTime", // zap.Any("serviceTime", st), //) - searchMsg := make([]*msgstream.SearchMsg, 0) + unSolvedMsg := make([]msgstream.TsMsg, 0) tempMsg := q.popAllUnsolvedMsg() - for _, sm := range tempMsg { - bt, _ := tsoutil.ParseTS(sm.EndTs()) + for _, m := range tempMsg { + bt, _ := tsoutil.ParseTS(m.EndTs()) st, _ = tsoutil.ParseTS(serviceTime) - log.Debug("get search message from unsolvedMsg", - zap.Int64("collectionID", sm.CollectionID), - zap.Int64("msgID", sm.ID()), + log.Debug("get query message from unsolvedMsg", + zap.Int64("collectionID", q.collectionID), + zap.Int64("msgID", m.ID()), zap.Any("reqTime_p", bt), zap.Any("serviceTime_p", st), - zap.Any("reqTime_l", sm.EndTs()), + zap.Any("reqTime_l", m.EndTs()), zap.Any("serviceTime_l", serviceTime), ) - if sm.EndTs() <= serviceTime { - searchMsg = append(searchMsg, sm) + if m.EndTs() <= serviceTime { + unSolvedMsg = append(unSolvedMsg, m) continue } - log.Debug("query node::doUnsolvedMsgSearch: add to unsolvedMsg", + log.Debug("query node::doUnsolvedMsg: add to unsolvedMsg", zap.Any("collectionID", q.collectionID), zap.Any("sm.BeginTs", bt), zap.Any("serviceTime", st), - zap.Any("delta seconds", (sm.BeginTs()-serviceTime)/(1000*1000*1000)), - zap.Any("msgID", sm.ID()), + zap.Any("delta seconds", (m.BeginTs()-serviceTime)/(1000*1000*1000)), + zap.Any("msgID", m.ID()), ) - q.addToUnsolvedMsg(sm) + q.addToUnsolvedMsg(m) } - if len(searchMsg) <= 0 { + if len(unSolvedMsg) <= 0 { continue } - for _, sm := range searchMsg { - sp, ctx := trace.StartSpanFromContext(sm.TraceCtx()) - sm.SetTraceCtx(ctx) - log.Debug("doing search in doUnsolvedMsgSearch...", - zap.Int64("collectionID", sm.CollectionID), - zap.Int64("msgID", sm.ID()), + for _, m := range unSolvedMsg { + msgType := m.Type() + var err error + sp, ctx := trace.StartSpanFromContext(m.TraceCtx()) + m.SetTraceCtx(ctx) + log.Debug("doing search in doUnsolvedMsg...", + zap.Int64("collectionID", q.collectionID), + zap.Int64("msgID", m.ID()), ) - err := q.search(sm) + switch msgType { + case commonpb.MsgType_Retrieve: + err = q.retrieve(m) + case commonpb.MsgType_Search: + err = q.search(m) + default: + err := fmt.Errorf("receive invalid msgType = %d", msgType) + log.Error(err.Error()) + return + } + if err != nil { log.Error(err.Error()) - log.Debug("do search failed in doUnsolvedMsgSearch, prepare to publish failed search result", - zap.Int64("collectionID", sm.CollectionID), - zap.Int64("msgID", sm.ID()), - ) - q.publishFailedSearchResult(sm, err.Error()) + err = q.publishFailedQueryResult(m, err.Error()) + if err != nil { + log.Error(err.Error()) + } else { + log.Debug("do query failed in doUnsolvedMsg, publish failed query result", + zap.Int64("collectionID", q.collectionID), + zap.Int64("msgID", m.ID()), + ) + } } sp.Finish() - log.Debug("do search done in doUnsolvedMsgSearch", - zap.Int64("collectionID", sm.CollectionID), - zap.Int64("msgID", sm.ID()), + log.Debug("do query done in doUnsolvedMsg", + zap.Int64("collectionID", q.collectionID), + zap.Int64("msgID", m.ID()), ) } - log.Debug("doUnsolvedMsgSearch, do search done", zap.Int("num of searchMsg", len(searchMsg))) + log.Debug("doUnsolvedMsg: do query done", zap.Int("num of query msg", len(unSolvedMsg))) } } } @@ -731,7 +695,8 @@ func translateHits(schema *typeutil.SchemaHelper, fieldIDs []int64, rawHits [][] // TODO:: cache map[dsl]plan // TODO: reBatched search requests -func (q *queryCollection) search(searchMsg *msgstream.SearchMsg) error { +func (q *queryCollection) search(msg msgstream.TsMsg) error { + searchMsg := msg.(*msgstream.SearchMsg) sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx()) defer sp.Finish() searchMsg.SetTraceCtx(ctx) @@ -873,7 +838,7 @@ func (q *queryCollection) search(searchMsg *msgstream.SearchMsg) error { zap.Any("vChannels", collection.getVChannels()), zap.Any("sealedSegmentSearched", sealedSegmentSearched), ) - err = q.publishSearchResult(searchResultMsg, searchMsg.CollectionID) + err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID) if err != nil { return err } @@ -993,7 +958,7 @@ func (q *queryCollection) search(searchMsg *msgstream.SearchMsg) error { // fmt.Println(testHits.IDs) // fmt.Println(testHits.Scores) //} - err = q.publishSearchResult(searchResultMsg, searchMsg.CollectionID) + err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID) if err != nil { return err } @@ -1008,147 +973,14 @@ func (q *queryCollection) search(searchMsg *msgstream.SearchMsg) error { return nil } -func (q *queryCollection) publishSearchResult(msg msgstream.TsMsg, collectionID UniqueID) error { - log.Debug("publishing search result...", - zap.Int64("collectionID", collectionID), - zap.Int64("msgID", msg.ID()), - ) - span, ctx := trace.StartSpanFromContext(msg.TraceCtx()) - defer span.Finish() - msg.SetTraceCtx(ctx) - msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, msg) - err := q.queryResultMsgStream.Produce(&msgPack) - if err != nil { - log.Error("publishing search result failed, err = "+err.Error(), - zap.Int64("collectionID", collectionID), - zap.Int64("msgID", msg.ID()), - ) - } else { - log.Debug("publish search result done", - zap.Int64("collectionID", collectionID), - zap.Int64("msgID", msg.ID()), - ) - } - return err -} - -func (q *queryCollection) publishFailedSearchResult(searchMsg *msgstream.SearchMsg, errMsg string) { - span, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx()) - defer span.Finish() - searchMsg.SetTraceCtx(ctx) - //log.Debug("Public fail SearchResult!") - msgPack := msgstream.MsgPack{} - - resultChannelInt := 0 - searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(resultChannelInt)}}, - SearchResults: internalpb.SearchResults{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_SearchResult, - MsgID: searchMsg.Base.MsgID, - Timestamp: searchMsg.Base.Timestamp, - SourceID: searchMsg.Base.SourceID, - }, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg}, - ResultChannelID: searchMsg.ResultChannelID, - }, - } - - msgPack.Msgs = append(msgPack.Msgs, searchResultMsg) - err := q.queryResultMsgStream.Produce(&msgPack) - if err != nil { - log.Error("publish FailedSearchResult failed" + err.Error()) - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -func (q *queryCollection) doUnsolvedMsgRetrieve() { - log.Debug("starting doUnsolvedMsgRetrieve...", zap.Any("collectionID", q.collectionID)) - for { - select { - case <-q.releaseCtx.Done(): - log.Debug("stop retrieveCollection's doUnsolvedMsgRertieve", zap.Int64("collectionID", q.collectionID)) - return - default: - //time.Sleep(10 * time.Millisecond) - serviceTime := q.waitNewTSafe() - st, _ := tsoutil.ParseTS(serviceTime) - log.Debug("get tSafe from flow graph", - zap.Int64("collectionID", q.collectionID), - zap.Any("tSafe", st)) - - q.setServiceableTime(serviceTime) - //log.Debug("query node::doUnsolvedMsgSearch: setServiceableTime", - // zap.Any("serviceTime", st), - //) - - retrieveMsg := make([]*msgstream.RetrieveMsg, 0) - tempMsg := q.popAllUnsolvedRetrieveMsg() - - for _, rm := range tempMsg { - bt, _ := tsoutil.ParseTS(rm.EndTs()) - st, _ = tsoutil.ParseTS(serviceTime) - log.Debug("get retrieve message from unsolvedMsg", - zap.Int64("collectionID", rm.CollectionID), - zap.Int64("msgID", rm.ID()), - zap.Any("reqTime_p", bt), - zap.Any("serviceTime_p", st), - zap.Any("reqTime_l", rm.EndTs()), - zap.Any("serviceTime_l", serviceTime), - ) - if rm.EndTs() <= serviceTime { - retrieveMsg = append(retrieveMsg, rm) - continue - } - log.Debug("query node::doUnsolvedMsgRetrieve: add to unsolvedMsg", - zap.Any("collectionID", q.collectionID), - zap.Any("sm.BeginTs", bt), - zap.Any("serviceTime", st), - zap.Any("delta seconds", (rm.BeginTs()-serviceTime)/(1000*1000*1000)), - zap.Any("msgID", rm.ID()), - ) - q.addToUnsolvedRetrieveMsg(rm) - } - - if len(retrieveMsg) <= 0 { - continue - } - for _, rm := range retrieveMsg { - sp, ctx := trace.StartSpanFromContext(rm.TraceCtx()) - rm.SetTraceCtx(ctx) - log.Debug("doing search in doUnsolvedMsgRetrieve...", - zap.Int64("collectionID", rm.CollectionID), - zap.Int64("msgID", rm.ID()), - ) - err := q.retrieve(rm) - if err != nil { - log.Error(err.Error()) - log.Debug("do retrieve failed in doUnsolvedMsgSearch, prepare to publish failed retrieve result", - zap.Int64("collectionID", rm.CollectionID), - zap.Int64("msgID", rm.ID()), - ) - q.publishFailedRetrieveResult(rm, err.Error()) - } - sp.Finish() - log.Debug("do retrieve done in doUnsolvedMsgSearch", - zap.Int64("collectionID", rm.CollectionID), - zap.Int64("msgID", rm.ID()), - ) - } - log.Debug("doUnsolvedMsgRetrieve, do retrieve done", zap.Int("num of retrieveMsg", len(retrieveMsg))) - } - } -} - -func (q *queryCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error { +func (q *queryCollection) retrieve(msg msgstream.TsMsg) error { // TODO(yukun) // step 1: get retrieve object and defer destruction // step 2: for each segment, call retrieve to get ids proto buffer // step 3: merge all proto in go // step 4: publish results // retrieveProtoBlob, err := proto.Marshal(&retrieveMsg.RetrieveRequest) + retrieveMsg := msg.(*msgstream.RetrieveMsg) sp, ctx := trace.StartSpanFromContext(retrieveMsg.TraceCtx()) defer sp.Finish() retrieveMsg.SetTraceCtx(ctx) @@ -1237,8 +1069,6 @@ func (q *queryCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error { } } - log.Debug("1111", zap.Any("len of mergeList", len(mergeList))) - result, err := mergeRetrieveResults(mergeList) if err != nil { return err @@ -1263,15 +1093,16 @@ func (q *queryCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error { GlobalSealedSegmentIDs: sealedSegmentRetrieved, }, } - log.Debug("QueryNode RetrieveResultMsg", + + err3 := q.publishQueryResult(retrieveResultMsg, retrieveMsg.CollectionID) + if err3 != nil { + return err3 + } + log.Debug("QueryNode publish RetrieveResultMsg", zap.Any("vChannels", collection.getVChannels()), zap.Any("collectionID", collection.ID()), zap.Any("sealedSegmentRetrieved", sealedSegmentRetrieved), ) - err3 := q.publishRetrieveResult(retrieveResultMsg, retrieveMsg.CollectionID) - if err3 != nil { - return err3 - } return nil } @@ -1308,10 +1139,7 @@ func mergeRetrieveResults(dataArr []*segcorepb.RetrieveResults) (*segcorepb.Retr return final, nil } -func (q *queryCollection) publishRetrieveResult(msg msgstream.TsMsg, collectionID UniqueID) error { - log.Debug("publishing retrieve result...", - zap.Int64("msgID", msg.ID()), - zap.Int64("collectionID", collectionID)) +func (q *queryCollection) publishQueryResult(msg msgstream.TsMsg, collectionID UniqueID) error { span, ctx := trace.StartSpanFromContext(msg.TraceCtx()) defer span.Finish() msg.SetTraceCtx(ctx) @@ -1320,38 +1148,59 @@ func (q *queryCollection) publishRetrieveResult(msg msgstream.TsMsg, collectionI err := q.queryResultMsgStream.Produce(&msgPack) if err != nil { log.Error(err.Error()) - } else { - log.Debug("publish retrieve result done", - zap.Int64("msgID", msg.ID()), - zap.Int64("collectionID", collectionID)) } + return err } -func (q *queryCollection) publishFailedRetrieveResult(retrieveMsg *msgstream.RetrieveMsg, errMsg string) error { - span, ctx := trace.StartSpanFromContext(retrieveMsg.TraceCtx()) +func (q *queryCollection) publishFailedQueryResult(msg msgstream.TsMsg, errMsg string) error { + msgType := msg.Type() + span, ctx := trace.StartSpanFromContext(msg.TraceCtx()) defer span.Finish() - retrieveMsg.SetTraceCtx(ctx) + msg.SetTraceCtx(ctx) msgPack := msgstream.MsgPack{} resultChannelInt := 0 - retrieveResultMsg := &msgstream.RetrieveResultMsg{ - BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(resultChannelInt)}}, - RetrieveResults: internalpb.RetrieveResults{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_RetrieveResult, - MsgID: retrieveMsg.Base.MsgID, - Timestamp: retrieveMsg.Base.Timestamp, - SourceID: retrieveMsg.Base.SourceID, + baseMsg := msgstream.BaseMsg{ + HashValues: []uint32{uint32(resultChannelInt)}, + } + baseResult := &commonpb.MsgBase{ + MsgID: msg.ID(), + Timestamp: msg.BeginTs(), + SourceID: msg.SourceID(), + } + + switch msgType { + case commonpb.MsgType_Retrieve: + retrieveMsg := msg.(*msgstream.RetrieveMsg) + baseResult.MsgType = commonpb.MsgType_RetrieveResult + retrieveResultMsg := &msgstream.RetrieveResultMsg{ + BaseMsg: baseMsg, + RetrieveResults: internalpb.RetrieveResults{ + Base: baseResult, + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg}, + ResultChannelID: retrieveMsg.ResultChannelID, + Ids: nil, + FieldsData: nil, }, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg}, - ResultChannelID: retrieveMsg.ResultChannelID, - Ids: nil, - FieldsData: nil, - }, + } + msgPack.Msgs = append(msgPack.Msgs, retrieveResultMsg) + case commonpb.MsgType_Search: + searchMsg := msg.(*msgstream.SearchMsg) + baseResult.MsgType = commonpb.MsgType_SearchResult + searchResultMsg := &msgstream.SearchResultMsg{ + BaseMsg: baseMsg, + SearchResults: internalpb.SearchResults{ + Base: baseResult, + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg}, + ResultChannelID: searchMsg.ResultChannelID, + }, + } + msgPack.Msgs = append(msgPack.Msgs, searchResultMsg) + default: + return fmt.Errorf("publish invalid msgType %d", msgType) } - msgPack.Msgs = append(msgPack.Msgs, retrieveResultMsg) err := q.queryResultMsgStream.Produce(&msgPack) if err != nil { return err