未验证 提交 d3c503f3 编写于 作者: D dragondriver 提交者: GitHub

Calculate the real topk in proxy (#6132)

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 6f4ad331
......@@ -1568,7 +1568,7 @@ func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNode
return ret
}
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) *milvuspb.SearchResults {
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
log.Debug("reduceSearchResultDataParallel", zap.Any("NumOfGoRoutines", maxParallel))
ret := &milvuspb.SearchResults{
......@@ -1593,10 +1593,12 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
const minFloat32 = -1 * float32(math.MaxFloat32)
// TODO(yukun): Use parallel function
realTopK := -1
for idx := 0; idx < nq; idx++ {
locs := make([]int, availableQueryNodeNum)
for j := 0; j < topk; j++ {
j := 0
for ; j < topk; j++ {
valid := false
choice, maxDistance := 0, minFloat32
for q, loc := range locs { // query num, the number of ways to merge
......@@ -1696,7 +1698,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
}
default:
log.Debug("Not supported field type")
return nil
return nil, fmt.Errorf("not supported field type: %s", fieldData.Type.String())
}
case *schemapb.FieldData_Vectors:
dim := fieldType.Vectors.Dim
......@@ -1729,9 +1731,15 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
ret.Results.Scores = append(ret.Results.Scores, searchResultData[choice].Scores[idx*topk+choiceOffset])
locs[choice]++
}
if realTopK != -1 && realTopK != j {
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = j
}
ret.Results.TopK = int64(realTopK)
if metricType != "IP" {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
......@@ -1742,7 +1750,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
// return nil
// }
return ret
return ret, nil
}
func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
......@@ -1767,7 +1775,7 @@ func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, top
return reduceSearchResultsParallelByCPU(hits, nq, availableQueryNodeNum, topk, metricType)
}
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) (*milvuspb.SearchResults, error) {
t := time.Now()
defer func() {
log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t)))
......@@ -1853,7 +1861,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
}
nq := results[0].NumQueries
topk := results[0].TopK
topk := 0
for _, partialResult := range results {
topk = getMax(topk, int(partialResult.TopK))
}
if nq <= 0 {
st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
......@@ -1864,7 +1875,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
return nil
}
st.result = reduceSearchResultData(results, int(nq), availableQueryNodeNum, int(topk), searchResults[0].MetricType)
st.result, err = reduceSearchResultData(results, int(nq), availableQueryNodeNum, topk, searchResults[0].MetricType)
if err != nil {
return err
}
schema, err := globalMetaCache.GetCollectionSchema(ctx, st.query.CollectionName)
if err != nil {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册