From a6f1de036bcc8f653defdc3162a0a7347395de4a Mon Sep 17 00:00:00 2001 From: xige-16 Date: Tue, 30 Mar 2021 22:16:58 +0800 Subject: [PATCH] Optimize search performance in query node Signed-off-by: xige-16 --- .../core/src/segcore/SegmentInterface.cpp | 2 +- internal/core/src/segcore/reduce_c.cpp | 50 ++++++++++++++++ internal/core/src/segcore/reduce_c.h | 7 +++ internal/querynode/load_service_test.go | 12 ++-- internal/querynode/plan.go | 20 +++---- internal/querynode/plan_test.go | 2 +- internal/querynode/reduce.go | 52 +++++++++++++---- internal/querynode/reduce_test.go | 4 +- internal/querynode/search_collection.go | 57 ++++++++++++------- internal/querynode/segment.go | 7 ++- internal/querynode/segment_test.go | 4 +- 11 files changed, 160 insertions(+), 57 deletions(-) diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index 7c798ab89..eb7d5b1e7 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -20,7 +20,7 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, QueryResult& AssertInfo(plan, "empty plan"); auto size = results.result_distances_.size(); Assert(results.internal_seg_offsets_.size() == size); - Assert(results.result_offsets_.size() == size); + // Assert(results.result_offsets_.size() == size); Assert(results.row_data_.size() == 0); // std::vector row_ids(size); diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index f792ed010..0fe4d1725 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -234,6 +234,56 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits, } } +CStatus +ReorganizeSingleQueryResult(CMarshaledHits* c_marshaled_hits, + CPlaceholderGroup* c_placeholder_groups, + int64_t num_groups, + CQueryResult c_search_result, + CPlan c_plan) { + try { + auto marshaledHits = std::make_unique(num_groups); + auto search_result = (SearchResult*)c_search_result; + auto topk = GetTopK(c_plan); + std::vector num_queries_peer_group; + for (int i = 0; i < num_groups; i++) { + auto num_queries = GetNumOfQueries(c_placeholder_groups[i]); + num_queries_peer_group.push_back(num_queries); + } + + int64_t fill_hit_offset = 0; + for (int i = 0; i < num_groups; i++) { + MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i]; + for (int j = 0; j < num_queries_peer_group[i]; j++) { + milvus::proto::milvus::Hits hits; + for (int k = 0; k < topk; k++, fill_hit_offset++) { + hits.add_scores(search_result->result_distances_[fill_hit_offset]); + auto& row_data = search_result->row_data_[fill_hit_offset]; + hits.add_row_data(row_data.data(), row_data.size()); + int64_t result_id; + memcpy(&result_id, row_data.data(), sizeof(int64_t)); + hits.add_ids(result_id); + } + auto blob = hits.SerializeAsString(); + hits_peer_group.hits_.push_back(blob); + hits_peer_group.blob_length_.push_back(blob.size()); + } + } + + auto status = CStatus(); + status.error_code = Success; + status.error_msg = ""; + auto marshled_res = (CMarshaledHits)marshaledHits.release(); + *c_marshaled_hits = marshled_res; + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + *c_marshaled_hits = nullptr; + return status; + } +} + int64_t GetHitsBlobSize(CMarshaledHits c_marshaled_hits) { int64_t total_size = 0; diff --git a/internal/core/src/segcore/reduce_c.h b/internal/core/src/segcore/reduce_c.h index 53915af2c..2c3626ba8 100644 --- a/internal/core/src/segcore/reduce_c.h +++ b/internal/core/src/segcore/reduce_c.h @@ -38,6 +38,13 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits, int64_t num_segments, CPlan c_plan); +CStatus +ReorganizeSingleQueryResult(CMarshaledHits* c_marshaled_hits, + CPlaceholderGroup* c_placeholder_groups, + int64_t num_groups, + CQueryResult c_search_result, + CPlan c_plan); + int64_t GetHitsBlobSize(CMarshaledHits c_marshaled_hits); diff --git a/internal/querynode/load_service_test.go b/internal/querynode/load_service_test.go index dd374f075..830d209b1 100644 --- a/internal/querynode/load_service_test.go +++ b/internal/querynode/load_service_test.go @@ -166,7 +166,7 @@ import ( // Type: milvuspb.PlaceholderType_VectorFloat, // Values: [][]byte{searchRowByteData}, // } -// placeholderGroup := milvuspb.PlaceholderGroup{ +// placeholderGroup := milvuspb.searchRequest{ // Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, // } // placeGroupByte, err := proto.Marshal(&placeholderGroup) @@ -175,7 +175,7 @@ import ( // } // query := milvuspb.SearchRequest{ // Dsl: dslString, -// PlaceholderGroup: placeGroupByte, +// searchRequest: placeGroupByte, // } // queryByte, err := proto.Marshal(&query) // if err != nil { @@ -489,7 +489,7 @@ import ( // Type: milvuspb.PlaceholderType_VectorBinary, // Values: [][]byte{searchRowData}, // } -// placeholderGroup := milvuspb.PlaceholderGroup{ +// placeholderGroup := milvuspb.searchRequest{ // Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, // } // placeGroupByte, err := proto.Marshal(&placeholderGroup) @@ -498,7 +498,7 @@ import ( // } // query := milvuspb.SearchRequest{ // Dsl: dslString, -// PlaceholderGroup: placeGroupByte, +// searchRequest: placeGroupByte, // } // queryByte, err := proto.Marshal(&query) // if err != nil { @@ -1186,9 +1186,9 @@ func TestSegmentLoad_Search_Vector(t *testing.T) { assert.NoError(t, err) plan, err := createPlan(*collection, dslString) assert.NoError(t, err) - holder, err := parserPlaceholderGroup(plan, placeHolderGroupBlob) + holder, err := parseSearchRequest(plan, placeHolderGroupBlob) assert.NoError(t, err) - placeholderGroups := make([]*PlaceholderGroup, 0) + placeholderGroups := make([]*searchRequest, 0) placeholderGroups = append(placeholderGroups, holder) // wait for segment building index diff --git a/internal/querynode/plan.go b/internal/querynode/plan.go index b168dec0d..c65d68380 100644 --- a/internal/querynode/plan.go +++ b/internal/querynode/plan.go @@ -49,32 +49,32 @@ func (plan *Plan) delete() { C.DeletePlan(plan.cPlan) } -type PlaceholderGroup struct { +type searchRequest struct { cPlaceholderGroup C.CPlaceholderGroup } -func parserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) (*PlaceholderGroup, error) { - if len(placeHolderBlob) == 0 { +func parseSearchRequest(plan *Plan, searchRequestBlob []byte) (*searchRequest, error) { + if len(searchRequestBlob) == 0 { return nil, errors.New("empty search request") } - var blobPtr = unsafe.Pointer(&placeHolderBlob[0]) - blobSize := C.long(len(placeHolderBlob)) + var blobPtr = unsafe.Pointer(&searchRequestBlob[0]) + blobSize := C.long(len(searchRequestBlob)) var cPlaceholderGroup C.CPlaceholderGroup status := C.ParsePlaceholderGroup(plan.cPlan, blobPtr, blobSize, &cPlaceholderGroup) - if err := HandleCStatus(&status, "parser placeholder group failed"); err != nil { + if err := HandleCStatus(&status, "parser searchRequest failed"); err != nil { return nil, err } - var newPlaceholderGroup = &PlaceholderGroup{cPlaceholderGroup: cPlaceholderGroup} - return newPlaceholderGroup, nil + var newSearchRequest = &searchRequest{cPlaceholderGroup: cPlaceholderGroup} + return newSearchRequest, nil } -func (pg *PlaceholderGroup) getNumOfQuery() int64 { +func (pg *searchRequest) getNumOfQuery() int64 { numQueries := C.GetNumOfQueries(pg.cPlaceholderGroup) return int64(numQueries) } -func (pg *PlaceholderGroup) delete() { +func (pg *searchRequest) delete() { C.DeletePlaceholderGroup(pg.cPlaceholderGroup) } diff --git a/internal/querynode/plan_test.go b/internal/querynode/plan_test.go index 5131c4a3f..0484e5e19 100644 --- a/internal/querynode/plan_test.go +++ b/internal/querynode/plan_test.go @@ -68,7 +68,7 @@ func TestPlan_PlaceholderGroup(t *testing.T) { placeGroupByte, err := proto.Marshal(&placeholderGroup) assert.Nil(t, err) - holder, err := parserPlaceholderGroup(plan, placeGroupByte) + holder, err := parseSearchRequest(plan, placeGroupByte) assert.NoError(t, err) assert.NotNil(t, holder) numQueries := holder.getNumOfQuery() diff --git a/internal/querynode/reduce.go b/internal/querynode/reduce.go index 027eaa9fc..33cd03076 100644 --- a/internal/querynode/reduce.go +++ b/internal/querynode/reduce.go @@ -10,10 +10,13 @@ package querynode */ import "C" import ( + "errors" + "fmt" "strconv" + "sync" "unsafe" - "errors" + "github.com/zilliztech/milvus-distributed/internal/log" ) type SearchResult struct { @@ -46,24 +49,31 @@ func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inRed } func fillTargetEntry(plan *Plan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error { - for i, value := range inReduced { - if value { - err := matchedSegments[i].fillTargetEntry(plan, searchResults[i]) - if err != nil { - return err - } + wg := &sync.WaitGroup{} + fmt.Println(inReduced) + for i := range inReduced { + if inReduced[i] { + wg.Add(1) + go func(i int) { + err := matchedSegments[i].fillTargetEntry(plan, searchResults[i]) + if err != nil { + log.Error(err.Error()) + } + wg.Done() + }(i) } } + wg.Wait() return nil } -func reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) { +func reorganizeQueryResults(plan *Plan, searchRequests []*searchRequest, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) { cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) - for _, pg := range placeholderGroups { + for _, pg := range searchRequests { cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) } var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) - var cNumGroup = (C.long)(len(placeholderGroups)) + var cNumGroup = (C.long)(len(searchRequests)) cSearchResults := make([]C.CQueryResult, 0) for _, res := range searchResults { @@ -86,6 +96,28 @@ func reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup, s return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil } +func reorganizeSingleQueryResult(plan *Plan, placeholderGroups []*searchRequest, searchResult *SearchResult) (*MarshaledHits, error) { + cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) + for _, pg := range placeholderGroups { + cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) + } + var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) + var cNumGroup = (C.long)(len(placeholderGroups)) + + cSearchResult := searchResult.cQueryResult + var cMarshaledHits C.CMarshaledHits + + status := C.ReorganizeSingleQueryResult(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResult, plan.cPlan) + errorCode := status.error_code + + if errorCode != 0 { + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + return nil, errors.New("reorganizeQueryResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) + } + return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil +} + func (mh *MarshaledHits) getHitsBlobSize() int64 { res := C.GetHitsBlobSize(mh.cMarshaledHits) return int64(res) diff --git a/internal/querynode/reduce_test.go b/internal/querynode/reduce_test.go index 7e4ee506e..79d0425c2 100644 --- a/internal/querynode/reduce_test.go +++ b/internal/querynode/reduce_test.go @@ -54,9 +54,9 @@ func TestReduce_AllFunc(t *testing.T) { plan, err := createPlan(*collection, dslString) assert.NoError(t, err) - holder, err := parserPlaceholderGroup(plan, placeGroupByte) + holder, err := parseSearchRequest(plan, placeGroupByte) assert.NoError(t, err) - placeholderGroups := make([]*PlaceholderGroup, 0) + placeholderGroups := make([]*searchRequest, 0) placeholderGroups = append(placeholderGroups, holder) searchResults := make([]*SearchResult, 0) diff --git a/internal/querynode/search_collection.go b/internal/querynode/search_collection.go index 44e093226..fb4f0422d 100644 --- a/internal/querynode/search_collection.go +++ b/internal/querynode/search_collection.go @@ -271,13 +271,13 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { if err != nil { return err } - placeHolderGroupBlob := query.PlaceholderGroup - placeholderGroup, err := parserPlaceholderGroup(plan, placeHolderGroupBlob) + searchRequestBlob := query.PlaceholderGroup + searchReq, err := parseSearchRequest(plan, searchRequestBlob) if err != nil { return err } - placeholderGroups := make([]*PlaceholderGroup, 0) - placeholderGroups = append(placeholderGroups, placeholderGroup) + searchRequests := make([]*searchRequest, 0) + searchRequests = append(searchRequests, searchReq) searchResults := make([]*SearchResult, 0) matchedSegments := make([]*Segment, 0) @@ -315,7 +315,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { if err != nil { return err } - searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}) + searchResult, err := segment.segmentSearch(plan, searchRequests, []Timestamp{searchTimestamp}) if err != nil { return err @@ -326,7 +326,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { } if len(searchResults) <= 0 { - for _, group := range placeholderGroups { + for _, group := range searchRequests { nq := group.getNumOfQuery() nilHits := make([][]byte, nq) hit := &milvuspb.Hits{} @@ -363,17 +363,30 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { inReduced := make([]bool, len(searchResults)) numSegment := int64(len(searchResults)) - err2 := reduceSearchResults(searchResults, numSegment, inReduced) - if err2 != nil { - return err2 - } - err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) - if err != nil { - return err - } - marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, numSegment, inReduced) - if err != nil { - return err + var marshaledHits *MarshaledHits = nil + if numSegment == 1 { + inReduced[0] = true + err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) + if err != nil { + return err + } + marshaledHits, err = reorganizeSingleQueryResult(plan, searchRequests, searchResults[0]) + if err != nil { + return err + } + } else { + err = reduceSearchResults(searchResults, numSegment, inReduced) + if err != nil { + return err + } + err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) + if err != nil { + return err + } + marshaledHits, err = reorganizeQueryResults(plan, searchRequests, searchResults, numSegment, inReduced) + if err != nil { + return err + } } hitsBlob, err := marshaledHits.getHitsBlob() if err != nil { @@ -381,14 +394,14 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { } var offset int64 = 0 - for index := range placeholderGroups { + for index := range searchRequests { hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index)) if err != nil { return err } - hits := make([][]byte, 0) - for _, len := range hitBlobSizePeerQuery { - hits = append(hits, hitsBlob[offset:offset+len]) + hits := make([][]byte, len(hitBlobSizePeerQuery)) + for i, len := range hitBlobSizePeerQuery { + hits[i] = hitsBlob[offset : offset+len] //test code to checkout marshaled hits //marshaledHit := hitsBlob[offset:offset+len] //unMarshaledHit := milvuspb.Hits{} @@ -436,7 +449,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { deleteSearchResults(searchResults) deleteMarshaledHits(marshaledHits) plan.delete() - placeholderGroup.delete() + searchReq.delete() return nil } diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index f7c7fdd8c..5b0286dac 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -214,7 +214,7 @@ func (s *Segment) getMemSize() int64 { } func (s *Segment) segmentSearch(plan *Plan, - placeHolderGroups []*PlaceholderGroup, + searchRequests []*searchRequest, timestamp []Timestamp) (*SearchResult, error) { /* CStatus @@ -229,14 +229,14 @@ func (s *Segment) segmentSearch(plan *Plan, return nil, errors.New("null seg core pointer") } cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) - for _, pg := range placeHolderGroups { + for _, pg := range searchRequests { cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) } var searchResult SearchResult var cTimestamp = (*C.ulong)(×tamp[0]) var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) - var cNumGroups = C.int(len(placeHolderGroups)) + var cNumGroups = C.int(len(searchRequests)) log.Debug("do search on segment", zap.Int64("segmentID", s.segmentID), zap.Int32("segmentType", int32(s.segmentType))) var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, &searchResult.cQueryResult) @@ -257,6 +257,7 @@ func (s *Segment) fillTargetEntry(plan *Plan, return errors.New("null seg core pointer") } + log.Debug("segment fill target entry, ", zap.Int64("segment ID = ", s.segmentID)) var status = C.FillTargetEntry(s.segmentPtr, plan.cPlan, result.cQueryResult) errorCode := status.error_code diff --git a/internal/querynode/segment_test.go b/internal/querynode/segment_test.go index d4e03a9f3..0f4518d25 100644 --- a/internal/querynode/segment_test.go +++ b/internal/querynode/segment_test.go @@ -355,9 +355,9 @@ func TestSegment_segmentSearch(t *testing.T) { searchTimestamp := Timestamp(1020) plan, err := createPlan(*collection, dslString) assert.NoError(t, err) - holder, err := parserPlaceholderGroup(plan, placeHolderGroupBlob) + holder, err := parseSearchRequest(plan, placeHolderGroupBlob) assert.NoError(t, err) - placeholderGroups := make([]*PlaceholderGroup, 0) + placeholderGroups := make([]*searchRequest, 0) placeholderGroups = append(placeholderGroups, holder) searchResults := make([]*SearchResult, 0) -- GitLab