未验证 提交 52c6a270 编写于 作者: X XuanYang-cn 提交者: GitHub

Fix binary results unstable (#19401)

See also: #19338, #19366, 19367
Signed-off-by: Nyangxuan <xuan.yang@zilliz.com>
Signed-off-by: Nyangxuan <xuan.yang@zilliz.com>
上级 7819297f
...@@ -564,11 +564,21 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s ...@@ -564,11 +564,21 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s
} }
sIdx := subSearchNqOffset[i][qi] + cursors[i] sIdx := subSearchNqOffset[i][qi] + cursors[i]
sScore := subSearchResultData[i].Scores[sIdx] sScore := subSearchResultData[i].Scores[sIdx]
// Choose the larger score idx or the smaller pk idx with the same score
if sScore > maxScore { if sScore > maxScore {
subSearchIdx = i subSearchIdx = i
resultDataIdx = sIdx resultDataIdx = sIdx
maxScore = sScore maxScore = sScore
} else if sScore == maxScore {
sID := typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx)
tmpID := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
if typeutil.ComparePK(sID, tmpID) {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
}
} }
} }
return subSearchIdx, resultDataIdx return subSearchIdx, resultDataIdx
......
...@@ -1218,7 +1218,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) { ...@@ -1218,7 +1218,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) {
}, },
}, },
}, },
Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2}, Scores: []float32{1.2, 1.0, 0.7, 0.5, 0.4, 0.2},
Topks: []int64{2, 2, 2}, Topks: []int64{2, 2, 2},
}, },
}, },
...@@ -1280,7 +1280,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) { ...@@ -1280,7 +1280,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) {
}, },
}, },
}, },
Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2}, Scores: []float32{1.2, 1.0, 0.7, 0.5, 0.4, 0.2},
Topks: []int64{2, 2, 2}, Topks: []int64{2, 2, 2},
}, },
}, },
......
...@@ -183,17 +183,32 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se ...@@ -183,17 +183,32 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
} }
func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int { func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int {
sel := -1 var (
maxDistance := -1 * float32(math.MaxFloat32) sel = -1
maxDistance = -1 * float32(math.MaxFloat32)
resultDataIdx int64 = -1
)
for i, offset := range offsets { // query num, the number of ways to merge for i, offset := range offsets { // query num, the number of ways to merge
if offset >= dataArray[i].Topks[qi] { if offset >= dataArray[i].Topks[qi] {
continue continue
} }
idx := resultOffsets[i][qi] + offset idx := resultOffsets[i][qi] + offset
distance := dataArray[i].Scores[idx] distance := dataArray[i].Scores[idx]
if distance > maxDistance { if distance > maxDistance {
sel = i sel = i
maxDistance = distance maxDistance = distance
resultDataIdx = idx
} else if distance == maxDistance {
sID := typeutil.GetPK(dataArray[i].GetIds(), idx)
tmpID := typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx)
if typeutil.ComparePK(sID, tmpID) {
sel = i
maxDistance = distance
resultDataIdx = idx
}
} }
} }
return sel return sel
......
...@@ -35,7 +35,7 @@ func (s *byPK) Swap(i, j int) { ...@@ -35,7 +35,7 @@ func (s *byPK) Swap(i, j int) {
} }
func (s *byPK) Less(i, j int) bool { func (s *byPK) Less(i, j int) bool {
return typeutil.ComparePK(s.r.GetIds(), i, j) return typeutil.ComparePKInSlice(s.r.GetIds(), i, j)
} }
func swapFieldData(field *schemapb.FieldData, i int, j int) { func swapFieldData(field *schemapb.FieldData, i int, j int) {
......
...@@ -657,8 +657,8 @@ func SwapPK(data *schemapb.IDs, i, j int) { ...@@ -657,8 +657,8 @@ func SwapPK(data *schemapb.IDs, i, j int) {
} }
} }
// ComparePK returns if i-th PK < j-th PK // ComparePKInSlice returns if i-th PK < j-th PK
func ComparePK(data *schemapb.IDs, i, j int) bool { func ComparePKInSlice(data *schemapb.IDs, i, j int) bool {
switch f := data.GetIdField().(type) { switch f := data.GetIdField().(type) {
case *schemapb.IDs_IntId: case *schemapb.IDs_IntId:
return f.IntId.Data[i] < f.IntId.Data[j] return f.IntId.Data[i] < f.IntId.Data[j]
...@@ -668,6 +668,17 @@ func ComparePK(data *schemapb.IDs, i, j int) bool { ...@@ -668,6 +668,17 @@ func ComparePK(data *schemapb.IDs, i, j int) bool {
return false return false
} }
// ComparePK returns if i-th PK of dataA > j-th PK of dataB
func ComparePK(pkA, pkB interface{}) bool {
switch pkA.(type) {
case int64:
return pkA.(int64) < pkB.(int64)
case string:
return pkA.(string) < pkB.(string)
}
return false
}
type ResultWithID interface { type ResultWithID interface {
GetIds() *schemapb.IDs GetIds() *schemapb.IDs
} }
......
...@@ -662,18 +662,18 @@ func TestComparePk(t *testing.T) { ...@@ -662,18 +662,18 @@ func TestComparePk(t *testing.T) {
AppendPKs(intPks, int64(3)) AppendPKs(intPks, int64(3))
require.Equal(t, []int64{1, 2, 3}, intPks.GetIntId().GetData()) require.Equal(t, []int64{1, 2, 3}, intPks.GetIntId().GetData())
less := ComparePK(intPks, 0, 1) less := ComparePKInSlice(intPks, 0, 1)
assert.True(t, less) assert.True(t, less)
less = ComparePK(intPks, 0, 2) less = ComparePKInSlice(intPks, 0, 2)
assert.True(t, less) assert.True(t, less)
less = ComparePK(intPks, 1, 2) less = ComparePKInSlice(intPks, 1, 2)
assert.True(t, less) assert.True(t, less)
less = ComparePK(intPks, 1, 0) less = ComparePKInSlice(intPks, 1, 0)
assert.False(t, less) assert.False(t, less)
less = ComparePK(intPks, 2, 0) less = ComparePKInSlice(intPks, 2, 0)
assert.False(t, less) assert.False(t, less)
less = ComparePK(intPks, 2, 1) less = ComparePKInSlice(intPks, 2, 1)
assert.False(t, less) assert.False(t, less)
strPks := &schemapb.IDs{} strPks := &schemapb.IDs{}
...@@ -683,17 +683,17 @@ func TestComparePk(t *testing.T) { ...@@ -683,17 +683,17 @@ func TestComparePk(t *testing.T) {
require.Equal(t, []string{"1", "2", "3"}, strPks.GetStrId().GetData()) require.Equal(t, []string{"1", "2", "3"}, strPks.GetStrId().GetData())
less = ComparePK(strPks, 0, 1) less = ComparePKInSlice(strPks, 0, 1)
assert.True(t, less) assert.True(t, less)
less = ComparePK(strPks, 0, 2) less = ComparePKInSlice(strPks, 0, 2)
assert.True(t, less) assert.True(t, less)
less = ComparePK(strPks, 1, 2) less = ComparePKInSlice(strPks, 1, 2)
assert.True(t, less) assert.True(t, less)
less = ComparePK(strPks, 1, 0) less = ComparePKInSlice(strPks, 1, 0)
assert.False(t, less) assert.False(t, less)
less = ComparePK(strPks, 2, 0) less = ComparePKInSlice(strPks, 2, 0)
assert.False(t, less) assert.False(t, less)
less = ComparePK(strPks, 2, 1) less = ComparePKInSlice(strPks, 2, 1)
assert.False(t, less) assert.False(t, less)
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册