提交 af3c14a8 编写于 作者: B bigsheeper 提交者: yefu.chen

Add batched search support

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 d7e6b993
......@@ -76,14 +76,21 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp
}
std::vector<float> all_scores;
std::vector<float> all_distance;
std::vector<int64_t> all_entities_ids;
// Proxy get numQueries from row_num.
auto numQueries = results[0]->row_num();
auto topK = results[0]->distances_size() / numQueries;
// 2d array for multiple queries
std::vector<std::vector<float>> all_distance(numQueries);
std::vector<std::vector<int64_t>> all_entities_ids(numQueries);
std::vector<bool> all_valid_row;
std::vector<grpc::RowData> all_row_data;
std::vector<grpc::KeyValuePair> all_kv_pairs;
grpc::Status status;
int row_num = 0;
// int row_num = 0;
for (auto &result_per_node : results) {
if (result_per_node->status().error_code() != grpc::ErrorCode::SUCCESS) {
......@@ -91,46 +98,66 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp
// one_node_res->entities().status().error_code() != grpc::ErrorCode::SUCCESS) {
return Status(DB_ERROR, "QueryNode return wrong status!");
}
for (int j = 0; j < result_per_node->distances_size(); j++) {
all_scores.push_back(result_per_node->scores()[j]);
all_distance.push_back(result_per_node->distances()[j]);
// all_kv_pairs.push_back(result_per_node->extra_params()[j]);
}
for (int k = 0; k < result_per_node->entities().ids_size(); ++k) {
all_entities_ids.push_back(result_per_node->entities().ids(k));
// all_valid_row.push_back(result_per_node->entities().valid_row(k));
// all_row_data.push_back(result_per_node->entities().rows_data(k));
}
if (result_per_node->row_num() > row_num) {
row_num = result_per_node->row_num();
// assert(result_per_node->row_num() == numQueries);
for (int i = 0; i < numQueries; i++) {
for (int j = i * topK; j < (i + 1) * topK && j < result_per_node->distances_size(); j++) {
all_scores.push_back(result_per_node->scores()[j]);
all_distance[i].push_back(result_per_node->distances()[j]);
all_entities_ids[i].push_back(result_per_node->entities().ids(j));
}
}
// for (int j = 0; j < result_per_node->distances_size(); j++) {
// all_scores.push_back(result_per_node->scores()[j]);
// all_distance.push_back(result_per_node->distances()[j]);
//// all_kv_pairs.push_back(result_per_node->extra_params()[j]);
// }
// for (int k = 0; k < result_per_node->entities().ids_size(); ++k) {
// all_entities_ids.push_back(result_per_node->entities().ids(k));
//// all_valid_row.push_back(result_per_node->entities().valid_row(k));
//// all_row_data.push_back(result_per_node->entities().rows_data(k));
// }
// if (result_per_node->row_num() > row_num) {
// row_num = result_per_node->row_num();
// }
status = result_per_node->status();
}
std::vector<int> index(all_distance.size());
std::vector<std::vector<int>> index_array;
for (int i = 0; i < numQueries; i++) {
auto &distance = all_distance[i];
std::vector<int> index(distance.size());
iota(index.begin(), index.end(), 0);
iota(index.begin(), index.end(), 0);
std::stable_sort(index.begin(), index.end(),
[&distance](size_t i1, size_t i2) { return distance[i1] < distance[i2]; });
index_array.emplace_back(index);
}
std::stable_sort(index.begin(), index.end(),
[&all_distance](size_t i1, size_t i2) { return all_distance[i1] > all_distance[i2]; });
grpc::Entities result_entities;
for (int m = 0; m < result->row_num(); ++m) {
result->add_scores(all_scores[index[m]]);
result->add_distances(all_distance[index[m]]);
for (int i = 0; i < numQueries; i++) {
for (int m = 0; m < topK; ++m) {
result->add_scores(all_scores[index_array[i][m]]);
result->add_distances(all_distance[i][index_array[i][m]]);
// result->add_extra_params();
// result->mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]);
result_entities.add_ids(all_entities_ids[index[m]]);
result_entities.add_ids(all_entities_ids[i][index_array[i][m]]);
// result_entities.add_valid_row(all_valid_row[index[m]]);
// result_entities.add_rows_data();
// result_entities.mutable_rows_data(m)->CopyFrom(all_row_data[index[m]]);
}
}
result_entities.mutable_status()->CopyFrom(status);
result->set_row_num(row_num);
result->set_row_num(numQueries);
result->mutable_entities()->CopyFrom(result_entities);
result->set_query_id(results[0]->query_id());
// result->set_client_id(results[0]->client_id());
......
......@@ -276,7 +276,7 @@ func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) {
if node.msgCounter.InsertCounter/CountInsertMsgBaseline != BaselineCounter {
node.WriteQueryLog()
BaselineCounter = node.msgCounter.InsertCounter/CountInsertMsgBaseline
BaselineCounter = node.msgCounter.InsertCounter / CountInsertMsgBaseline
}
if msgLen[0] == 0 && len(node.buffer.InsertDeleteBuffer) <= 0 {
......@@ -339,10 +339,10 @@ func (node *QueryNode) RunSearch(wg *sync.WaitGroup) {
case msg := <-node.messageClient.GetSearchChan():
node.messageClient.SearchMsg = node.messageClient.SearchMsg[:0]
node.messageClient.SearchMsg = append(node.messageClient.SearchMsg, msg)
fmt.Println("Do Search...")
//for {
//if node.messageClient.SearchMsg[0].Timestamp < node.queryNodeTimeSync.ServiceTimeSync {
var status = node.Search(node.messageClient.SearchMsg)
fmt.Println("Do Search done")
if status.ErrorCode != 0 {
fmt.Println("Search Failed")
node.PublishFailedSearchResult()
......@@ -504,8 +504,8 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status {
}
wg.Add(1)
var deleteTimestamps = node.deleteData.deleteTimestamps[segmentID]
fmt.Println("Doing delete......")
go node.DoDelete(segmentID, &deleteIDs, &deleteTimestamps, &wg)
fmt.Println("Do delete done")
}
wg.Wait()
......@@ -513,7 +513,6 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status {
}
func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Status {
fmt.Println("Doing insert..., len = ", len(node.insertData.insertIDs[segmentID]))
var targetSegment, err = node.GetSegmentBySegmentID(segmentID)
if err != nil {
fmt.Println(err.Error())
......@@ -526,6 +525,7 @@ func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Statu
offsets := node.insertData.insertOffset[segmentID]
err = targetSegment.SegmentInsert(offsets, &ids, &timestamps, &records)
fmt.Println("Do insert done, len = ", len(node.insertData.insertIDs[segmentID]))
node.QueryLog(len(ids))
......@@ -584,8 +584,6 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// TODO: Do not receive batched search requests
for _, msg := range searchMessages {
var clientId = msg.ClientId
var resultsTmp = make([]SearchResultTmp, 0)
var searchTimestamp = msg.Timestamp
// ServiceTimeSync update by readerTimeSync, which is get from proxy.
......@@ -610,6 +608,11 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// 2. Get query information from query json
query := node.QueryJson2Info(&queryJson)
// 2d slice for receiving multiple queries's results
var resultsTmp = make([][]SearchResultTmp, query.NumQueries)
for i := 0; i < int(query.NumQueries); i++ {
resultsTmp[i] = make([]SearchResultTmp, 0)
}
// 3. Do search in all segments
for _, segment := range node.SegmentsMap {
......@@ -625,18 +628,30 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
return msgPb.Status{ErrorCode: 1}
}
for i := 0; i < len(res.ResultIds); i++ {
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
for i := 0; i < int(query.NumQueries); i++ {
for j := i * query.TopK; j < (i+1)*query.TopK; j++ {
resultsTmp[i] = append(resultsTmp[i], SearchResultTmp{
ResultId: res.ResultIds[j],
ResultDistance: res.ResultDistances[j],
})
}
}
}
// 4. Reduce results
sort.Slice(resultsTmp, func(i, j int) bool {
return resultsTmp[i].ResultDistance < resultsTmp[j].ResultDistance
})
if len(resultsTmp) > query.TopK {
resultsTmp = resultsTmp[:query.TopK]
for _, rTmp := range resultsTmp {
sort.Slice(rTmp, func(i, j int) bool {
return rTmp[i].ResultDistance < rTmp[j].ResultDistance
})
}
for _, rTmp := range resultsTmp {
if len(rTmp) > query.TopK {
rTmp = rTmp[:query.TopK]
}
}
var entities = msgPb.Entities{
Ids: make([]int64, 0),
}
......@@ -649,15 +664,19 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
QueryId: msg.Uid,
ClientId: clientId,
}
for _, res := range resultsTmp {
results.Entities.Ids = append(results.Entities.Ids, res.ResultId)
results.Distances = append(results.Distances, res.ResultDistance)
results.Scores = append(results.Distances, float32(0))
for _, rTmp := range resultsTmp {
for _, res := range rTmp {
results.Entities.Ids = append(results.Entities.Ids, res.ResultId)
results.Distances = append(results.Distances, res.ResultDistance)
results.Scores = append(results.Distances, float32(0))
}
}
results.RowNum = int64(len(results.Distances))
// Send numQueries to RowNum.
results.RowNum = query.NumQueries
// 5. publish result to pulsar
//fmt.Println(results.Entities.Ids)
//fmt.Println(results.Distances)
node.PublishSearchResult(&results)
}
......
......@@ -218,8 +218,8 @@ func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord
field_name: C.CString(query.FieldName),
}
resultIds := make([]int64, query.TopK)
resultDistances := make([]float32, query.TopK)
resultIds := make([]int64, int64(query.TopK) * query.NumQueries)
resultDistances := make([]float32, int64(query.TopK) * query.NumQueries)
var cTimestamp = C.ulong(timestamp)
var cResultIds = (*C.long)(&resultIds[0])
......
......@@ -17,6 +17,7 @@
#include "utils/Utils.h"
#include <random>
const int NUM_OF_VECTOR = 1;
const int TOP_K = 10;
const int LOOP = 1000;
......@@ -32,7 +33,7 @@ get_vector_param() {
std::normal_distribution<float> dis(0, 1);
for (int j = 0; j < 1; ++j) {
for (int j = 0; j < NUM_OF_VECTOR; ++j) {
milvus::VectorData vectorData;
std::vector<float> float_data;
for (int i = 0; i < DIM; ++i) {
......@@ -44,7 +45,7 @@ get_vector_param() {
}
nlohmann::json vector_param_json;
vector_param_json["num_queries"] = 1;
vector_param_json["num_queries"] = NUM_OF_VECTOR;
vector_param_json["topK"] = TOP_K;
vector_param_json["field_name"] = "field_vec";
std::string vector_param_json_string = vector_param_json.dump();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册