提交 af3c14a8 编写于 作者: B bigsheeper 提交者: yefu.chen

Add batched search support

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 d7e6b993
...@@ -76,14 +76,21 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp ...@@ -76,14 +76,21 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp
} }
std::vector<float> all_scores; std::vector<float> all_scores;
std::vector<float> all_distance;
std::vector<int64_t> all_entities_ids; // Proxy get numQueries from row_num.
auto numQueries = results[0]->row_num();
auto topK = results[0]->distances_size() / numQueries;
// 2d array for multiple queries
std::vector<std::vector<float>> all_distance(numQueries);
std::vector<std::vector<int64_t>> all_entities_ids(numQueries);
std::vector<bool> all_valid_row; std::vector<bool> all_valid_row;
std::vector<grpc::RowData> all_row_data; std::vector<grpc::RowData> all_row_data;
std::vector<grpc::KeyValuePair> all_kv_pairs; std::vector<grpc::KeyValuePair> all_kv_pairs;
grpc::Status status; grpc::Status status;
int row_num = 0; // int row_num = 0;
for (auto &result_per_node : results) { for (auto &result_per_node : results) {
if (result_per_node->status().error_code() != grpc::ErrorCode::SUCCESS) { if (result_per_node->status().error_code() != grpc::ErrorCode::SUCCESS) {
...@@ -91,46 +98,66 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp ...@@ -91,46 +98,66 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp
// one_node_res->entities().status().error_code() != grpc::ErrorCode::SUCCESS) { // one_node_res->entities().status().error_code() != grpc::ErrorCode::SUCCESS) {
return Status(DB_ERROR, "QueryNode return wrong status!"); return Status(DB_ERROR, "QueryNode return wrong status!");
} }
for (int j = 0; j < result_per_node->distances_size(); j++) {
all_scores.push_back(result_per_node->scores()[j]); // assert(result_per_node->row_num() == numQueries);
all_distance.push_back(result_per_node->distances()[j]);
// all_kv_pairs.push_back(result_per_node->extra_params()[j]); for (int i = 0; i < numQueries; i++) {
} for (int j = i * topK; j < (i + 1) * topK && j < result_per_node->distances_size(); j++) {
for (int k = 0; k < result_per_node->entities().ids_size(); ++k) { all_scores.push_back(result_per_node->scores()[j]);
all_entities_ids.push_back(result_per_node->entities().ids(k)); all_distance[i].push_back(result_per_node->distances()[j]);
// all_valid_row.push_back(result_per_node->entities().valid_row(k)); all_entities_ids[i].push_back(result_per_node->entities().ids(j));
// all_row_data.push_back(result_per_node->entities().rows_data(k)); }
}
if (result_per_node->row_num() > row_num) {
row_num = result_per_node->row_num();
} }
// for (int j = 0; j < result_per_node->distances_size(); j++) {
// all_scores.push_back(result_per_node->scores()[j]);
// all_distance.push_back(result_per_node->distances()[j]);
//// all_kv_pairs.push_back(result_per_node->extra_params()[j]);
// }
// for (int k = 0; k < result_per_node->entities().ids_size(); ++k) {
// all_entities_ids.push_back(result_per_node->entities().ids(k));
//// all_valid_row.push_back(result_per_node->entities().valid_row(k));
//// all_row_data.push_back(result_per_node->entities().rows_data(k));
// }
// if (result_per_node->row_num() > row_num) {
// row_num = result_per_node->row_num();
// }
status = result_per_node->status(); status = result_per_node->status();
} }
std::vector<int> index(all_distance.size()); std::vector<std::vector<int>> index_array;
for (int i = 0; i < numQueries; i++) {
auto &distance = all_distance[i];
std::vector<int> index(distance.size());
iota(index.begin(), index.end(), 0); iota(index.begin(), index.end(), 0);
std::stable_sort(index.begin(), index.end(),
[&distance](size_t i1, size_t i2) { return distance[i1] < distance[i2]; });
index_array.emplace_back(index);
}
std::stable_sort(index.begin(), index.end(),
[&all_distance](size_t i1, size_t i2) { return all_distance[i1] > all_distance[i2]; });
grpc::Entities result_entities; grpc::Entities result_entities;
for (int m = 0; m < result->row_num(); ++m) { for (int i = 0; i < numQueries; i++) {
result->add_scores(all_scores[index[m]]); for (int m = 0; m < topK; ++m) {
result->add_distances(all_distance[index[m]]); result->add_scores(all_scores[index_array[i][m]]);
result->add_distances(all_distance[i][index_array[i][m]]);
// result->add_extra_params(); // result->add_extra_params();
// result->mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]); // result->mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]);
result_entities.add_ids(all_entities_ids[index[m]]); result_entities.add_ids(all_entities_ids[i][index_array[i][m]]);
// result_entities.add_valid_row(all_valid_row[index[m]]); // result_entities.add_valid_row(all_valid_row[index[m]]);
// result_entities.add_rows_data(); // result_entities.add_rows_data();
// result_entities.mutable_rows_data(m)->CopyFrom(all_row_data[index[m]]); // result_entities.mutable_rows_data(m)->CopyFrom(all_row_data[index[m]]);
}
} }
result_entities.mutable_status()->CopyFrom(status); result_entities.mutable_status()->CopyFrom(status);
result->set_row_num(row_num); result->set_row_num(numQueries);
result->mutable_entities()->CopyFrom(result_entities); result->mutable_entities()->CopyFrom(result_entities);
result->set_query_id(results[0]->query_id()); result->set_query_id(results[0]->query_id());
// result->set_client_id(results[0]->client_id()); // result->set_client_id(results[0]->client_id());
......
...@@ -276,7 +276,7 @@ func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) { ...@@ -276,7 +276,7 @@ func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) {
if node.msgCounter.InsertCounter/CountInsertMsgBaseline != BaselineCounter { if node.msgCounter.InsertCounter/CountInsertMsgBaseline != BaselineCounter {
node.WriteQueryLog() node.WriteQueryLog()
BaselineCounter = node.msgCounter.InsertCounter/CountInsertMsgBaseline BaselineCounter = node.msgCounter.InsertCounter / CountInsertMsgBaseline
} }
if msgLen[0] == 0 && len(node.buffer.InsertDeleteBuffer) <= 0 { if msgLen[0] == 0 && len(node.buffer.InsertDeleteBuffer) <= 0 {
...@@ -339,10 +339,10 @@ func (node *QueryNode) RunSearch(wg *sync.WaitGroup) { ...@@ -339,10 +339,10 @@ func (node *QueryNode) RunSearch(wg *sync.WaitGroup) {
case msg := <-node.messageClient.GetSearchChan(): case msg := <-node.messageClient.GetSearchChan():
node.messageClient.SearchMsg = node.messageClient.SearchMsg[:0] node.messageClient.SearchMsg = node.messageClient.SearchMsg[:0]
node.messageClient.SearchMsg = append(node.messageClient.SearchMsg, msg) node.messageClient.SearchMsg = append(node.messageClient.SearchMsg, msg)
fmt.Println("Do Search...")
//for { //for {
//if node.messageClient.SearchMsg[0].Timestamp < node.queryNodeTimeSync.ServiceTimeSync { //if node.messageClient.SearchMsg[0].Timestamp < node.queryNodeTimeSync.ServiceTimeSync {
var status = node.Search(node.messageClient.SearchMsg) var status = node.Search(node.messageClient.SearchMsg)
fmt.Println("Do Search done")
if status.ErrorCode != 0 { if status.ErrorCode != 0 {
fmt.Println("Search Failed") fmt.Println("Search Failed")
node.PublishFailedSearchResult() node.PublishFailedSearchResult()
...@@ -504,8 +504,8 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status { ...@@ -504,8 +504,8 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status {
} }
wg.Add(1) wg.Add(1)
var deleteTimestamps = node.deleteData.deleteTimestamps[segmentID] var deleteTimestamps = node.deleteData.deleteTimestamps[segmentID]
fmt.Println("Doing delete......")
go node.DoDelete(segmentID, &deleteIDs, &deleteTimestamps, &wg) go node.DoDelete(segmentID, &deleteIDs, &deleteTimestamps, &wg)
fmt.Println("Do delete done")
} }
wg.Wait() wg.Wait()
...@@ -513,7 +513,6 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status { ...@@ -513,7 +513,6 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status {
} }
func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Status { func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Status {
fmt.Println("Doing insert..., len = ", len(node.insertData.insertIDs[segmentID]))
var targetSegment, err = node.GetSegmentBySegmentID(segmentID) var targetSegment, err = node.GetSegmentBySegmentID(segmentID)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
...@@ -526,6 +525,7 @@ func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Statu ...@@ -526,6 +525,7 @@ func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Statu
offsets := node.insertData.insertOffset[segmentID] offsets := node.insertData.insertOffset[segmentID]
err = targetSegment.SegmentInsert(offsets, &ids, &timestamps, &records) err = targetSegment.SegmentInsert(offsets, &ids, &timestamps, &records)
fmt.Println("Do insert done, len = ", len(node.insertData.insertIDs[segmentID]))
node.QueryLog(len(ids)) node.QueryLog(len(ids))
...@@ -584,8 +584,6 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { ...@@ -584,8 +584,6 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// TODO: Do not receive batched search requests // TODO: Do not receive batched search requests
for _, msg := range searchMessages { for _, msg := range searchMessages {
var clientId = msg.ClientId var clientId = msg.ClientId
var resultsTmp = make([]SearchResultTmp, 0)
var searchTimestamp = msg.Timestamp var searchTimestamp = msg.Timestamp
// ServiceTimeSync update by readerTimeSync, which is get from proxy. // ServiceTimeSync update by readerTimeSync, which is get from proxy.
...@@ -610,6 +608,11 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { ...@@ -610,6 +608,11 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// 2. Get query information from query json // 2. Get query information from query json
query := node.QueryJson2Info(&queryJson) query := node.QueryJson2Info(&queryJson)
// 2d slice for receiving multiple queries's results
var resultsTmp = make([][]SearchResultTmp, query.NumQueries)
for i := 0; i < int(query.NumQueries); i++ {
resultsTmp[i] = make([]SearchResultTmp, 0)
}
// 3. Do search in all segments // 3. Do search in all segments
for _, segment := range node.SegmentsMap { for _, segment := range node.SegmentsMap {
...@@ -625,18 +628,30 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { ...@@ -625,18 +628,30 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
return msgPb.Status{ErrorCode: 1} return msgPb.Status{ErrorCode: 1}
} }
for i := 0; i < len(res.ResultIds); i++ { for i := 0; i < int(query.NumQueries); i++ {
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]}) for j := i * query.TopK; j < (i+1)*query.TopK; j++ {
resultsTmp[i] = append(resultsTmp[i], SearchResultTmp{
ResultId: res.ResultIds[j],
ResultDistance: res.ResultDistances[j],
})
}
} }
} }
// 4. Reduce results // 4. Reduce results
sort.Slice(resultsTmp, func(i, j int) bool { for _, rTmp := range resultsTmp {
return resultsTmp[i].ResultDistance < resultsTmp[j].ResultDistance sort.Slice(rTmp, func(i, j int) bool {
}) return rTmp[i].ResultDistance < rTmp[j].ResultDistance
if len(resultsTmp) > query.TopK { })
resultsTmp = resultsTmp[:query.TopK] }
for _, rTmp := range resultsTmp {
if len(rTmp) > query.TopK {
rTmp = rTmp[:query.TopK]
}
} }
var entities = msgPb.Entities{ var entities = msgPb.Entities{
Ids: make([]int64, 0), Ids: make([]int64, 0),
} }
...@@ -649,15 +664,19 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { ...@@ -649,15 +664,19 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
QueryId: msg.Uid, QueryId: msg.Uid,
ClientId: clientId, ClientId: clientId,
} }
for _, res := range resultsTmp { for _, rTmp := range resultsTmp {
results.Entities.Ids = append(results.Entities.Ids, res.ResultId) for _, res := range rTmp {
results.Distances = append(results.Distances, res.ResultDistance) results.Entities.Ids = append(results.Entities.Ids, res.ResultId)
results.Scores = append(results.Distances, float32(0)) results.Distances = append(results.Distances, res.ResultDistance)
results.Scores = append(results.Distances, float32(0))
}
} }
// Send numQueries to RowNum.
results.RowNum = int64(len(results.Distances)) results.RowNum = query.NumQueries
// 5. publish result to pulsar // 5. publish result to pulsar
//fmt.Println(results.Entities.Ids)
//fmt.Println(results.Distances)
node.PublishSearchResult(&results) node.PublishSearchResult(&results)
} }
......
...@@ -218,8 +218,8 @@ func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord ...@@ -218,8 +218,8 @@ func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord
field_name: C.CString(query.FieldName), field_name: C.CString(query.FieldName),
} }
resultIds := make([]int64, query.TopK) resultIds := make([]int64, int64(query.TopK) * query.NumQueries)
resultDistances := make([]float32, query.TopK) resultDistances := make([]float32, int64(query.TopK) * query.NumQueries)
var cTimestamp = C.ulong(timestamp) var cTimestamp = C.ulong(timestamp)
var cResultIds = (*C.long)(&resultIds[0]) var cResultIds = (*C.long)(&resultIds[0])
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "utils/Utils.h" #include "utils/Utils.h"
#include <random> #include <random>
const int NUM_OF_VECTOR = 1;
const int TOP_K = 10; const int TOP_K = 10;
const int LOOP = 1000; const int LOOP = 1000;
...@@ -32,7 +33,7 @@ get_vector_param() { ...@@ -32,7 +33,7 @@ get_vector_param() {
std::normal_distribution<float> dis(0, 1); std::normal_distribution<float> dis(0, 1);
for (int j = 0; j < 1; ++j) { for (int j = 0; j < NUM_OF_VECTOR; ++j) {
milvus::VectorData vectorData; milvus::VectorData vectorData;
std::vector<float> float_data; std::vector<float> float_data;
for (int i = 0; i < DIM; ++i) { for (int i = 0; i < DIM; ++i) {
...@@ -44,7 +45,7 @@ get_vector_param() { ...@@ -44,7 +45,7 @@ get_vector_param() {
} }
nlohmann::json vector_param_json; nlohmann::json vector_param_json;
vector_param_json["num_queries"] = 1; vector_param_json["num_queries"] = NUM_OF_VECTOR;
vector_param_json["topK"] = TOP_K; vector_param_json["topK"] = TOP_K;
vector_param_json["field_name"] = "field_vec"; vector_param_json["field_name"] = "field_vec";
std::string vector_param_json_string = vector_param_json.dump(); std::string vector_param_json_string = vector_param_json.dump();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册