未验证 提交 52c6a270 编写于 作者: X XuanYang-cn 提交者: GitHub

Fix binary results unstable (#19401)

See also: #19338, #19366, 19367
Signed-off-by: Nyangxuan <xuan.yang@zilliz.com>
Signed-off-by: Nyangxuan <xuan.yang@zilliz.com>
上级 7819297f
......@@ -564,11 +564,21 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s
}
sIdx := subSearchNqOffset[i][qi] + cursors[i]
sScore := subSearchResultData[i].Scores[sIdx]
// Choose the larger score idx or the smaller pk idx with the same score
if sScore > maxScore {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
} else if sScore == maxScore {
sID := typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx)
tmpID := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
if typeutil.ComparePK(sID, tmpID) {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
}
}
}
return subSearchIdx, resultDataIdx
......
......@@ -1218,7 +1218,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) {
},
},
},
Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2},
Scores: []float32{1.2, 1.0, 0.7, 0.5, 0.4, 0.2},
Topks: []int64{2, 2, 2},
},
},
......@@ -1280,7 +1280,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) {
},
},
},
Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2},
Scores: []float32{1.2, 1.0, 0.7, 0.5, 0.4, 0.2},
Topks: []int64{2, 2, 2},
},
},
......
......@@ -183,17 +183,32 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
}
func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int {
sel := -1
maxDistance := -1 * float32(math.MaxFloat32)
var (
sel = -1
maxDistance = -1 * float32(math.MaxFloat32)
resultDataIdx int64 = -1
)
for i, offset := range offsets { // query num, the number of ways to merge
if offset >= dataArray[i].Topks[qi] {
continue
}
idx := resultOffsets[i][qi] + offset
distance := dataArray[i].Scores[idx]
if distance > maxDistance {
sel = i
maxDistance = distance
resultDataIdx = idx
} else if distance == maxDistance {
sID := typeutil.GetPK(dataArray[i].GetIds(), idx)
tmpID := typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx)
if typeutil.ComparePK(sID, tmpID) {
sel = i
maxDistance = distance
resultDataIdx = idx
}
}
}
return sel
......
......@@ -35,7 +35,7 @@ func (s *byPK) Swap(i, j int) {
}
func (s *byPK) Less(i, j int) bool {
return typeutil.ComparePK(s.r.GetIds(), i, j)
return typeutil.ComparePKInSlice(s.r.GetIds(), i, j)
}
func swapFieldData(field *schemapb.FieldData, i int, j int) {
......
......@@ -657,8 +657,8 @@ func SwapPK(data *schemapb.IDs, i, j int) {
}
}
// ComparePK returns if i-th PK < j-th PK
func ComparePK(data *schemapb.IDs, i, j int) bool {
// ComparePKInSlice returns if i-th PK < j-th PK
func ComparePKInSlice(data *schemapb.IDs, i, j int) bool {
switch f := data.GetIdField().(type) {
case *schemapb.IDs_IntId:
return f.IntId.Data[i] < f.IntId.Data[j]
......@@ -668,6 +668,17 @@ func ComparePK(data *schemapb.IDs, i, j int) bool {
return false
}
// ComparePK returns if i-th PK of dataA > j-th PK of dataB
func ComparePK(pkA, pkB interface{}) bool {
switch pkA.(type) {
case int64:
return pkA.(int64) < pkB.(int64)
case string:
return pkA.(string) < pkB.(string)
}
return false
}
type ResultWithID interface {
GetIds() *schemapb.IDs
}
......
......@@ -662,18 +662,18 @@ func TestComparePk(t *testing.T) {
AppendPKs(intPks, int64(3))
require.Equal(t, []int64{1, 2, 3}, intPks.GetIntId().GetData())
less := ComparePK(intPks, 0, 1)
less := ComparePKInSlice(intPks, 0, 1)
assert.True(t, less)
less = ComparePK(intPks, 0, 2)
less = ComparePKInSlice(intPks, 0, 2)
assert.True(t, less)
less = ComparePK(intPks, 1, 2)
less = ComparePKInSlice(intPks, 1, 2)
assert.True(t, less)
less = ComparePK(intPks, 1, 0)
less = ComparePKInSlice(intPks, 1, 0)
assert.False(t, less)
less = ComparePK(intPks, 2, 0)
less = ComparePKInSlice(intPks, 2, 0)
assert.False(t, less)
less = ComparePK(intPks, 2, 1)
less = ComparePKInSlice(intPks, 2, 1)
assert.False(t, less)
strPks := &schemapb.IDs{}
......@@ -683,17 +683,17 @@ func TestComparePk(t *testing.T) {
require.Equal(t, []string{"1", "2", "3"}, strPks.GetStrId().GetData())
less = ComparePK(strPks, 0, 1)
less = ComparePKInSlice(strPks, 0, 1)
assert.True(t, less)
less = ComparePK(strPks, 0, 2)
less = ComparePKInSlice(strPks, 0, 2)
assert.True(t, less)
less = ComparePK(strPks, 1, 2)
less = ComparePKInSlice(strPks, 1, 2)
assert.True(t, less)
less = ComparePK(strPks, 1, 0)
less = ComparePKInSlice(strPks, 1, 0)
assert.False(t, less)
less = ComparePK(strPks, 2, 0)
less = ComparePKInSlice(strPks, 2, 0)
assert.False(t, less)
less = ComparePK(strPks, 2, 1)
less = ComparePKInSlice(strPks, 2, 1)
assert.False(t, less)
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册