From af3c14a8c4840c126b9ce0276f202c4303773755 Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Tue, 13 Oct 2020 16:20:39 +0800 Subject: [PATCH] Add batched search support Signed-off-by: bigsheeper --- proxy/src/message_client/ClientV2.cpp | 75 ++++++++++++++++++--------- reader/read_node/query_node.go | 57 +++++++++++++------- reader/read_node/segment.go | 4 +- sdk/examples/simple/search.cpp | 5 +- 4 files changed, 94 insertions(+), 47 deletions(-) diff --git a/proxy/src/message_client/ClientV2.cpp b/proxy/src/message_client/ClientV2.cpp index 92a17ca49..ab430e1c8 100644 --- a/proxy/src/message_client/ClientV2.cpp +++ b/proxy/src/message_client/ClientV2.cpp @@ -76,14 +76,21 @@ Aggregation(std::vector> results, milvus::grp } std::vector all_scores; - std::vector all_distance; - std::vector all_entities_ids; + + // Proxy get numQueries from row_num. + auto numQueries = results[0]->row_num(); + auto topK = results[0]->distances_size() / numQueries; + + // 2d array for multiple queries + std::vector> all_distance(numQueries); + std::vector> all_entities_ids(numQueries); + std::vector all_valid_row; std::vector all_row_data; std::vector all_kv_pairs; grpc::Status status; - int row_num = 0; +// int row_num = 0; for (auto &result_per_node : results) { if (result_per_node->status().error_code() != grpc::ErrorCode::SUCCESS) { @@ -91,46 +98,66 @@ Aggregation(std::vector> results, milvus::grp // one_node_res->entities().status().error_code() != grpc::ErrorCode::SUCCESS) { return Status(DB_ERROR, "QueryNode return wrong status!"); } - for (int j = 0; j < result_per_node->distances_size(); j++) { - all_scores.push_back(result_per_node->scores()[j]); - all_distance.push_back(result_per_node->distances()[j]); -// all_kv_pairs.push_back(result_per_node->extra_params()[j]); - } - for (int k = 0; k < result_per_node->entities().ids_size(); ++k) { - all_entities_ids.push_back(result_per_node->entities().ids(k)); -// all_valid_row.push_back(result_per_node->entities().valid_row(k)); -// all_row_data.push_back(result_per_node->entities().rows_data(k)); - } - if (result_per_node->row_num() > row_num) { - row_num = result_per_node->row_num(); + +// assert(result_per_node->row_num() == numQueries); + + for (int i = 0; i < numQueries; i++) { + for (int j = i * topK; j < (i + 1) * topK && j < result_per_node->distances_size(); j++) { + all_scores.push_back(result_per_node->scores()[j]); + all_distance[i].push_back(result_per_node->distances()[j]); + all_entities_ids[i].push_back(result_per_node->entities().ids(j)); + } } + +// for (int j = 0; j < result_per_node->distances_size(); j++) { +// all_scores.push_back(result_per_node->scores()[j]); +// all_distance.push_back(result_per_node->distances()[j]); +//// all_kv_pairs.push_back(result_per_node->extra_params()[j]); +// } +// for (int k = 0; k < result_per_node->entities().ids_size(); ++k) { +// all_entities_ids.push_back(result_per_node->entities().ids(k)); +//// all_valid_row.push_back(result_per_node->entities().valid_row(k)); +//// all_row_data.push_back(result_per_node->entities().rows_data(k)); +// } + +// if (result_per_node->row_num() > row_num) { +// row_num = result_per_node->row_num(); +// } status = result_per_node->status(); } - std::vector index(all_distance.size()); + std::vector> index_array; + for (int i = 0; i < numQueries; i++) { + auto &distance = all_distance[i]; + std::vector index(distance.size()); - iota(index.begin(), index.end(), 0); + iota(index.begin(), index.end(), 0); + + std::stable_sort(index.begin(), index.end(), + [&distance](size_t i1, size_t i2) { return distance[i1] < distance[i2]; }); + index_array.emplace_back(index); + } - std::stable_sort(index.begin(), index.end(), - [&all_distance](size_t i1, size_t i2) { return all_distance[i1] > all_distance[i2]; }); grpc::Entities result_entities; - for (int m = 0; m < result->row_num(); ++m) { - result->add_scores(all_scores[index[m]]); - result->add_distances(all_distance[index[m]]); + for (int i = 0; i < numQueries; i++) { + for (int m = 0; m < topK; ++m) { + result->add_scores(all_scores[index_array[i][m]]); + result->add_distances(all_distance[i][index_array[i][m]]); // result->add_extra_params(); // result->mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]); - result_entities.add_ids(all_entities_ids[index[m]]); + result_entities.add_ids(all_entities_ids[i][index_array[i][m]]); // result_entities.add_valid_row(all_valid_row[index[m]]); // result_entities.add_rows_data(); // result_entities.mutable_rows_data(m)->CopyFrom(all_row_data[index[m]]); + } } result_entities.mutable_status()->CopyFrom(status); - result->set_row_num(row_num); + result->set_row_num(numQueries); result->mutable_entities()->CopyFrom(result_entities); result->set_query_id(results[0]->query_id()); // result->set_client_id(results[0]->client_id()); diff --git a/reader/read_node/query_node.go b/reader/read_node/query_node.go index 1e7d80867..a0049965a 100644 --- a/reader/read_node/query_node.go +++ b/reader/read_node/query_node.go @@ -276,7 +276,7 @@ func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) { if node.msgCounter.InsertCounter/CountInsertMsgBaseline != BaselineCounter { node.WriteQueryLog() - BaselineCounter = node.msgCounter.InsertCounter/CountInsertMsgBaseline + BaselineCounter = node.msgCounter.InsertCounter / CountInsertMsgBaseline } if msgLen[0] == 0 && len(node.buffer.InsertDeleteBuffer) <= 0 { @@ -339,10 +339,10 @@ func (node *QueryNode) RunSearch(wg *sync.WaitGroup) { case msg := <-node.messageClient.GetSearchChan(): node.messageClient.SearchMsg = node.messageClient.SearchMsg[:0] node.messageClient.SearchMsg = append(node.messageClient.SearchMsg, msg) - fmt.Println("Do Search...") //for { //if node.messageClient.SearchMsg[0].Timestamp < node.queryNodeTimeSync.ServiceTimeSync { var status = node.Search(node.messageClient.SearchMsg) + fmt.Println("Do Search done") if status.ErrorCode != 0 { fmt.Println("Search Failed") node.PublishFailedSearchResult() @@ -504,8 +504,8 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status { } wg.Add(1) var deleteTimestamps = node.deleteData.deleteTimestamps[segmentID] - fmt.Println("Doing delete......") go node.DoDelete(segmentID, &deleteIDs, &deleteTimestamps, &wg) + fmt.Println("Do delete done") } wg.Wait() @@ -513,7 +513,6 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status { } func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Status { - fmt.Println("Doing insert..., len = ", len(node.insertData.insertIDs[segmentID])) var targetSegment, err = node.GetSegmentBySegmentID(segmentID) if err != nil { fmt.Println(err.Error()) @@ -526,6 +525,7 @@ func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Statu offsets := node.insertData.insertOffset[segmentID] err = targetSegment.SegmentInsert(offsets, &ids, ×tamps, &records) + fmt.Println("Do insert done, len = ", len(node.insertData.insertIDs[segmentID])) node.QueryLog(len(ids)) @@ -584,8 +584,6 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { // TODO: Do not receive batched search requests for _, msg := range searchMessages { var clientId = msg.ClientId - var resultsTmp = make([]SearchResultTmp, 0) - var searchTimestamp = msg.Timestamp // ServiceTimeSync update by readerTimeSync, which is get from proxy. @@ -610,6 +608,11 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { // 2. Get query information from query json query := node.QueryJson2Info(&queryJson) + // 2d slice for receiving multiple queries's results + var resultsTmp = make([][]SearchResultTmp, query.NumQueries) + for i := 0; i < int(query.NumQueries); i++ { + resultsTmp[i] = make([]SearchResultTmp, 0) + } // 3. Do search in all segments for _, segment := range node.SegmentsMap { @@ -625,18 +628,30 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { return msgPb.Status{ErrorCode: 1} } - for i := 0; i < len(res.ResultIds); i++ { - resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]}) + for i := 0; i < int(query.NumQueries); i++ { + for j := i * query.TopK; j < (i+1)*query.TopK; j++ { + resultsTmp[i] = append(resultsTmp[i], SearchResultTmp{ + ResultId: res.ResultIds[j], + ResultDistance: res.ResultDistances[j], + }) + } } } // 4. Reduce results - sort.Slice(resultsTmp, func(i, j int) bool { - return resultsTmp[i].ResultDistance < resultsTmp[j].ResultDistance - }) - if len(resultsTmp) > query.TopK { - resultsTmp = resultsTmp[:query.TopK] + for _, rTmp := range resultsTmp { + sort.Slice(rTmp, func(i, j int) bool { + return rTmp[i].ResultDistance < rTmp[j].ResultDistance + }) + } + + for _, rTmp := range resultsTmp { + if len(rTmp) > query.TopK { + rTmp = rTmp[:query.TopK] + } } + + var entities = msgPb.Entities{ Ids: make([]int64, 0), } @@ -649,15 +664,19 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { QueryId: msg.Uid, ClientId: clientId, } - for _, res := range resultsTmp { - results.Entities.Ids = append(results.Entities.Ids, res.ResultId) - results.Distances = append(results.Distances, res.ResultDistance) - results.Scores = append(results.Distances, float32(0)) + for _, rTmp := range resultsTmp { + for _, res := range rTmp { + results.Entities.Ids = append(results.Entities.Ids, res.ResultId) + results.Distances = append(results.Distances, res.ResultDistance) + results.Scores = append(results.Distances, float32(0)) + } } - - results.RowNum = int64(len(results.Distances)) + // Send numQueries to RowNum. + results.RowNum = query.NumQueries // 5. publish result to pulsar + //fmt.Println(results.Entities.Ids) + //fmt.Println(results.Distances) node.PublishSearchResult(&results) } diff --git a/reader/read_node/segment.go b/reader/read_node/segment.go index 21be00f35..2c718881f 100644 --- a/reader/read_node/segment.go +++ b/reader/read_node/segment.go @@ -218,8 +218,8 @@ func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord field_name: C.CString(query.FieldName), } - resultIds := make([]int64, query.TopK) - resultDistances := make([]float32, query.TopK) + resultIds := make([]int64, int64(query.TopK) * query.NumQueries) + resultDistances := make([]float32, int64(query.TopK) * query.NumQueries) var cTimestamp = C.ulong(timestamp) var cResultIds = (*C.long)(&resultIds[0]) diff --git a/sdk/examples/simple/search.cpp b/sdk/examples/simple/search.cpp index 35fde2e39..818025db6 100644 --- a/sdk/examples/simple/search.cpp +++ b/sdk/examples/simple/search.cpp @@ -17,6 +17,7 @@ #include "utils/Utils.h" #include +const int NUM_OF_VECTOR = 1; const int TOP_K = 10; const int LOOP = 1000; @@ -32,7 +33,7 @@ get_vector_param() { std::normal_distribution dis(0, 1); - for (int j = 0; j < 1; ++j) { + for (int j = 0; j < NUM_OF_VECTOR; ++j) { milvus::VectorData vectorData; std::vector float_data; for (int i = 0; i < DIM; ++i) { @@ -44,7 +45,7 @@ get_vector_param() { } nlohmann::json vector_param_json; - vector_param_json["num_queries"] = 1; + vector_param_json["num_queries"] = NUM_OF_VECTOR; vector_param_json["topK"] = TOP_K; vector_param_json["field_name"] = "field_vec"; std::string vector_param_json_string = vector_param_json.dump(); -- GitLab