未验证 提交 297d75fc 编写于 作者: X XuanYang-cn 提交者: GitHub

Enable query pagination (#19231)

Signed-off-by: Nyangxuan <xuan.yang@zilliz.com>
Signed-off-by: Nyangxuan <xuan.yang@zilliz.com>
上级 563917b3
......@@ -116,7 +116,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
// if limit is provided
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
if err != nil {
return &queryParams{}, nil
return &queryParams{limit: typeutil.Unlimited}, nil
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil || limit <= 0 {
......@@ -331,7 +331,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0)
tr.CtxRecord(ctx, "reduceResultStart")
t.result, err = mergeRetrieveResults(ctx, t.toReduceResults)
t.result, err = reduceRetrieveResults(ctx, t.toReduceResults, t.queryParams)
if err != nil {
return err
}
......@@ -409,46 +409,66 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
return fieldName + " in [ " + idsStr + " ]"
}
func mergeRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) {
var ret *milvuspb.QueryResults
var skipDupCnt int64
var idSet = make(map[interface{}]struct{})
func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams) (*milvuspb.QueryResults, error) {
log.Ctx(ctx).Debug("reduceInternelRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults)))
var (
ret = &milvuspb.QueryResults{}
skipDupCnt int64
loopEnd int
)
// merge results and remove duplicates
for _, rr := range retrieveResults {
numPks := typeutil.GetSizeOfIDs(rr.GetIds())
// skip empty result, it will break merge result
if rr == nil || rr.Ids == nil || rr.GetIds() == nil || numPks == 0 {
validRetrieveResults := []*internalpb.RetrieveResults{}
for _, r := range retrieveResults {
size := typeutil.GetSizeOfIDs(r.GetIds())
if r == nil || len(r.GetFieldsData()) == 0 || size == 0 {
continue
}
validRetrieveResults = append(validRetrieveResults, r)
loopEnd += size
}
if len(validRetrieveResults) == 0 {
return ret, nil
}
if ret == nil {
ret = &milvuspb.QueryResults{
FieldsData: make([]*schemapb.FieldData, len(rr.FieldsData)),
ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData()))
idSet := make(map[interface{}]struct{})
cursors := make([]int64, len(validRetrieveResults))
if queryParams != nil && queryParams.limit != typeutil.Unlimited {
loopEnd = int(queryParams.limit)
if queryParams.offset > 0 {
for i := int64(0); i < queryParams.offset; i++ {
sel := typeutil.SelectMinPK(validRetrieveResults, cursors)
if sel == -1 {
return ret, nil
}
cursors[sel]++
}
}
}
if len(ret.FieldsData) != len(rr.FieldsData) {
return nil, fmt.Errorf("mismatch FieldData in proxy RetrieveResults, expect %d get %d", len(ret.FieldsData), len(rr.FieldsData))
for j := 0; j < loopEnd; j++ {
sel := typeutil.SelectMinPK(validRetrieveResults, cursors)
if sel == -1 {
break
}
for i := 0; i < numPks; i++ {
id := typeutil.GetPK(rr.GetIds(), int64(i))
if _, ok := idSet[id]; !ok {
typeutil.AppendFieldData(ret.FieldsData, rr.FieldsData, int64(i))
idSet[id] = struct{}{}
} else {
// primary keys duplicate
skipDupCnt++
}
pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel])
if _, ok := idSet[pk]; !ok {
typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
idSet[pk] = struct{}{}
} else {
// primary keys duplicate
skipDupCnt++
}
cursors[sel]++
}
log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
if ret == nil {
ret = &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{},
}
if skipDupCnt > 0 {
log.Ctx(ctx).Debug("skip duplicated query result while reducing QueryResults", zap.Int64("count", skipDupCnt))
}
return ret, nil
......
......@@ -6,18 +6,20 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/common"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/api/commonpb"
"github.com/milvus-io/milvus/api/milvuspb"
"github.com/milvus-io/milvus/api/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
......@@ -351,10 +353,10 @@ func TestTaskQuery_functions(t *testing.T) {
outLimit int64
outOffset int64
}{
{"empty input", []string{}, []string{}, false, 0, 0},
{"empty input", []string{}, []string{}, false, typeutil.Unlimited, 0},
{"valid limit=1", []string{LimitKey}, []string{"1"}, false, 1, 0},
{"valid limit=1, offset=2", []string{LimitKey, OffsetKey}, []string{"1", "2"}, false, 1, 2},
{"valid no limit, offset=2", []string{OffsetKey}, []string{"2"}, false, 0, 0},
{"valid no limit, offset=2", []string{OffsetKey}, []string{"2"}, false, typeutil.Unlimited, 0},
{"invalid limit str", []string{LimitKey}, []string{"a"}, true, 0, 0},
{"invalid limit zero", []string{LimitKey}, []string{"0"}, true, 0, 0},
{"invalid offset negative", []string{LimitKey, OffsetKey}, []string{"1", "-1"}, true, 0, 0},
......@@ -383,4 +385,264 @@ func TestTaskQuery_functions(t *testing.T) {
})
}
})
t.Run("test reduceRetrieveResults", func(t *testing.T) {
const (
Dim = 8
Int64FieldName = "Int64Field"
FloatVectorFieldName = "FloatVectorField"
Int64FieldID = common.StartOfUserFieldID + 1
FloatVectorFieldID = common.StartOfUserFieldID + 2
)
Int64Array := []int64{11, 22}
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
var fieldDataArray1 []*schemapb.FieldData
fieldDataArray1 = append(fieldDataArray1, getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
fieldDataArray1 = append(fieldDataArray1, getFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
var fieldDataArray2 []*schemapb.FieldData
fieldDataArray2 = append(fieldDataArray2, getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
fieldDataArray2 = append(fieldDataArray2, getFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
t.Run("test skip dupPK 2", func(t *testing.T) {
result1 := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{0, 1},
},
},
},
FieldsData: fieldDataArray1,
}
result2 := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{0, 1},
},
},
},
FieldsData: fieldDataArray2,
}
result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{result1, result2}, nil)
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, Int64Array, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
assert.InDeltaSlice(t, FloatVector, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
})
t.Run("test nil results", func(t *testing.T) {
ret, err := reduceRetrieveResults(context.Background(), nil, nil)
assert.NoError(t, err)
assert.Empty(t, ret.GetFieldsData())
})
t.Run("test merge", func(t *testing.T) {
r1 := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 3},
},
},
},
FieldsData: fieldDataArray1,
}
r2 := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{2, 4},
},
},
},
FieldsData: fieldDataArray2,
}
resultFloat := []float32{
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0,
11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
t.Run("test limited", func(t *testing.T) {
tests := []struct {
description string
limit int64
}{
{"limit 1", 1},
{"limit 2", 2},
{"limit 3", 3},
{"limit 4", 4},
}
resultField0 := []int64{11, 11, 22, 22}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, &queryParams{limit: test.limit})
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, resultField0[0:test.limit], result.GetFieldsData()[0].GetScalars().GetLongData().Data)
assert.InDeltaSlice(t, resultFloat[0:test.limit*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
assert.NoError(t, err)
})
}
})
t.Run("test offset", func(t *testing.T) {
tests := []struct {
description string
offset int64
}{
{"offset 0", 0},
{"offset 1", 1},
{"offset 2", 2},
{"offset 3", 3},
}
resultField0 := []int64{11, 11, 22, 22}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, &queryParams{limit: 1, offset: test.offset})
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, resultField0[test.offset:test.offset+1], result.GetFieldsData()[0].GetScalars().GetLongData().Data)
assert.InDeltaSlice(t, resultFloat[test.offset*Dim:(test.offset+1)*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
})
}
})
})
})
}
func getFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, fieldValue interface{}, dim int64) *schemapb.FieldData {
var fieldData *schemapb.FieldData
switch fieldType {
case schemapb.DataType_Bool:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_Bool,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: fieldValue.([]bool),
},
},
},
},
FieldId: fieldID,
}
case schemapb.DataType_Int32:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_Int32,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: fieldValue.([]int32),
},
},
},
},
FieldId: fieldID,
}
case schemapb.DataType_Int64:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: fieldValue.([]int64),
},
},
},
},
FieldId: fieldID,
}
case schemapb.DataType_Float:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_Float,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: fieldValue.([]float32),
},
},
},
},
FieldId: fieldID,
}
case schemapb.DataType_Double:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_Double,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: fieldValue.([]float64),
},
},
},
},
FieldId: fieldID,
}
case schemapb.DataType_VarChar:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: fieldValue.([]string),
},
},
},
},
FieldId: fieldID,
}
case schemapb.DataType_BinaryVector:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_BinaryVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: fieldValue.([]byte),
},
},
},
FieldId: fieldID,
}
case schemapb.DataType_FloatVector:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: fieldValue.([]float32),
},
},
},
},
FieldId: fieldID,
}
default:
log.Warn("not supported field type", zap.String("fieldType", fieldType.String()))
}
return fieldData
}
......@@ -962,7 +962,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
results = append(results, streamingResult)
ret, err2 := mergeInternalRetrieveResults(ctx, results)
ret, err2 := mergeInternalRetrieveResult(ctx, results, req.Req.GetLimit())
if err2 != nil {
failRet.Status.Reason = err2.Error()
return failRet, nil
......@@ -1026,7 +1026,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
if err := runningGp.Wait(); err != nil {
return failRet, nil
}
ret, err := mergeInternalRetrieveResults(ctx, toMergeResults)
ret, err := mergeInternalRetrieveResult(ctx, toMergeResults, req.GetReq().GetLimit())
if err != nil {
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
failRet.Status.Reason = err.Error()
......
......@@ -22,22 +22,21 @@ import (
"math"
"strconv"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/api/commonpb"
"github.com/milvus-io/milvus/api/schemapb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
unlimited int = -1
)
var _ typeutil.ResultWithID = &internalpb.RetrieveResults{}
var _ typeutil.ResultWithID = &segcorepb.RetrieveResults{}
func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*internalpb.GetStatisticsResponse, error) {
mergedResults := map[string]interface{}{
......@@ -238,8 +237,11 @@ func encodeSearchResultData(searchResultData *schemapb.SearchResultData, nq int6
return
}
func mergeInternalRetrieveResultsV2(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, limit int) (*internalpb.RetrieveResults, error) {
log.Ctx(ctx).Debug("reduceInternelRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults)))
func mergeInternalRetrieveResult(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, limit int64) (*internalpb.RetrieveResults, error) {
log.Ctx(ctx).Debug("reduceInternelRetrieveResults",
zap.Int64("limit", limit),
zap.Int("len(retrieveResults)", len(retrieveResults)),
)
var (
ret = &internalpb.RetrieveResults{
Ids: &schemapb.IDs{},
......@@ -263,15 +265,15 @@ func mergeInternalRetrieveResultsV2(ctx context.Context, retrieveResults []*inte
return ret, nil
}
if limit != unlimited {
loopEnd = limit
if limit != typeutil.Unlimited {
loopEnd = int(limit)
}
ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData()))
idSet := make(map[interface{}]struct{})
cursors := make([]int64, len(validRetrieveResults))
for j := 0; j < loopEnd; j++ {
sel := selectMinPK(validRetrieveResults, cursors)
sel := typeutil.SelectMinPK(validRetrieveResults, cursors)
if sel == -1 {
break
}
......@@ -295,58 +297,11 @@ func mergeInternalRetrieveResultsV2(ctx context.Context, retrieveResults []*inte
return ret, nil
}
// TODO: largely based on function mergeSegcoreRetrieveResults, need rewriting
func mergeInternalRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
var ret *internalpb.RetrieveResults
var skipDupCnt int64
var idSet = make(map[interface{}]struct{})
// merge results and remove duplicates
for _, rr := range retrieveResults {
// skip if fields data is empty
if len(rr.FieldsData) == 0 {
continue
}
if ret == nil {
ret = &internalpb.RetrieveResults{
Ids: &schemapb.IDs{},
FieldsData: make([]*schemapb.FieldData, len(rr.FieldsData)),
}
}
if len(ret.FieldsData) != len(rr.FieldsData) {
log.Ctx(ctx).Warn("mismatch FieldData in RetrieveResults")
return nil, fmt.Errorf("mismatch FieldData in RetrieveResults")
}
numPks := typeutil.GetSizeOfIDs(rr.GetIds())
for i := 0; i < numPks; i++ {
id := typeutil.GetPK(rr.GetIds(), int64(i))
if _, ok := idSet[id]; !ok {
typeutil.AppendPKs(ret.Ids, id)
typeutil.AppendFieldData(ret.FieldsData, rr.FieldsData, int64(i))
idSet[id] = struct{}{}
} else {
// primary keys duplicate
skipDupCnt++
}
}
}
// not found, return default values indicating not result found
if ret == nil {
ret = &internalpb.RetrieveResults{
Ids: &schemapb.IDs{},
FieldsData: []*schemapb.FieldData{},
}
}
return ret, nil
}
func mergeSegcoreRetrieveResultsV2(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, limit int) (*segcorepb.RetrieveResults, error) {
log.Ctx(ctx).Debug("reduceSegcoreRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults)))
func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, limit int64) (*segcorepb.RetrieveResults, error) {
log.Ctx(ctx).Debug("reduceSegcoreRetrieveResults",
zap.Int64("limit", limit),
zap.Int("len(retrieveResults)", len(retrieveResults)),
)
var (
ret = &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{},
......@@ -370,15 +325,15 @@ func mergeSegcoreRetrieveResultsV2(ctx context.Context, retrieveResults []*segco
return ret, nil
}
if limit != unlimited {
loopEnd = limit
if limit != typeutil.Unlimited {
loopEnd = int(limit)
}
ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData()))
idSet := make(map[interface{}]struct{})
cursors := make([]int64, len(validRetrieveResults))
for j := 0; j < loopEnd; j++ {
sel := selectMinPK(validRetrieveResults, cursors)
sel := typeutil.SelectMinPK(validRetrieveResults, cursors)
if sel == -1 {
break
}
......@@ -402,103 +357,6 @@ func mergeSegcoreRetrieveResultsV2(ctx context.Context, retrieveResults []*segco
return ret, nil
}
type ResultWithID interface {
GetIds() *schemapb.IDs
}
var _ ResultWithID = &internalpb.RetrieveResults{}
var _ ResultWithID = &segcorepb.RetrieveResults{}
func selectMinPK[T ResultWithID](results []T, cursors []int64) int {
var (
sel = -1
minIntPK int64 = math.MaxInt64
firstStr = true
minStrPK = ""
)
for i, cursor := range cursors {
if int(cursor) >= typeutil.GetSizeOfIDs(results[i].GetIds()) {
continue
}
pkInterface := typeutil.GetPK(results[i].GetIds(), cursor)
switch pk := pkInterface.(type) {
case string:
if firstStr || pk < minStrPK {
firstStr = false
minStrPK = pk
sel = i
}
case int64:
if pk < minIntPK {
minIntPK = pk
sel = i
}
default:
continue
}
}
return sel
}
func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
var (
ret *segcorepb.RetrieveResults
skipDupCnt int64
idSet = make(map[interface{}]struct{})
)
// merge results and remove duplicates
for _, rr := range retrieveResults {
// skip empty result, it will break merge result
if rr == nil || len(rr.Offset) == 0 {
continue
}
if ret == nil {
ret = &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{},
FieldsData: make([]*schemapb.FieldData, len(rr.FieldsData)),
}
}
if len(ret.FieldsData) != len(rr.FieldsData) {
return nil, fmt.Errorf("mismatch FieldData in RetrieveResults")
}
pkHitNum := typeutil.GetSizeOfIDs(rr.GetIds())
for i := 0; i < pkHitNum; i++ {
id := typeutil.GetPK(rr.GetIds(), int64(i))
if _, ok := idSet[id]; !ok {
typeutil.AppendPKs(ret.Ids, id)
typeutil.AppendFieldData(ret.FieldsData, rr.FieldsData, int64(i))
idSet[id] = struct{}{}
} else {
// primary keys duplicate
skipDupCnt++
}
}
}
if skipDupCnt > 0 {
log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
}
// not found, return default values indicating not result found
if ret == nil {
ret = &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{},
FieldsData: []*schemapb.FieldData{},
}
}
return ret, nil
}
// func printSearchResultData(data *schemapb.SearchResultData, header string) {
// size := len(data.Ids.GetIntId().Data)
// if size != len(data.Scores) {
......
......@@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestResult_mergeSegcoreRetrieveResults(t *testing.T) {
......@@ -71,7 +72,7 @@ func TestResult_mergeSegcoreRetrieveResults(t *testing.T) {
FieldsData: fieldDataArray2,
}
result, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, unlimited)
result, err := mergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, typeutil.Unlimited)
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []int64{0, 1}, result.GetIds().GetIntId().GetData())
......@@ -80,7 +81,7 @@ func TestResult_mergeSegcoreRetrieveResults(t *testing.T) {
})
t.Run("test nil results", func(t *testing.T) {
ret, err := mergeSegcoreRetrieveResultsV2(context.Background(), nil, unlimited)
ret, err := mergeSegcoreRetrieveResults(context.Background(), nil, typeutil.Unlimited)
assert.NoError(t, err)
assert.Empty(t, ret.GetIds())
assert.Empty(t, ret.GetFieldsData())
......@@ -98,7 +99,7 @@ func TestResult_mergeSegcoreRetrieveResults(t *testing.T) {
FieldsData: fieldDataArray1,
}
ret, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{r}, unlimited)
ret, err := mergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r}, typeutil.Unlimited)
assert.NoError(t, err)
assert.Empty(t, ret.GetIds())
assert.Empty(t, ret.GetFieldsData())
......@@ -137,7 +138,7 @@ func TestResult_mergeSegcoreRetrieveResults(t *testing.T) {
t.Run("test limited", func(t *testing.T) {
tests := []struct {
description string
limit int
limit int64
}{
{"limit 1", 1},
{"limit 2", 2},
......@@ -148,9 +149,9 @@ func TestResult_mergeSegcoreRetrieveResults(t *testing.T) {
resultField0 := []int64{11, 11, 22, 22}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
result, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, test.limit)
result, err := mergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, test.limit)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, test.limit, len(result.GetIds().GetIntId().GetData()))
assert.Equal(t, int(test.limit), len(result.GetIds().GetIntId().GetData()))
assert.Equal(t, resultIDs[0:test.limit], result.GetIds().GetIntId().GetData())
assert.Equal(t, resultField0[0:test.limit], result.GetFieldsData()[0].GetScalars().GetLongData().Data)
assert.InDeltaSlice(t, resultFloat[0:test.limit*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
......@@ -160,7 +161,7 @@ func TestResult_mergeSegcoreRetrieveResults(t *testing.T) {
})
t.Run("test int ID", func(t *testing.T) {
result, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, unlimited)
result, err := mergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, typeutil.Unlimited)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData())
assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
......@@ -181,7 +182,7 @@ func TestResult_mergeSegcoreRetrieveResults(t *testing.T) {
Data: []string{"b", "d"},
}}}
result, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, unlimited)
result, err := mergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, typeutil.Unlimited)
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData())
......@@ -234,7 +235,7 @@ func TestResult_mergeInternalRetrieveResults(t *testing.T) {
FieldsData: fieldDataArray2,
}
result, err := mergeInternalRetrieveResultsV2(context.Background(), []*internalpb.RetrieveResults{result1, result2}, unlimited)
result, err := mergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, typeutil.Unlimited)
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []int64{0, 1}, result.GetIds().GetIntId().GetData())
......@@ -243,7 +244,7 @@ func TestResult_mergeInternalRetrieveResults(t *testing.T) {
})
t.Run("test nil results", func(t *testing.T) {
ret, err := mergeInternalRetrieveResultsV2(context.Background(), nil, unlimited)
ret, err := mergeInternalRetrieveResult(context.Background(), nil, typeutil.Unlimited)
assert.NoError(t, err)
assert.Empty(t, ret.GetIds())
assert.Empty(t, ret.GetFieldsData())
......@@ -280,7 +281,7 @@ func TestResult_mergeInternalRetrieveResults(t *testing.T) {
t.Run("test limited", func(t *testing.T) {
tests := []struct {
description string
limit int
limit int64
}{
{"limit 1", 1},
{"limit 2", 2},
......@@ -291,9 +292,9 @@ func TestResult_mergeInternalRetrieveResults(t *testing.T) {
resultField0 := []int64{11, 11, 22, 22}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
result, err := mergeInternalRetrieveResultsV2(context.Background(), []*internalpb.RetrieveResults{r1, r2}, test.limit)
result, err := mergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, test.limit)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, test.limit, len(result.GetIds().GetIntId().GetData()))
assert.Equal(t, int(test.limit), len(result.GetIds().GetIntId().GetData()))
assert.Equal(t, resultIDs[0:test.limit], result.GetIds().GetIntId().GetData())
assert.Equal(t, resultField0[0:test.limit], result.GetFieldsData()[0].GetScalars().GetLongData().Data)
assert.InDeltaSlice(t, resultFloat[0:test.limit*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
......@@ -303,7 +304,7 @@ func TestResult_mergeInternalRetrieveResults(t *testing.T) {
})
t.Run("test int ID", func(t *testing.T) {
result, err := mergeInternalRetrieveResultsV2(context.Background(), []*internalpb.RetrieveResults{r1, r2}, unlimited)
result, err := mergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, typeutil.Unlimited)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData())
assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
......@@ -316,15 +317,19 @@ func TestResult_mergeInternalRetrieveResults(t *testing.T) {
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"a", "c"},
}}}
},
},
}
r2.Ids = &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"b", "d"},
}}}
},
},
}
result, err := mergeInternalRetrieveResultsV2(context.Background(), []*internalpb.RetrieveResults{r1, r2}, unlimited)
result, err := mergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, typeutil.Unlimited)
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData())
......
......@@ -85,7 +85,7 @@ func (q *queryTask) queryOnStreaming() error {
}
q.tr.RecordSpan()
mergedResult, err := mergeSegcoreRetrieveResults(ctx, sResults)
mergedResult, err := mergeSegcoreRetrieveResults(ctx, sResults, q.iReq.GetLimit())
if err != nil {
return err
}
......@@ -132,7 +132,7 @@ func (q *queryTask) queryOnHistorical() error {
return err
}
mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults)
mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults, q.req.GetReq().GetLimit())
if err != nil {
return err
}
......
......@@ -19,6 +19,7 @@ package typeutil
import (
"errors"
"fmt"
"math"
"strconv"
"github.com/milvus-io/milvus/api/schemapb"
......@@ -666,3 +667,43 @@ func ComparePK(data *schemapb.IDs, i, j int) bool {
}
return false
}
type ResultWithID interface {
GetIds() *schemapb.IDs
}
// SelectMinPK select the index of the minPK in results T of the cursors.
func SelectMinPK[T ResultWithID](results []T, cursors []int64) int {
var (
sel = -1
minIntPK int64 = math.MaxInt64
firstStr = true
minStrPK string
)
for i, cursor := range cursors {
if int(cursor) >= GetSizeOfIDs(results[i].GetIds()) {
continue
}
pkInterface := GetPK(results[i].GetIds(), cursor)
switch pk := pkInterface.(type) {
case string:
if firstStr || pk < minStrPK {
firstStr = false
minStrPK = pk
sel = i
}
case int64:
if pk < minIntPK {
minIntPK = pk
sel = i
}
default:
continue
}
}
return sel
}
......@@ -48,6 +48,8 @@ const (
DataNodeRole = "datanode"
)
const Unlimited int64 = -1
func ServerTypeMap() map[string]interface{} {
return map[string]interface{}{
EmbeddedRole: nil,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册