提交 a6f1de03 编写于 作者: X xige-16 提交者: yefu.chen

Optimize search performance in query node

Signed-off-by: Nxige-16 <xi.ge@zilliz.com>
上级 fd282d3c
......@@ -20,7 +20,7 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, QueryResult&
AssertInfo(plan, "empty plan");
auto size = results.result_distances_.size();
Assert(results.internal_seg_offsets_.size() == size);
Assert(results.result_offsets_.size() == size);
// Assert(results.result_offsets_.size() == size);
Assert(results.row_data_.size() == 0);
// std::vector<int64_t> row_ids(size);
......
......@@ -234,6 +234,56 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
}
}
CStatus
ReorganizeSingleQueryResult(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups,
CQueryResult c_search_result,
CPlan c_plan) {
try {
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
auto search_result = (SearchResult*)c_search_result;
auto topk = GetTopK(c_plan);
std::vector<int64_t> num_queries_peer_group;
for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
num_queries_peer_group.push_back(num_queries);
}
int64_t fill_hit_offset = 0;
for (int i = 0; i < num_groups; i++) {
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
for (int j = 0; j < num_queries_peer_group[i]; j++) {
milvus::proto::milvus::Hits hits;
for (int k = 0; k < topk; k++, fill_hit_offset++) {
hits.add_scores(search_result->result_distances_[fill_hit_offset]);
auto& row_data = search_result->row_data_[fill_hit_offset];
hits.add_row_data(row_data.data(), row_data.size());
int64_t result_id;
memcpy(&result_id, row_data.data(), sizeof(int64_t));
hits.add_ids(result_id);
}
auto blob = hits.SerializeAsString();
hits_peer_group.hits_.push_back(blob);
hits_peer_group.blob_length_.push_back(blob.size());
}
}
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
auto marshled_res = (CMarshaledHits)marshaledHits.release();
*c_marshaled_hits = marshled_res;
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
*c_marshaled_hits = nullptr;
return status;
}
}
int64_t
GetHitsBlobSize(CMarshaledHits c_marshaled_hits) {
int64_t total_size = 0;
......
......@@ -38,6 +38,13 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
int64_t num_segments,
CPlan c_plan);
CStatus
ReorganizeSingleQueryResult(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups,
CQueryResult c_search_result,
CPlan c_plan);
int64_t
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);
......
......@@ -166,7 +166,7 @@ import (
// Type: milvuspb.PlaceholderType_VectorFloat,
// Values: [][]byte{searchRowByteData},
// }
// placeholderGroup := milvuspb.PlaceholderGroup{
// placeholderGroup := milvuspb.searchRequest{
// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue},
// }
// placeGroupByte, err := proto.Marshal(&placeholderGroup)
......@@ -175,7 +175,7 @@ import (
// }
// query := milvuspb.SearchRequest{
// Dsl: dslString,
// PlaceholderGroup: placeGroupByte,
// searchRequest: placeGroupByte,
// }
// queryByte, err := proto.Marshal(&query)
// if err != nil {
......@@ -489,7 +489,7 @@ import (
// Type: milvuspb.PlaceholderType_VectorBinary,
// Values: [][]byte{searchRowData},
// }
// placeholderGroup := milvuspb.PlaceholderGroup{
// placeholderGroup := milvuspb.searchRequest{
// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue},
// }
// placeGroupByte, err := proto.Marshal(&placeholderGroup)
......@@ -498,7 +498,7 @@ import (
// }
// query := milvuspb.SearchRequest{
// Dsl: dslString,
// PlaceholderGroup: placeGroupByte,
// searchRequest: placeGroupByte,
// }
// queryByte, err := proto.Marshal(&query)
// if err != nil {
......@@ -1186,9 +1186,9 @@ func TestSegmentLoad_Search_Vector(t *testing.T) {
assert.NoError(t, err)
plan, err := createPlan(*collection, dslString)
assert.NoError(t, err)
holder, err := parserPlaceholderGroup(plan, placeHolderGroupBlob)
holder, err := parseSearchRequest(plan, placeHolderGroupBlob)
assert.NoError(t, err)
placeholderGroups := make([]*PlaceholderGroup, 0)
placeholderGroups := make([]*searchRequest, 0)
placeholderGroups = append(placeholderGroups, holder)
// wait for segment building index
......
......@@ -49,32 +49,32 @@ func (plan *Plan) delete() {
C.DeletePlan(plan.cPlan)
}
type PlaceholderGroup struct {
type searchRequest struct {
cPlaceholderGroup C.CPlaceholderGroup
}
func parserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) (*PlaceholderGroup, error) {
if len(placeHolderBlob) == 0 {
func parseSearchRequest(plan *Plan, searchRequestBlob []byte) (*searchRequest, error) {
if len(searchRequestBlob) == 0 {
return nil, errors.New("empty search request")
}
var blobPtr = unsafe.Pointer(&placeHolderBlob[0])
blobSize := C.long(len(placeHolderBlob))
var blobPtr = unsafe.Pointer(&searchRequestBlob[0])
blobSize := C.long(len(searchRequestBlob))
var cPlaceholderGroup C.CPlaceholderGroup
status := C.ParsePlaceholderGroup(plan.cPlan, blobPtr, blobSize, &cPlaceholderGroup)
if err := HandleCStatus(&status, "parser placeholder group failed"); err != nil {
if err := HandleCStatus(&status, "parser searchRequest failed"); err != nil {
return nil, err
}
var newPlaceholderGroup = &PlaceholderGroup{cPlaceholderGroup: cPlaceholderGroup}
return newPlaceholderGroup, nil
var newSearchRequest = &searchRequest{cPlaceholderGroup: cPlaceholderGroup}
return newSearchRequest, nil
}
func (pg *PlaceholderGroup) getNumOfQuery() int64 {
func (pg *searchRequest) getNumOfQuery() int64 {
numQueries := C.GetNumOfQueries(pg.cPlaceholderGroup)
return int64(numQueries)
}
func (pg *PlaceholderGroup) delete() {
func (pg *searchRequest) delete() {
C.DeletePlaceholderGroup(pg.cPlaceholderGroup)
}
......@@ -68,7 +68,7 @@ func TestPlan_PlaceholderGroup(t *testing.T) {
placeGroupByte, err := proto.Marshal(&placeholderGroup)
assert.Nil(t, err)
holder, err := parserPlaceholderGroup(plan, placeGroupByte)
holder, err := parseSearchRequest(plan, placeGroupByte)
assert.NoError(t, err)
assert.NotNil(t, holder)
numQueries := holder.getNumOfQuery()
......
......@@ -10,10 +10,13 @@ package querynode
*/
import "C"
import (
"errors"
"fmt"
"strconv"
"sync"
"unsafe"
"errors"
"github.com/zilliztech/milvus-distributed/internal/log"
)
type SearchResult struct {
......@@ -46,24 +49,31 @@ func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inRed
}
func fillTargetEntry(plan *Plan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error {
for i, value := range inReduced {
if value {
err := matchedSegments[i].fillTargetEntry(plan, searchResults[i])
if err != nil {
return err
}
wg := &sync.WaitGroup{}
fmt.Println(inReduced)
for i := range inReduced {
if inReduced[i] {
wg.Add(1)
go func(i int) {
err := matchedSegments[i].fillTargetEntry(plan, searchResults[i])
if err != nil {
log.Error(err.Error())
}
wg.Done()
}(i)
}
}
wg.Wait()
return nil
}
func reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) {
func reorganizeQueryResults(plan *Plan, searchRequests []*searchRequest, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) {
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
for _, pg := range placeholderGroups {
for _, pg := range searchRequests {
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
}
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroup = (C.long)(len(placeholderGroups))
var cNumGroup = (C.long)(len(searchRequests))
cSearchResults := make([]C.CQueryResult, 0)
for _, res := range searchResults {
......@@ -86,6 +96,28 @@ func reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup, s
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
}
func reorganizeSingleQueryResult(plan *Plan, placeholderGroups []*searchRequest, searchResult *SearchResult) (*MarshaledHits, error) {
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
for _, pg := range placeholderGroups {
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
}
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroup = (C.long)(len(placeholderGroups))
cSearchResult := searchResult.cQueryResult
var cMarshaledHits C.CMarshaledHits
status := C.ReorganizeSingleQueryResult(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResult, plan.cPlan)
errorCode := status.error_code
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return nil, errors.New("reorganizeQueryResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
}
func (mh *MarshaledHits) getHitsBlobSize() int64 {
res := C.GetHitsBlobSize(mh.cMarshaledHits)
return int64(res)
......
......@@ -54,9 +54,9 @@ func TestReduce_AllFunc(t *testing.T) {
plan, err := createPlan(*collection, dslString)
assert.NoError(t, err)
holder, err := parserPlaceholderGroup(plan, placeGroupByte)
holder, err := parseSearchRequest(plan, placeGroupByte)
assert.NoError(t, err)
placeholderGroups := make([]*PlaceholderGroup, 0)
placeholderGroups := make([]*searchRequest, 0)
placeholderGroups = append(placeholderGroups, holder)
searchResults := make([]*SearchResult, 0)
......
......@@ -271,13 +271,13 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
if err != nil {
return err
}
placeHolderGroupBlob := query.PlaceholderGroup
placeholderGroup, err := parserPlaceholderGroup(plan, placeHolderGroupBlob)
searchRequestBlob := query.PlaceholderGroup
searchReq, err := parseSearchRequest(plan, searchRequestBlob)
if err != nil {
return err
}
placeholderGroups := make([]*PlaceholderGroup, 0)
placeholderGroups = append(placeholderGroups, placeholderGroup)
searchRequests := make([]*searchRequest, 0)
searchRequests = append(searchRequests, searchReq)
searchResults := make([]*SearchResult, 0)
matchedSegments := make([]*Segment, 0)
......@@ -315,7 +315,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
if err != nil {
return err
}
searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp})
searchResult, err := segment.segmentSearch(plan, searchRequests, []Timestamp{searchTimestamp})
if err != nil {
return err
......@@ -326,7 +326,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
}
if len(searchResults) <= 0 {
for _, group := range placeholderGroups {
for _, group := range searchRequests {
nq := group.getNumOfQuery()
nilHits := make([][]byte, nq)
hit := &milvuspb.Hits{}
......@@ -363,17 +363,30 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
inReduced := make([]bool, len(searchResults))
numSegment := int64(len(searchResults))
err2 := reduceSearchResults(searchResults, numSegment, inReduced)
if err2 != nil {
return err2
}
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
if err != nil {
return err
}
marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, numSegment, inReduced)
if err != nil {
return err
var marshaledHits *MarshaledHits = nil
if numSegment == 1 {
inReduced[0] = true
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
if err != nil {
return err
}
marshaledHits, err = reorganizeSingleQueryResult(plan, searchRequests, searchResults[0])
if err != nil {
return err
}
} else {
err = reduceSearchResults(searchResults, numSegment, inReduced)
if err != nil {
return err
}
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
if err != nil {
return err
}
marshaledHits, err = reorganizeQueryResults(plan, searchRequests, searchResults, numSegment, inReduced)
if err != nil {
return err
}
}
hitsBlob, err := marshaledHits.getHitsBlob()
if err != nil {
......@@ -381,14 +394,14 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
}
var offset int64 = 0
for index := range placeholderGroups {
for index := range searchRequests {
hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
if err != nil {
return err
}
hits := make([][]byte, 0)
for _, len := range hitBlobSizePeerQuery {
hits = append(hits, hitsBlob[offset:offset+len])
hits := make([][]byte, len(hitBlobSizePeerQuery))
for i, len := range hitBlobSizePeerQuery {
hits[i] = hitsBlob[offset : offset+len]
//test code to checkout marshaled hits
//marshaledHit := hitsBlob[offset:offset+len]
//unMarshaledHit := milvuspb.Hits{}
......@@ -436,7 +449,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
deleteSearchResults(searchResults)
deleteMarshaledHits(marshaledHits)
plan.delete()
placeholderGroup.delete()
searchReq.delete()
return nil
}
......
......@@ -214,7 +214,7 @@ func (s *Segment) getMemSize() int64 {
}
func (s *Segment) segmentSearch(plan *Plan,
placeHolderGroups []*PlaceholderGroup,
searchRequests []*searchRequest,
timestamp []Timestamp) (*SearchResult, error) {
/*
CStatus
......@@ -229,14 +229,14 @@ func (s *Segment) segmentSearch(plan *Plan,
return nil, errors.New("null seg core pointer")
}
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
for _, pg := range placeHolderGroups {
for _, pg := range searchRequests {
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
}
var searchResult SearchResult
var cTimestamp = (*C.ulong)(&timestamp[0])
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroups = C.int(len(placeHolderGroups))
var cNumGroups = C.int(len(searchRequests))
log.Debug("do search on segment", zap.Int64("segmentID", s.segmentID), zap.Int32("segmentType", int32(s.segmentType)))
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, &searchResult.cQueryResult)
......@@ -257,6 +257,7 @@ func (s *Segment) fillTargetEntry(plan *Plan,
return errors.New("null seg core pointer")
}
log.Debug("segment fill target entry, ", zap.Int64("segment ID = ", s.segmentID))
var status = C.FillTargetEntry(s.segmentPtr, plan.cPlan, result.cQueryResult)
errorCode := status.error_code
......
......@@ -355,9 +355,9 @@ func TestSegment_segmentSearch(t *testing.T) {
searchTimestamp := Timestamp(1020)
plan, err := createPlan(*collection, dslString)
assert.NoError(t, err)
holder, err := parserPlaceholderGroup(plan, placeHolderGroupBlob)
holder, err := parseSearchRequest(plan, placeHolderGroupBlob)
assert.NoError(t, err)
placeholderGroups := make([]*PlaceholderGroup, 0)
placeholderGroups := make([]*searchRequest, 0)
placeholderGroups = append(placeholderGroups, holder)
searchResults := make([]*SearchResult, 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册