提交 82e152a6 编写于 作者: G groot

MS-245 Improve search result transfer performance


Former-commit-id: 4a75c36a20c30092bd39c56df720df17847be0b0
上级 05b36b46
...@@ -29,7 +29,8 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -29,7 +29,8 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-208 - Add buildinde interface for C++ SDK - MS-208 - Add buildinde interface for C++ SDK
- MS-212 - Support Inner product metric type - MS-212 - Support Inner product metric type
- MS-241 - Build Faiss with MKL if using Intel CPU; else build with OpenBlas - MS-241 - Build Faiss with MKL if using Intel CPU; else build with OpenBlas
- MS-242 - clean up cmake and change MAKE_BUILD_ARGS to be user defined variable - MS-242 - Clean up cmake and change MAKE_BUILD_ARGS to be user defined variable
- MS-245 - Improve search result transfer performance
## New Feature ## New Feature
- MS-180 - Add new mem manager - MS-180 - Add new mem manager
......
...@@ -209,17 +209,25 @@ ClientProxy::SearchVector(const std::string &table_name, ...@@ -209,17 +209,25 @@ ClientProxy::SearchVector(const std::string &table_name,
} }
//step 3: search vectors //step 3: search vectors
std::vector<thrift::TopKQueryResult> result_array; std::vector<thrift::TopKQueryBinResult> result_array;
ClientPtr()->interface()->SearchVector(result_array, table_name, thrift_records, thrift_ranges, topk); ClientPtr()->interface()->SearchVector2(result_array, table_name, thrift_records, thrift_ranges, topk);
//step 4: convert result array //step 4: convert result array
for(auto& thrift_topk_result : result_array) { for(auto& thrift_topk_result : result_array) {
TopKQueryResult result; TopKQueryResult result;
for(auto& thrift_query_result : thrift_topk_result.query_result_arrays) { size_t id_count = thrift_topk_result.id_array.size()/sizeof(int64_t);
size_t dist_count = thrift_topk_result.distance_array.size()/ sizeof(double);
if(id_count != dist_count) {
return Status(StatusCode::UnknownError, "illegal result");
}
int64_t* id_ptr = (int64_t*)thrift_topk_result.id_array.data();
double* dist_ptr = (double*)thrift_topk_result.distance_array.data();
for(size_t i = 0; i < id_count; i++) {
QueryResult query_result; QueryResult query_result;
query_result.id = thrift_query_result.id; query_result.id = id_ptr[i];
query_result.distance = thrift_query_result.distance; query_result.distance = dist_ptr[i];
result.query_result_arrays.emplace_back(query_result); result.query_result_arrays.emplace_back(query_result);
} }
......
...@@ -60,11 +60,22 @@ RequestHandler::SearchVector(std::vector<thrift::TopKQueryResult> &_return, ...@@ -60,11 +60,22 @@ RequestHandler::SearchVector(std::vector<thrift::TopKQueryResult> &_return,
const std::vector<thrift::Range> &query_range_array, const std::vector<thrift::Range> &query_range_array,
const int64_t topk) { const int64_t topk) {
// SERVER_LOG_DEBUG << "Entering RequestHandler::SearchVector"; // SERVER_LOG_DEBUG << "Entering RequestHandler::SearchVector";
BaseTaskPtr task_ptr = SearchVectorTask::Create(table_name, std::vector<std::string>(), query_record_array, BaseTaskPtr task_ptr = SearchVectorTask1::Create(table_name, std::vector<std::string>(), query_record_array,
query_range_array, topk, _return); query_range_array, topk, _return);
RequestScheduler::ExecTask(task_ptr); RequestScheduler::ExecTask(task_ptr);
} }
void
RequestHandler::SearchVector2(std::vector<thrift::TopKQueryBinResult> & _return,
const std::string& table_name,
const std::vector<thrift::RowRecord> & query_record_array,
const std::vector<thrift::Range> & query_range_array,
const int64_t topk) {
BaseTaskPtr task_ptr = SearchVectorTask2::Create(table_name, std::vector<std::string>(), query_record_array,
query_range_array, topk, _return);
RequestScheduler::ExecTask(task_ptr);
}
void void
RequestHandler::SearchVectorInFiles(std::vector<::milvus::thrift::TopKQueryResult> &_return, RequestHandler::SearchVectorInFiles(std::vector<::milvus::thrift::TopKQueryResult> &_return,
const std::string& table_name, const std::string& table_name,
...@@ -73,7 +84,7 @@ RequestHandler::SearchVectorInFiles(std::vector<::milvus::thrift::TopKQueryResul ...@@ -73,7 +84,7 @@ RequestHandler::SearchVectorInFiles(std::vector<::milvus::thrift::TopKQueryResul
const std::vector<::milvus::thrift::Range> &query_range_array, const std::vector<::milvus::thrift::Range> &query_range_array,
const int64_t topk) { const int64_t topk) {
// SERVER_LOG_DEBUG << "Entering RequestHandler::SearchVectorInFiles. file_id_array size = " << std::to_string(file_id_array.size()); // SERVER_LOG_DEBUG << "Entering RequestHandler::SearchVectorInFiles. file_id_array size = " << std::to_string(file_id_array.size());
BaseTaskPtr task_ptr = SearchVectorTask::Create(table_name, file_id_array, query_record_array, BaseTaskPtr task_ptr = SearchVectorTask1::Create(table_name, file_id_array, query_record_array,
query_range_array, topk, _return); query_range_array, topk, _return);
RequestScheduler::ExecTask(task_ptr); RequestScheduler::ExecTask(task_ptr);
} }
......
...@@ -106,6 +106,29 @@ public: ...@@ -106,6 +106,29 @@ public:
const std::vector<::milvus::thrift::Range> & query_range_array, const std::vector<::milvus::thrift::Range> & query_range_array,
const int64_t topk); const int64_t topk);
/**
* @brief Query vector
*
* This method is used to query vector in table.
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param query_range_array, optional ranges for conditional search. If not specified, search whole table
* @param topk, how many similarity vectors will be searched.
*
* @return query binary result array.
*
* @param table_name
* @param query_record_array
* @param query_range_array
* @param topk
*/
void SearchVector2(std::vector<::milvus::thrift::TopKQueryBinResult> & _return,
const std::string& table_name,
const std::vector<::milvus::thrift::RowRecord> & query_record_array,
const std::vector<::milvus::thrift::Range> & query_range_array,
const int64_t topk);
/** /**
* @brief Internal use query interface * @brief Internal use query interface
* *
......
...@@ -466,33 +466,21 @@ ServerError AddVectorTask::OnExecute() { ...@@ -466,33 +466,21 @@ ServerError AddVectorTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(const std::string &table_name, SearchVectorTaskBase::SearchVectorTaskBase(const std::string &table_name,
const std::vector<std::string>& file_id_array, const std::vector<std::string>& file_id_array,
const std::vector<thrift::RowRecord> &query_record_array, const std::vector<thrift::RowRecord> &query_record_array,
const std::vector<thrift::Range> &query_range_array, const std::vector<thrift::Range> &query_range_array,
const int64_t top_k, const int64_t top_k)
std::vector<thrift::TopKQueryResult> &result_array)
: BaseTask(DQL_TASK_GROUP), : BaseTask(DQL_TASK_GROUP),
table_name_(table_name), table_name_(table_name),
file_id_array_(file_id_array), file_id_array_(file_id_array),
record_array_(query_record_array), record_array_(query_record_array),
range_array_(query_range_array), range_array_(query_range_array),
top_k_(top_k), top_k_(top_k) {
result_array_(result_array) {
}
BaseTaskPtr SearchVectorTask::Create(const std::string& table_name,
const std::vector<std::string>& file_id_array,
const std::vector<thrift::RowRecord> & query_record_array,
const std::vector<thrift::Range> & query_range_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult>& result_array) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(table_name, file_id_array,
query_record_array, query_range_array, top_k, result_array));
} }
ServerError SearchVectorTask::OnExecute() { ServerError SearchVectorTaskBase::OnExecute() {
try { try {
TimeRecorder rc("SearchVectorTask"); TimeRecorder rc("SearchVectorTask");
...@@ -570,26 +558,106 @@ ServerError SearchVectorTask::OnExecute() { ...@@ -570,26 +558,106 @@ ServerError SearchVectorTask::OnExecute() {
rc.Record("do search"); rc.Record("do search");
//step 5: construct result array //step 5: construct result array
for(uint64_t i = 0; i < record_count; i++) { ConstructResult(results);
auto& result = results[i]; rc.Record("construct result");
const auto& record = record_array_[i]; rc.Elapse("total cost");
thrift::TopKQueryResult thrift_topk_result; } catch (std::exception& ex) {
for(auto& pair : result) { return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
thrift::QueryResult thrift_result; }
thrift_result.__set_id(pair.first);
thrift_result.__set_distance(pair.second);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result); return SERVER_SUCCESS;
} }
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask1::SearchVectorTask1(const std::string &table_name,
const std::vector<std::string>& file_id_array,
const std::vector<thrift::RowRecord> &query_record_array,
const std::vector<thrift::Range> &query_range_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult> &result_array)
: SearchVectorTaskBase(table_name, file_id_array, query_record_array, query_range_array, top_k),
result_array_(result_array) {
}
BaseTaskPtr SearchVectorTask1::Create(const std::string& table_name,
const std::vector<std::string>& file_id_array,
const std::vector<thrift::RowRecord> & query_record_array,
const std::vector<thrift::Range> & query_range_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult>& result_array) {
return std::shared_ptr<BaseTask>(new SearchVectorTask1(table_name, file_id_array,
query_record_array, query_range_array, top_k, result_array));
}
ServerError SearchVectorTask1::ConstructResult(engine::QueryResults& results) {
for(uint64_t i = 0; i < results.size(); i++) {
auto& result = results[i];
const auto& record = record_array_[i];
thrift::TopKQueryResult thrift_topk_result;
for(auto& pair : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(pair.first);
thrift_result.__set_distance(pair.second);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
result_array_.emplace_back(thrift_topk_result);
}
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask2::SearchVectorTask2(const std::string &table_name,
const std::vector<std::string>& file_id_array,
const std::vector<thrift::RowRecord> &query_record_array,
const std::vector<thrift::Range> &query_range_array,
const int64_t top_k,
std::vector<thrift::TopKQueryBinResult> &result_array)
: SearchVectorTaskBase(table_name, file_id_array, query_record_array, query_range_array, top_k),
result_array_(result_array) {
}
BaseTaskPtr SearchVectorTask2::Create(const std::string& table_name,
const std::vector<std::string>& file_id_array,
const std::vector<thrift::RowRecord> & query_record_array,
const std::vector<thrift::Range> & query_range_array,
const int64_t top_k,
std::vector<thrift::TopKQueryBinResult>& result_array) {
return std::shared_ptr<BaseTask>(new SearchVectorTask2(table_name, file_id_array,
query_record_array, query_range_array, top_k, result_array));
}
ServerError SearchVectorTask2::ConstructResult(engine::QueryResults& results) {
for(size_t i = 0; i < results.size(); i++) {
auto& result = results[i];
thrift::TopKQueryBinResult thrift_topk_result;
if(result.empty()) {
result_array_.emplace_back(thrift_topk_result); result_array_.emplace_back(thrift_topk_result);
continue;
} }
rc.Record("construct result");
rc.Elapse("total cost");
} catch (std::exception& ex) { std::string str_ids, str_distances;
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); str_ids.resize(sizeof(engine::IDNumber)*result.size());
str_distances.resize(sizeof(double)*result.size());
engine::IDNumber* ids_ptr = (engine::IDNumber*)str_ids.data();
double* distance_ptr = (double*)str_distances.data();
for(size_t k = 0; k < results.size(); k++) {
auto& pair = result[k];
ids_ptr[k] = pair.first;
distance_ptr[k] = pair.second;
}
thrift_topk_result.__set_id_array(str_ids);
thrift_topk_result.__set_distance_array(str_distances);
result_array_.emplace_back(thrift_topk_result);
} }
return SERVER_SUCCESS; return SERVER_SUCCESS;
......
...@@ -129,7 +129,28 @@ private: ...@@ -129,7 +129,28 @@ private:
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchVectorTask : public BaseTask { class SearchVectorTaskBase : public BaseTask {
protected:
SearchVectorTaskBase(const std::string& table_name,
const std::vector<std::string>& file_id_array,
const std::vector<::milvus::thrift::RowRecord> & query_record_array,
const std::vector<::milvus::thrift::Range> & query_range_array,
const int64_t top_k);
ServerError OnExecute() override;
virtual ServerError ConstructResult(engine::QueryResults& results) = 0;
protected:
std::string table_name_;
std::vector<std::string> file_id_array_;
int64_t top_k_;
const std::vector<::milvus::thrift::RowRecord>& record_array_;
const std::vector<::milvus::thrift::Range>& range_array_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchVectorTask1 : public SearchVectorTaskBase {
public: public:
static BaseTaskPtr Create(const std::string& table_name, static BaseTaskPtr Create(const std::string& table_name,
const std::vector<std::string>& file_id_array, const std::vector<std::string>& file_id_array,
...@@ -139,24 +160,43 @@ public: ...@@ -139,24 +160,43 @@ public:
std::vector<::milvus::thrift::TopKQueryResult>& result_array); std::vector<::milvus::thrift::TopKQueryResult>& result_array);
protected: protected:
SearchVectorTask(const std::string& table_name, SearchVectorTask1(const std::string& table_name,
const std::vector<std::string>& file_id_array, const std::vector<std::string>& file_id_array,
const std::vector<::milvus::thrift::RowRecord> & query_record_array, const std::vector<::milvus::thrift::RowRecord> & query_record_array,
const std::vector<::milvus::thrift::Range> & query_range_array, const std::vector<::milvus::thrift::Range> & query_range_array,
const int64_t top_k, const int64_t top_k,
std::vector<::milvus::thrift::TopKQueryResult>& result_array); std::vector<::milvus::thrift::TopKQueryResult>& result_array);
ServerError OnExecute() override; ServerError ConstructResult(engine::QueryResults& results) override;
private: private:
std::string table_name_;
std::vector<std::string> file_id_array_;
int64_t top_k_;
const std::vector<::milvus::thrift::RowRecord>& record_array_;
const std::vector<::milvus::thrift::Range>& range_array_;
std::vector<::milvus::thrift::TopKQueryResult>& result_array_; std::vector<::milvus::thrift::TopKQueryResult>& result_array_;
}; };
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchVectorTask2 : public SearchVectorTaskBase {
public:
static BaseTaskPtr Create(const std::string& table_name,
const std::vector<std::string>& file_id_array,
const std::vector<::milvus::thrift::RowRecord> & query_record_array,
const std::vector<::milvus::thrift::Range> & query_range_array,
const int64_t top_k,
std::vector<::milvus::thrift::TopKQueryBinResult>& result_array);
protected:
SearchVectorTask2(const std::string& table_name,
const std::vector<std::string>& file_id_array,
const std::vector<::milvus::thrift::RowRecord> & query_record_array,
const std::vector<::milvus::thrift::Range> & query_range_array,
const int64_t top_k,
std::vector<::milvus::thrift::TopKQueryBinResult>& result_array);
ServerError ConstructResult(engine::QueryResults& results) override;
private:
std::vector<::milvus::thrift::TopKQueryBinResult>& result_array_;
};
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class GetTableRowCountTask : public BaseTask { class GetTableRowCountTask : public BaseTask {
public: public:
......
...@@ -104,6 +104,25 @@ class MilvusServiceIf { ...@@ -104,6 +104,25 @@ class MilvusServiceIf {
*/ */
virtual void SearchVector(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk) = 0; virtual void SearchVector(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk) = 0;
/**
* @brief Query vector
*
* This method is used to query vector in table.
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param query_range_array, optional ranges for conditional search. If not specified, search whole table
* @param topk, how many similarity vectors will be searched.
*
* @return query binary result array.
*
* @param table_name
* @param query_record_array
* @param query_range_array
* @param topk
*/
virtual void SearchVector2(std::vector<TopKQueryBinResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk) = 0;
/** /**
* @brief Internal use query interface * @brief Internal use query interface
* *
...@@ -218,6 +237,9 @@ class MilvusServiceNull : virtual public MilvusServiceIf { ...@@ -218,6 +237,9 @@ class MilvusServiceNull : virtual public MilvusServiceIf {
void SearchVector(std::vector<TopKQueryResult> & /* _return */, const std::string& /* table_name */, const std::vector<RowRecord> & /* query_record_array */, const std::vector<Range> & /* query_range_array */, const int64_t /* topk */) { void SearchVector(std::vector<TopKQueryResult> & /* _return */, const std::string& /* table_name */, const std::vector<RowRecord> & /* query_record_array */, const std::vector<Range> & /* query_range_array */, const int64_t /* topk */) {
return; return;
} }
void SearchVector2(std::vector<TopKQueryBinResult> & /* _return */, const std::string& /* table_name */, const std::vector<RowRecord> & /* query_record_array */, const std::vector<Range> & /* query_range_array */, const int64_t /* topk */) {
return;
}
void SearchVectorInFiles(std::vector<TopKQueryResult> & /* _return */, const std::string& /* table_name */, const std::vector<std::string> & /* file_id_array */, const std::vector<RowRecord> & /* query_record_array */, const std::vector<Range> & /* query_range_array */, const int64_t /* topk */) { void SearchVectorInFiles(std::vector<TopKQueryResult> & /* _return */, const std::string& /* table_name */, const std::vector<std::string> & /* file_id_array */, const std::vector<RowRecord> & /* query_record_array */, const std::vector<Range> & /* query_range_array */, const int64_t /* topk */) {
return; return;
} }
...@@ -912,6 +934,139 @@ class MilvusService_SearchVector_presult { ...@@ -912,6 +934,139 @@ class MilvusService_SearchVector_presult {
}; };
typedef struct _MilvusService_SearchVector2_args__isset {
_MilvusService_SearchVector2_args__isset() : table_name(false), query_record_array(false), query_range_array(false), topk(false) {}
bool table_name :1;
bool query_record_array :1;
bool query_range_array :1;
bool topk :1;
} _MilvusService_SearchVector2_args__isset;
class MilvusService_SearchVector2_args {
public:
MilvusService_SearchVector2_args(const MilvusService_SearchVector2_args&);
MilvusService_SearchVector2_args& operator=(const MilvusService_SearchVector2_args&);
MilvusService_SearchVector2_args() : table_name(), topk(0) {
}
virtual ~MilvusService_SearchVector2_args() throw();
std::string table_name;
std::vector<RowRecord> query_record_array;
std::vector<Range> query_range_array;
int64_t topk;
_MilvusService_SearchVector2_args__isset __isset;
void __set_table_name(const std::string& val);
void __set_query_record_array(const std::vector<RowRecord> & val);
void __set_query_range_array(const std::vector<Range> & val);
void __set_topk(const int64_t val);
bool operator == (const MilvusService_SearchVector2_args & rhs) const
{
if (!(table_name == rhs.table_name))
return false;
if (!(query_record_array == rhs.query_record_array))
return false;
if (!(query_range_array == rhs.query_range_array))
return false;
if (!(topk == rhs.topk))
return false;
return true;
}
bool operator != (const MilvusService_SearchVector2_args &rhs) const {
return !(*this == rhs);
}
bool operator < (const MilvusService_SearchVector2_args & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
};
class MilvusService_SearchVector2_pargs {
public:
virtual ~MilvusService_SearchVector2_pargs() throw();
const std::string* table_name;
const std::vector<RowRecord> * query_record_array;
const std::vector<Range> * query_range_array;
const int64_t* topk;
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
};
typedef struct _MilvusService_SearchVector2_result__isset {
_MilvusService_SearchVector2_result__isset() : success(false), e(false) {}
bool success :1;
bool e :1;
} _MilvusService_SearchVector2_result__isset;
class MilvusService_SearchVector2_result {
public:
MilvusService_SearchVector2_result(const MilvusService_SearchVector2_result&);
MilvusService_SearchVector2_result& operator=(const MilvusService_SearchVector2_result&);
MilvusService_SearchVector2_result() {
}
virtual ~MilvusService_SearchVector2_result() throw();
std::vector<TopKQueryBinResult> success;
Exception e;
_MilvusService_SearchVector2_result__isset __isset;
void __set_success(const std::vector<TopKQueryBinResult> & val);
void __set_e(const Exception& val);
bool operator == (const MilvusService_SearchVector2_result & rhs) const
{
if (!(success == rhs.success))
return false;
if (!(e == rhs.e))
return false;
return true;
}
bool operator != (const MilvusService_SearchVector2_result &rhs) const {
return !(*this == rhs);
}
bool operator < (const MilvusService_SearchVector2_result & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
};
typedef struct _MilvusService_SearchVector2_presult__isset {
_MilvusService_SearchVector2_presult__isset() : success(false), e(false) {}
bool success :1;
bool e :1;
} _MilvusService_SearchVector2_presult__isset;
class MilvusService_SearchVector2_presult {
public:
virtual ~MilvusService_SearchVector2_presult() throw();
std::vector<TopKQueryBinResult> * success;
Exception e;
_MilvusService_SearchVector2_presult__isset __isset;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
};
typedef struct _MilvusService_SearchVectorInFiles_args__isset { typedef struct _MilvusService_SearchVectorInFiles_args__isset {
_MilvusService_SearchVectorInFiles_args__isset() : table_name(false), file_id_array(false), query_record_array(false), query_range_array(false), topk(false) {} _MilvusService_SearchVectorInFiles_args__isset() : table_name(false), file_id_array(false), query_record_array(false), query_range_array(false), topk(false) {}
bool table_name :1; bool table_name :1;
...@@ -1531,6 +1686,9 @@ class MilvusServiceClient : virtual public MilvusServiceIf { ...@@ -1531,6 +1686,9 @@ class MilvusServiceClient : virtual public MilvusServiceIf {
void SearchVector(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk); void SearchVector(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
void send_SearchVector(const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk); void send_SearchVector(const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
void recv_SearchVector(std::vector<TopKQueryResult> & _return); void recv_SearchVector(std::vector<TopKQueryResult> & _return);
void SearchVector2(std::vector<TopKQueryBinResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
void send_SearchVector2(const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
void recv_SearchVector2(std::vector<TopKQueryBinResult> & _return);
void SearchVectorInFiles(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk); void SearchVectorInFiles(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
void send_SearchVectorInFiles(const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk); void send_SearchVectorInFiles(const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
void recv_SearchVectorInFiles(std::vector<TopKQueryResult> & _return); void recv_SearchVectorInFiles(std::vector<TopKQueryResult> & _return);
...@@ -1567,6 +1725,7 @@ class MilvusServiceProcessor : public ::apache::thrift::TDispatchProcessor { ...@@ -1567,6 +1725,7 @@ class MilvusServiceProcessor : public ::apache::thrift::TDispatchProcessor {
void process_BuildIndex(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext); void process_BuildIndex(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext);
void process_AddVector(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext); void process_AddVector(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext);
void process_SearchVector(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext); void process_SearchVector(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext);
void process_SearchVector2(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext);
void process_SearchVectorInFiles(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext); void process_SearchVectorInFiles(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext);
void process_DescribeTable(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext); void process_DescribeTable(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext);
void process_GetTableRowCount(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext); void process_GetTableRowCount(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext);
...@@ -1581,6 +1740,7 @@ class MilvusServiceProcessor : public ::apache::thrift::TDispatchProcessor { ...@@ -1581,6 +1740,7 @@ class MilvusServiceProcessor : public ::apache::thrift::TDispatchProcessor {
processMap_["BuildIndex"] = &MilvusServiceProcessor::process_BuildIndex; processMap_["BuildIndex"] = &MilvusServiceProcessor::process_BuildIndex;
processMap_["AddVector"] = &MilvusServiceProcessor::process_AddVector; processMap_["AddVector"] = &MilvusServiceProcessor::process_AddVector;
processMap_["SearchVector"] = &MilvusServiceProcessor::process_SearchVector; processMap_["SearchVector"] = &MilvusServiceProcessor::process_SearchVector;
processMap_["SearchVector2"] = &MilvusServiceProcessor::process_SearchVector2;
processMap_["SearchVectorInFiles"] = &MilvusServiceProcessor::process_SearchVectorInFiles; processMap_["SearchVectorInFiles"] = &MilvusServiceProcessor::process_SearchVectorInFiles;
processMap_["DescribeTable"] = &MilvusServiceProcessor::process_DescribeTable; processMap_["DescribeTable"] = &MilvusServiceProcessor::process_DescribeTable;
processMap_["GetTableRowCount"] = &MilvusServiceProcessor::process_GetTableRowCount; processMap_["GetTableRowCount"] = &MilvusServiceProcessor::process_GetTableRowCount;
...@@ -1670,6 +1830,16 @@ class MilvusServiceMultiface : virtual public MilvusServiceIf { ...@@ -1670,6 +1830,16 @@ class MilvusServiceMultiface : virtual public MilvusServiceIf {
return; return;
} }
void SearchVector2(std::vector<TopKQueryBinResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk) {
size_t sz = ifaces_.size();
size_t i = 0;
for (; i < (sz - 1); ++i) {
ifaces_[i]->SearchVector2(_return, table_name, query_record_array, query_range_array, topk);
}
ifaces_[i]->SearchVector2(_return, table_name, query_record_array, query_range_array, topk);
return;
}
void SearchVectorInFiles(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk) { void SearchVectorInFiles(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk) {
size_t sz = ifaces_.size(); size_t sz = ifaces_.size();
size_t i = 0; size_t i = 0;
...@@ -1767,6 +1937,9 @@ class MilvusServiceConcurrentClient : virtual public MilvusServiceIf { ...@@ -1767,6 +1937,9 @@ class MilvusServiceConcurrentClient : virtual public MilvusServiceIf {
void SearchVector(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk); void SearchVector(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
int32_t send_SearchVector(const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk); int32_t send_SearchVector(const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
void recv_SearchVector(std::vector<TopKQueryResult> & _return, const int32_t seqid); void recv_SearchVector(std::vector<TopKQueryResult> & _return, const int32_t seqid);
void SearchVector2(std::vector<TopKQueryBinResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
int32_t send_SearchVector2(const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
void recv_SearchVector2(std::vector<TopKQueryBinResult> & _return, const int32_t seqid);
void SearchVectorInFiles(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk); void SearchVectorInFiles(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
int32_t send_SearchVectorInFiles(const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk); int32_t send_SearchVectorInFiles(const std::string& table_name, const std::vector<std::string> & file_id_array, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk);
void recv_SearchVectorInFiles(std::vector<TopKQueryResult> & _return, const int32_t seqid); void recv_SearchVectorInFiles(std::vector<TopKQueryResult> & _return, const int32_t seqid);
......
...@@ -120,6 +120,28 @@ class MilvusServiceHandler : virtual public MilvusServiceIf { ...@@ -120,6 +120,28 @@ class MilvusServiceHandler : virtual public MilvusServiceIf {
printf("SearchVector\n"); printf("SearchVector\n");
} }
/**
* @brief Query vector
*
* This method is used to query vector in table.
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param query_range_array, optional ranges for conditional search. If not specified, search whole table
* @param topk, how many similarity vectors will be searched.
*
* @return query binary result array.
*
* @param table_name
* @param query_record_array
* @param query_range_array
* @param topk
*/
void SearchVector2(std::vector<TopKQueryBinResult> & _return, const std::string& table_name, const std::vector<RowRecord> & query_record_array, const std::vector<Range> & query_range_array, const int64_t topk) {
// Your implementation goes here
printf("SearchVector2\n");
}
/** /**
* @brief Internal use query interface * @brief Internal use query interface
* *
......
...@@ -781,4 +781,119 @@ void TopKQueryResult::printTo(std::ostream& out) const { ...@@ -781,4 +781,119 @@ void TopKQueryResult::printTo(std::ostream& out) const {
out << ")"; out << ")";
} }
TopKQueryBinResult::~TopKQueryBinResult() throw() {
}
void TopKQueryBinResult::__set_id_array(const std::string& val) {
this->id_array = val;
}
void TopKQueryBinResult::__set_distance_array(const std::string& val) {
this->distance_array = val;
}
std::ostream& operator<<(std::ostream& out, const TopKQueryBinResult& obj)
{
obj.printTo(out);
return out;
}
uint32_t TopKQueryBinResult::read(::apache::thrift::protocol::TProtocol* iprot) {
::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
uint32_t xfer = 0;
std::string fname;
::apache::thrift::protocol::TType ftype;
int16_t fid;
xfer += iprot->readStructBegin(fname);
using ::apache::thrift::protocol::TProtocolException;
bool isset_id_array = false;
bool isset_distance_array = false;
while (true)
{
xfer += iprot->readFieldBegin(fname, ftype, fid);
if (ftype == ::apache::thrift::protocol::T_STOP) {
break;
}
switch (fid)
{
case 1:
if (ftype == ::apache::thrift::protocol::T_STRING) {
xfer += iprot->readBinary(this->id_array);
isset_id_array = true;
} else {
xfer += iprot->skip(ftype);
}
break;
case 2:
if (ftype == ::apache::thrift::protocol::T_STRING) {
xfer += iprot->readBinary(this->distance_array);
isset_distance_array = true;
} else {
xfer += iprot->skip(ftype);
}
break;
default:
xfer += iprot->skip(ftype);
break;
}
xfer += iprot->readFieldEnd();
}
xfer += iprot->readStructEnd();
if (!isset_id_array)
throw TProtocolException(TProtocolException::INVALID_DATA);
if (!isset_distance_array)
throw TProtocolException(TProtocolException::INVALID_DATA);
return xfer;
}
uint32_t TopKQueryBinResult::write(::apache::thrift::protocol::TProtocol* oprot) const {
uint32_t xfer = 0;
::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
xfer += oprot->writeStructBegin("TopKQueryBinResult");
xfer += oprot->writeFieldBegin("id_array", ::apache::thrift::protocol::T_STRING, 1);
xfer += oprot->writeBinary(this->id_array);
xfer += oprot->writeFieldEnd();
xfer += oprot->writeFieldBegin("distance_array", ::apache::thrift::protocol::T_STRING, 2);
xfer += oprot->writeBinary(this->distance_array);
xfer += oprot->writeFieldEnd();
xfer += oprot->writeFieldStop();
xfer += oprot->writeStructEnd();
return xfer;
}
void swap(TopKQueryBinResult &a, TopKQueryBinResult &b) {
using ::std::swap;
swap(a.id_array, b.id_array);
swap(a.distance_array, b.distance_array);
}
TopKQueryBinResult::TopKQueryBinResult(const TopKQueryBinResult& other19) {
id_array = other19.id_array;
distance_array = other19.distance_array;
}
TopKQueryBinResult& TopKQueryBinResult::operator=(const TopKQueryBinResult& other20) {
id_array = other20.id_array;
distance_array = other20.distance_array;
return *this;
}
void TopKQueryBinResult::printTo(std::ostream& out) const {
using ::apache::thrift::to_string;
out << "TopKQueryBinResult(";
out << "id_array=" << to_string(id_array);
out << ", " << "distance_array=" << to_string(distance_array);
out << ")";
}
}} // namespace }} // namespace
...@@ -63,6 +63,8 @@ class QueryResult; ...@@ -63,6 +63,8 @@ class QueryResult;
class TopKQueryResult; class TopKQueryResult;
class TopKQueryBinResult;
typedef struct _Exception__isset { typedef struct _Exception__isset {
_Exception__isset() : code(false), reason(false) {} _Exception__isset() : code(false), reason(false) {}
bool code :1; bool code :1;
...@@ -346,6 +348,47 @@ void swap(TopKQueryResult &a, TopKQueryResult &b); ...@@ -346,6 +348,47 @@ void swap(TopKQueryResult &a, TopKQueryResult &b);
std::ostream& operator<<(std::ostream& out, const TopKQueryResult& obj); std::ostream& operator<<(std::ostream& out, const TopKQueryResult& obj);
class TopKQueryBinResult : public virtual ::apache::thrift::TBase {
public:
TopKQueryBinResult(const TopKQueryBinResult&);
TopKQueryBinResult& operator=(const TopKQueryBinResult&);
TopKQueryBinResult() : id_array(), distance_array() {
}
virtual ~TopKQueryBinResult() throw();
std::string id_array;
std::string distance_array;
void __set_id_array(const std::string& val);
void __set_distance_array(const std::string& val);
bool operator == (const TopKQueryBinResult & rhs) const
{
if (!(id_array == rhs.id_array))
return false;
if (!(distance_array == rhs.distance_array))
return false;
return true;
}
bool operator != (const TopKQueryBinResult &rhs) const {
return !(*this == rhs);
}
bool operator < (const TopKQueryBinResult & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
virtual void printTo(std::ostream& out) const;
};
void swap(TopKQueryBinResult &a, TopKQueryBinResult &b);
std::ostream& operator<<(std::ostream& out, const TopKQueryBinResult& obj);
}} // namespace }} // namespace
#endif #endif
...@@ -84,6 +84,14 @@ struct TopKQueryResult { ...@@ -84,6 +84,14 @@ struct TopKQueryResult {
1: list<QueryResult> query_result_arrays; ///< TopK query result 1: list<QueryResult> query_result_arrays; ///< TopK query result
} }
/**
* @brief TopK query binary result
*/
struct TopKQueryBinResult {
1: required binary id_array; ///< id array, interger array
2: required binary distance_array; ///< distance array, double array
}
service MilvusService { service MilvusService {
/** /**
* @brief Create table method * @brief Create table method
...@@ -158,6 +166,23 @@ service MilvusService { ...@@ -158,6 +166,23 @@ service MilvusService {
4: list<Range> query_range_array, 4: list<Range> query_range_array,
5: i64 topk) throws(1: Exception e); 5: i64 topk) throws(1: Exception e);
/**
* @brief Query vector
*
* This method is used to query vector in table.
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param query_range_array, optional ranges for conditional search. If not specified, search whole table
* @param topk, how many similarity vectors will be searched.
*
* @return query binary result array.
*/
list<TopKQueryBinResult> SearchVector2(2: string table_name,
3: list<RowRecord> query_record_array,
4: list<Range> query_range_array,
5: i64 topk) throws(1: Exception e);
/** /**
* @brief Internal use query interface * @brief Internal use query interface
* *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册