未验证 提交 0ab70271 编写于 作者: C Cai Yudong 提交者: GitHub

correct empty search result handle (#7244)

Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 e782ee94
......@@ -144,17 +144,15 @@ message SearchResults {
common.Status status = 2;
string result_channelID = 3;
string metric_type = 4;
repeated bytes hits = 5;
int64 num_queries = 5;
int64 top_k = 6;
repeated int64 sealed_segmentIDs_searched = 7;
repeated string channelIDs_searched = 8;
repeated int64 global_sealed_segmentIDs = 9;
// schema.SearchResultsData inside
bytes sliced_blob = 9;
int64 sliced_num_count = 10;
int64 sliced_offset = 11;
repeated int64 sealed_segmentIDs_searched = 6;
repeated string channelIDs_searched = 7;
repeated int64 global_sealed_segmentIDs = 8;
bytes sliced_blob = 10;
int64 sliced_num_count = 11;
int64 sliced_offset = 12;
}
message RetrieveRequest {
......
......@@ -1659,9 +1659,8 @@ func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb
// return decodeSearchResultsParallelByCPU(searchResults)
}
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
nq := searchResultData[0].NumQueries
topk := searchResultData[0].TopK
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
nq int64, topk int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
log.Debug("reduceSearchResultDataParallel",
zap.Int("len(searchResultData)", len(searchResultData)),
......@@ -1887,25 +1886,26 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
return ret, nil
}
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, metricType string) (*milvuspb.SearchResults, error) {
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
t := time.Now()
defer func() {
log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t)))
}()
return reduceSearchResultDataParallel(searchResultData, availableQueryNodeNum, metricType, runtime.NumCPU())
return reduceSearchResultDataParallel(searchResultData, availableQueryNodeNum, nq, topk, metricType, runtime.NumCPU())
}
func printSearchResult(partialSearchResult *internalpb.SearchResults) {
for i := 0; i < len(partialSearchResult.Hits); i++ {
testHits := milvuspb.Hits{}
err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
if err != nil {
panic(err)
}
fmt.Println(testHits.IDs)
fmt.Println(testHits.Scores)
}
}
//func printSearchResult(partialSearchResult *internalpb.SearchResults) {
// for i := 0; i < len(partialSearchResult.Hits); i++ {
// testHits := milvuspb.Hits{}
// err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
// if err != nil {
// panic(err)
// }
// fmt.Println(testHits.IDs)
// fmt.Println(testHits.Scores)
// }
//}
func (st *SearchTask) PostExecute(ctx context.Context) error {
t0 := time.Now()
......@@ -1947,10 +1947,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
availableQueryNodeNum = 0
for _, partialSearchResult := range filterSearchResult {
if partialSearchResult.SlicedBlob == nil {
filterReason += "nq is zero\n"
continue
filterReason += "empty search result\n"
} else {
availableQueryNodeNum++
}
availableQueryNodeNum++
}
log.Debug("Proxy Search PostExecute stage2", zap.Any("availableQueryNodeNum", availableQueryNodeNum))
......@@ -1962,6 +1962,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
ErrorCode: commonpb.ErrorCode_Success,
Reason: filterReason,
},
Results: &schemapb.SearchResultData{
NumQueries: searchResults[0].NumQueries,
Topks: make([]int64, searchResults[0].NumQueries),
},
}
return nil
}
......@@ -1972,7 +1976,8 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
return err
}
st.result, err = reduceSearchResultData(results, int64(availableQueryNodeNum), searchResults[0].MetricType)
st.result, err = reduceSearchResultData(results, int64(availableQueryNodeNum),
searchResults[0].NumQueries, searchResults[0].TopK, searchResults[0].MetricType)
if err != nil {
return err
}
......
......@@ -867,31 +867,7 @@ func (q *queryCollection) search(msg queryMsg) error {
sp.LogFields(oplog.String("statistical time", "segment search end"))
if len(searchResults) <= 0 {
for _, group := range searchRequests {
nq := group.getNumOfQuery()
nilHits := make([][]byte, nq)
hit := &milvuspb.Hits{}
for i := 0; i < int(nq); i++ {
bs, err := proto.Marshal(hit)
if err != nil {
return err
}
nilHits[i] = bs
}
// TODO: remove inefficient code in cgo and use SearchResultData directly
// TODO: Currently add a translate layer from hits to SearchResultData
// TODO: hits marshal and unmarshal is likely bottleneck
transformed, err := translateHits(schema, searchMsg.OutputFieldsId, nilHits)
if err != nil {
return err
}
byteBlobs, err := proto.Marshal(transformed)
if err != nil {
return err
}
for range searchRequests {
resultChannelInt := 0
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{Ctx: searchMsg.Ctx, HashValues: []uint32{uint32(resultChannelInt)}},
......@@ -904,11 +880,12 @@ func (q *queryCollection) search(msg queryMsg) error {
},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
ResultChannelID: searchMsg.ResultChannelID,
Hits: nilHits,
SlicedBlob: byteBlobs,
MetricType: plan.getMetricType(),
NumQueries: queryNum,
TopK: topK,
SlicedBlob: nil,
SlicedOffset: 1,
SlicedNumCount: 1,
MetricType: plan.getMetricType(),
SealedSegmentIDsSearched: sealedSegmentSearched,
ChannelIDsSearched: q.collection.getVChannels(),
GlobalSealedSegmentIDs: globalSealedSegments,
......@@ -995,11 +972,12 @@ func (q *queryCollection) search(msg queryMsg) error {
},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
ResultChannelID: searchMsg.ResultChannelID,
Hits: hits,
MetricType: plan.getMetricType(),
NumQueries: queryNum,
TopK: topK,
SlicedBlob: byteBlobs,
SlicedOffset: 1,
SlicedNumCount: 1,
MetricType: plan.getMetricType(),
SealedSegmentIDsSearched: sealedSegmentSearched,
ChannelIDsSearched: q.collection.getVChannels(),
GlobalSealedSegmentIDs: globalSealedSegments,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册