未验证 提交 b2e8ba7b 编写于 作者: D dragondriver 提交者: GitHub

Fix reduce algorithm in proxy search task (#8206)

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 c1e229cb
......@@ -83,6 +83,8 @@ const (
CreateAliasTaskName = "CreateAliasTask"
DropAliasTaskName = "DropAliasTask"
AlterAliasTaskName = "AlterAliasTask"
minFloat32 = -1 * float32(math.MaxFloat32)
)
type task interface {
......@@ -1755,8 +1757,6 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
}
}
const minFloat32 = -1 * float32(math.MaxFloat32)
// TODO(yukun): Use parallel function
var realTopK int64 = -1
var idx int64
......@@ -1766,17 +1766,14 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
j = 0
for ; j < topk; j++ {
valid := true
choice, maxDistance := 0, minFloat32
choice, maxDistance := -1, minFloat32
for q, loc := range locs { // query num, the number of ways to merge
if loc >= topk {
continue
}
curIdx := idx*topk + loc
id := searchResultData[q].Ids.GetIntId().Data[curIdx]
if id == -1 {
valid = false
} else {
if id != -1 {
distance := searchResultData[q].Scores[curIdx]
if distance > maxDistance {
choice = q
......@@ -1784,7 +1781,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
}
}
}
if !valid {
if choice == -1 {
break
}
choiceOffset := locs[choice]
......
......@@ -1903,7 +1903,7 @@ func TestSearchTask_all(t *testing.T) {
for i := 0; i < nq; i++ {
for j := 0; j < topk; j++ {
offset := i*topk + j
score := rand.Float32()
score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
resultData.Scores[offset] = score
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
......@@ -1981,6 +1981,250 @@ func TestSearchTask_all(t *testing.T) {
wg.Wait()
}
func TestSearchTask_7803_reduce(t *testing.T) {
var err error
Params.Init()
Params.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestSearchTask_7803_reduce"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
expr := fmt.Sprintf("%s > 0", int64Field)
nq := 10
topk := 10
nprobe := 10
schema := constructCollectionSchema(
int64Field,
floatVecField,
dim,
collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyID,
},
DbID: 0,
CollectionID: collectionID,
Schema: nil,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
req := constructSearchRequest(dbName, collectionName,
expr,
floatVecField,
nq, dim, nprobe, topk)
task := &searchTask{
Condition: NewTaskCondition(ctx),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyID, 10),
DbID: 0,
CollectionID: 0,
PartitionIDs: nil,
Dsl: "",
PlaceholderGroup: nil,
DslType: 0,
SerializedExprPlan: nil,
OutputFieldsId: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
},
ctx: ctx,
resultBuf: make(chan []*internalpb.SearchResults),
result: nil,
query: req,
chMgr: chMgr,
qc: qc,
}
// simple mock for query node
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
err = chMgr.createDQLStream(collectionID)
assert.NoError(t, err)
stream, err := chMgr.getDQLStream(collectionID)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
consumeCtx, cancel := context.WithCancel(ctx)
go func() {
defer wg.Done()
for {
select {
case <-consumeCtx.Done():
return
case pack := <-stream.Chan():
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.SearchMsg)
assert.True(t, ok)
// TODO(dragondriver): construct result according to the request
constructSearchResulstData := func(invalidNum int) *schemapb.SearchResultData {
resultData := &schemapb.SearchResultData{
NumQueries: int64(nq),
TopK: int64(topk),
FieldsData: nil,
Scores: make([]float32, nq*topk),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, nq*topk),
},
},
},
Topks: make([]int64, nq),
}
for i := 0; i < nq; i++ {
for j := 0; j < topk; j++ {
offset := i*topk + j
if j >= invalidNum {
resultData.Scores[offset] = minFloat32
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = -1
} else {
score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
resultData.Scores[offset] = score
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
}
}
resultData.Topks[i] = int64(topk)
}
return resultData
}
result1 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
resultData := constructSearchResulstData(topk / 2)
sliceBlob, err := proto.Marshal(resultData)
assert.NoError(t, err)
result1.SlicedBlob = sliceBlob
result2 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
resultData2 := constructSearchResulstData(topk - topk/2)
sliceBlob2, err := proto.Marshal(resultData2)
assert.NoError(t, err)
result2.SlicedBlob = sliceBlob2
// send search result
task.resultBuf <- []*internalpb.SearchResults{result1, result2}
}
}
}
}()
assert.NoError(t, task.OnEnqueue())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
cancel()
wg.Wait()
}
func TestSearchTask_Type(t *testing.T) {
Params.Init()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册