diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index ac98b4c3279d8729e0092bcc9b5aebbae25880c1..5fb72d5c5c80d5cd43f5c64c99042ffa8c7727c3 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -564,11 +564,21 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s } sIdx := subSearchNqOffset[i][qi] + cursors[i] sScore := subSearchResultData[i].Scores[sIdx] + + // Choose the larger score idx or the smaller pk idx with the same score if sScore > maxScore { subSearchIdx = i resultDataIdx = sIdx - maxScore = sScore + } else if sScore == maxScore { + sID := typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx) + tmpID := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx) + + if typeutil.ComparePK(sID, tmpID) { + subSearchIdx = i + resultDataIdx = sIdx + maxScore = sScore + } } } return subSearchIdx, resultDataIdx diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index a67f535dcb808db9ce2b2289931ac560fb2da887..b58951888855665a740e73b91aca0edbe1fb41de 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1218,7 +1218,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) { }, }, }, - Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2}, + Scores: []float32{1.2, 1.0, 0.7, 0.5, 0.4, 0.2}, Topks: []int64{2, 2, 2}, }, }, @@ -1280,7 +1280,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) { }, }, }, - Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2}, + Scores: []float32{1.2, 1.0, 0.7, 0.5, 0.4, 0.2}, Topks: []int64{2, 2, 2}, }, }, diff --git a/internal/querynode/result.go b/internal/querynode/result.go index 22332f03e534430d09effaf2bc52531116191759..e4a1bd605668a19976f1a8f8066e8b400571f291 100644 --- a/internal/querynode/result.go +++ b/internal/querynode/result.go @@ -183,17 +183,32 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se } func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int { - sel := -1 - maxDistance := -1 * float32(math.MaxFloat32) + var ( + sel = -1 + maxDistance = -1 * float32(math.MaxFloat32) + resultDataIdx int64 = -1 + ) for i, offset := range offsets { // query num, the number of ways to merge if offset >= dataArray[i].Topks[qi] { continue } + idx := resultOffsets[i][qi] + offset distance := dataArray[i].Scores[idx] + if distance > maxDistance { sel = i maxDistance = distance + resultDataIdx = idx + } else if distance == maxDistance { + sID := typeutil.GetPK(dataArray[i].GetIds(), idx) + tmpID := typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx) + + if typeutil.ComparePK(sID, tmpID) { + sel = i + maxDistance = distance + resultDataIdx = idx + } } } return sel diff --git a/internal/querynode/result_sorter.go b/internal/querynode/result_sorter.go index 8ed18826d38c2f97d2024cb689c828eb437ee878..85eea42f61cdbe3eb08e6a4386a027f2306cddd5 100644 --- a/internal/querynode/result_sorter.go +++ b/internal/querynode/result_sorter.go @@ -35,7 +35,7 @@ func (s *byPK) Swap(i, j int) { } func (s *byPK) Less(i, j int) bool { - return typeutil.ComparePK(s.r.GetIds(), i, j) + return typeutil.ComparePKInSlice(s.r.GetIds(), i, j) } func swapFieldData(field *schemapb.FieldData, i int, j int) { diff --git a/internal/util/typeutil/schema.go b/internal/util/typeutil/schema.go index 34c33e9774eb492f5a89da51a82a5fd464c81270..e2499af04f1d28a8076be8fee86812300f0e4a9c 100644 --- a/internal/util/typeutil/schema.go +++ b/internal/util/typeutil/schema.go @@ -657,8 +657,8 @@ func SwapPK(data *schemapb.IDs, i, j int) { } } -// ComparePK returns if i-th PK < j-th PK -func ComparePK(data *schemapb.IDs, i, j int) bool { +// ComparePKInSlice returns if i-th PK < j-th PK +func ComparePKInSlice(data *schemapb.IDs, i, j int) bool { switch f := data.GetIdField().(type) { case *schemapb.IDs_IntId: return f.IntId.Data[i] < f.IntId.Data[j] @@ -668,6 +668,17 @@ func ComparePK(data *schemapb.IDs, i, j int) bool { return false } +// ComparePK returns if i-th PK of dataA > j-th PK of dataB +func ComparePK(pkA, pkB interface{}) bool { + switch pkA.(type) { + case int64: + return pkA.(int64) < pkB.(int64) + case string: + return pkA.(string) < pkB.(string) + } + return false +} + type ResultWithID interface { GetIds() *schemapb.IDs } diff --git a/internal/util/typeutil/schema_test.go b/internal/util/typeutil/schema_test.go index ee96bb40842a09c5b132b369868d0113952cee20..5ca217046e97acfc31747624843e733772f75c9a 100644 --- a/internal/util/typeutil/schema_test.go +++ b/internal/util/typeutil/schema_test.go @@ -662,18 +662,18 @@ func TestComparePk(t *testing.T) { AppendPKs(intPks, int64(3)) require.Equal(t, []int64{1, 2, 3}, intPks.GetIntId().GetData()) - less := ComparePK(intPks, 0, 1) + less := ComparePKInSlice(intPks, 0, 1) assert.True(t, less) - less = ComparePK(intPks, 0, 2) + less = ComparePKInSlice(intPks, 0, 2) assert.True(t, less) - less = ComparePK(intPks, 1, 2) + less = ComparePKInSlice(intPks, 1, 2) assert.True(t, less) - less = ComparePK(intPks, 1, 0) + less = ComparePKInSlice(intPks, 1, 0) assert.False(t, less) - less = ComparePK(intPks, 2, 0) + less = ComparePKInSlice(intPks, 2, 0) assert.False(t, less) - less = ComparePK(intPks, 2, 1) + less = ComparePKInSlice(intPks, 2, 1) assert.False(t, less) strPks := &schemapb.IDs{} @@ -683,17 +683,17 @@ func TestComparePk(t *testing.T) { require.Equal(t, []string{"1", "2", "3"}, strPks.GetStrId().GetData()) - less = ComparePK(strPks, 0, 1) + less = ComparePKInSlice(strPks, 0, 1) assert.True(t, less) - less = ComparePK(strPks, 0, 2) + less = ComparePKInSlice(strPks, 0, 2) assert.True(t, less) - less = ComparePK(strPks, 1, 2) + less = ComparePKInSlice(strPks, 1, 2) assert.True(t, less) - less = ComparePK(strPks, 1, 0) + less = ComparePKInSlice(strPks, 1, 0) assert.False(t, less) - less = ComparePK(strPks, 2, 0) + less = ComparePKInSlice(strPks, 2, 0) assert.False(t, less) - less = ComparePK(strPks, 2, 1) + less = ComparePKInSlice(strPks, 2, 1) assert.False(t, less) }