未验证 提交 d3c503f3 编写于 作者: D dragondriver 提交者: GitHub

Calculate the real topk in proxy (#6132)

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 6f4ad331
...@@ -1568,7 +1568,7 @@ func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNode ...@@ -1568,7 +1568,7 @@ func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNode
return ret return ret
} }
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) *milvuspb.SearchResults { func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
log.Debug("reduceSearchResultDataParallel", zap.Any("NumOfGoRoutines", maxParallel)) log.Debug("reduceSearchResultDataParallel", zap.Any("NumOfGoRoutines", maxParallel))
ret := &milvuspb.SearchResults{ ret := &milvuspb.SearchResults{
...@@ -1593,10 +1593,12 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat ...@@ -1593,10 +1593,12 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
const minFloat32 = -1 * float32(math.MaxFloat32) const minFloat32 = -1 * float32(math.MaxFloat32)
// TODO(yukun): Use parallel function // TODO(yukun): Use parallel function
realTopK := -1
for idx := 0; idx < nq; idx++ { for idx := 0; idx < nq; idx++ {
locs := make([]int, availableQueryNodeNum) locs := make([]int, availableQueryNodeNum)
for j := 0; j < topk; j++ { j := 0
for ; j < topk; j++ {
valid := false valid := false
choice, maxDistance := 0, minFloat32 choice, maxDistance := 0, minFloat32
for q, loc := range locs { // query num, the number of ways to merge for q, loc := range locs { // query num, the number of ways to merge
...@@ -1696,7 +1698,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat ...@@ -1696,7 +1698,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
} }
default: default:
log.Debug("Not supported field type") log.Debug("Not supported field type")
return nil return nil, fmt.Errorf("not supported field type: %s", fieldData.Type.String())
} }
case *schemapb.FieldData_Vectors: case *schemapb.FieldData_Vectors:
dim := fieldType.Vectors.Dim dim := fieldType.Vectors.Dim
...@@ -1729,9 +1731,15 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat ...@@ -1729,9 +1731,15 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
ret.Results.Scores = append(ret.Results.Scores, searchResultData[choice].Scores[idx*topk+choiceOffset]) ret.Results.Scores = append(ret.Results.Scores, searchResultData[choice].Scores[idx*topk+choiceOffset])
locs[choice]++ locs[choice]++
} }
if realTopK != -1 && realTopK != j {
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = j
} }
ret.Results.TopK = int64(realTopK)
if metricType != "IP" { if metricType != "IP" {
for k := range ret.Results.Scores { for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1 ret.Results.Scores[k] *= -1
...@@ -1742,7 +1750,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat ...@@ -1742,7 +1750,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
// return nil // return nil
// } // }
return ret return ret, nil
} }
func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults { func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
...@@ -1767,7 +1775,7 @@ func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, top ...@@ -1767,7 +1775,7 @@ func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, top
return reduceSearchResultsParallelByCPU(hits, nq, availableQueryNodeNum, topk, metricType) return reduceSearchResultsParallelByCPU(hits, nq, availableQueryNodeNum, topk, metricType)
} }
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults { func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) (*milvuspb.SearchResults, error) {
t := time.Now() t := time.Now()
defer func() { defer func() {
log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t))) log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t)))
...@@ -1853,7 +1861,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error { ...@@ -1853,7 +1861,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
} }
nq := results[0].NumQueries nq := results[0].NumQueries
topk := results[0].TopK topk := 0
for _, partialResult := range results {
topk = getMax(topk, int(partialResult.TopK))
}
if nq <= 0 { if nq <= 0 {
st.result = &milvuspb.SearchResults{ st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
...@@ -1864,7 +1875,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error { ...@@ -1864,7 +1875,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
return nil return nil
} }
st.result = reduceSearchResultData(results, int(nq), availableQueryNodeNum, int(topk), searchResults[0].MetricType) st.result, err = reduceSearchResultData(results, int(nq), availableQueryNodeNum, topk, searchResults[0].MetricType)
if err != nil {
return err
}
schema, err := globalMetaCache.GetCollectionSchema(ctx, st.query.CollectionName) schema, err := globalMetaCache.GetCollectionSchema(ctx, st.query.CollectionName)
if err != nil { if err != nil {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册