未验证 提交 2264aab0 编写于 作者: Y yukun 提交者: GitHub

Add new hybrid search api (#2445)

* Add json-string-dsl hybrid search api
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add C++ sdk for json-string-dsl hybrid search
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add C++ examples for new hybrid search api
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add unittest for new hybrid search api
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 353eb5e8
......@@ -168,8 +168,8 @@ class DB {
virtual Status
HybridQuery(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
const std::vector<std::string>& partition_tags, context::HybridSearchContextPtr hybrid_search_context,
query::GeneralQueryPtr general_query, std::vector<std::string>& field_name,
const std::vector<std::string>& partition_tags, query::GeneralQueryPtr general_query,
query::QueryPtr query_ptr, std::vector<std::string>& field_name,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result) = 0;
}; // DB
......
......@@ -1783,9 +1783,8 @@ DBImpl::QueryByIDs(const std::shared_ptr<server::Context>& context, const std::s
Status
DBImpl::HybridQuery(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
const std::vector<std::string>& partition_tags,
context::HybridSearchContextPtr hybrid_search_context, query::GeneralQueryPtr general_query,
std::vector<std::string>& field_names,
const std::vector<std::string>& partition_tags, query::GeneralQueryPtr general_query,
query::QueryPtr query_ptr, std::vector<std::string>& field_names,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result) {
auto query_ctx = context->Child("Query");
......@@ -1837,8 +1836,8 @@ DBImpl::HybridQuery(const std::shared_ptr<server::Context>& context, const std::
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = HybridQueryAsync(query_ctx, collection_id, files_holder, hybrid_search_context, general_query, field_names,
attr_type, result);
status = HybridQueryAsync(query_ctx, collection_id, files_holder, general_query, query_ptr, field_names, attr_type,
result);
if (!status.ok()) {
return status;
}
......@@ -1999,8 +1998,8 @@ DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, meta::FilesH
Status
DBImpl::HybridQueryAsync(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
meta::FilesHolder& files_holder, context::HybridSearchContextPtr hybrid_search_context,
query::GeneralQueryPtr general_query, std::vector<std::string>& field_names,
meta::FilesHolder& files_holder, query::GeneralQueryPtr general_query,
query::QueryPtr query_ptr, std::vector<std::string>& field_names,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result) {
auto query_async_ctx = context->Child("Query Async");
......@@ -2030,7 +2029,7 @@ DBImpl::HybridQueryAsync(const std::shared_ptr<server::Context>& context, const
milvus::engine::meta::SegmentsSchema& files = files_holder.HoldFiles();
LOG_ENGINE_DEBUG_ << LogOut("Engine query begin, index file count: %ld", files_holder.HoldFiles().size());
scheduler::SearchJobPtr job =
std::make_shared<scheduler::SearchJob>(query_async_ctx, general_query, attr_type, vectors);
std::make_shared<scheduler::SearchJob>(query_async_ctx, general_query, query_ptr, attr_type, vectors);
for (auto& file : files) {
scheduler::SegmentSchemaPtr file_ptr = std::make_shared<meta::SegmentSchema>(file);
job->AddIndexFile(file_ptr);
......
......@@ -160,8 +160,8 @@ class DBImpl : public DB, public server::CacheConfigHandler, public server::Engi
Status
HybridQuery(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
const std::vector<std::string>& partition_tags, context::HybridSearchContextPtr hybrid_search_context,
query::GeneralQueryPtr general_query, std::vector<std::string>& field_names,
const std::vector<std::string>& partition_tags, query::GeneralQueryPtr general_query,
query::QueryPtr query_ptr, std::vector<std::string>& field_names,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result) override;
......@@ -198,8 +198,8 @@ class DBImpl : public DB, public server::CacheConfigHandler, public server::Engi
Status
HybridQueryAsync(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
meta::FilesHolder& files_holder, context::HybridSearchContextPtr hybrid_search_context,
query::GeneralQueryPtr general_query, std::vector<std::string>& field_names,
meta::FilesHolder& files_holder, query::GeneralQueryPtr general_query, query::QueryPtr query_ptr,
std::vector<std::string>& field_names,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result);
......
......@@ -117,12 +117,11 @@ class ExecutionEngine {
virtual Status
ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset,
std::unordered_map<std::string, DataType>& attr_type,
milvus::query::VectorQueryPtr& vector_query) = 0;
std::unordered_map<std::string, DataType>& attr_type, std::string& vector_placeholder) = 0;
virtual Status
HybridSearch(query::GeneralQueryPtr general_query, std::unordered_map<std::string, DataType>& attr_type,
uint64_t& nq, uint64_t& topk, std::vector<float>& distances, std::vector<int64_t>& search_ids) = 0;
query::QueryPtr query_ptr, std::vector<float>& distances, std::vector<int64_t>& search_ids) = 0;
virtual Status
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
......
......@@ -787,12 +787,15 @@ ExecutionEngineImpl::ProcessRangeQuery(std::vector<T> data, T value, query::Comp
}
Status
ExecutionEngineImpl::HybridSearch(milvus::query::GeneralQueryPtr general_query,
std::unordered_map<std::string, DataType>& attr_type, uint64_t& nq, uint64_t& topk,
ExecutionEngineImpl::HybridSearch(query::GeneralQueryPtr general_query,
std::unordered_map<std::string, DataType>& attr_type, query::QueryPtr query_ptr,
std::vector<float>& distances, std::vector<int64_t>& search_ids) {
faiss::ConcurrentBitsetPtr bitset;
milvus::query::VectorQueryPtr vector_query;
auto status = ExecBinaryQuery(general_query, bitset, attr_type, vector_query);
std::string vector_placeholder;
auto status = ExecBinaryQuery(general_query, bitset, attr_type, vector_placeholder);
if (!status.ok()) {
return status;
}
// Do search
faiss::ConcurrentBitsetPtr list;
......@@ -804,8 +807,10 @@ ExecutionEngineImpl::HybridSearch(milvus::query::GeneralQueryPtr general_query,
}
}
index_->SetBlacklist(list);
topk = vector_query->topk;
nq = vector_query->query_vector.float_data.size() / dim_;
auto vector_query = query_ptr->vectors.at(vector_placeholder);
int64_t topk = vector_query->topk;
int64_t nq = vector_query->query_vector.float_data.size() / dim_;
distances.resize(nq * topk);
search_ids.resize(nq * topk);
......@@ -822,18 +827,18 @@ ExecutionEngineImpl::HybridSearch(milvus::query::GeneralQueryPtr general_query,
Status
ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset,
std::unordered_map<std::string, DataType>& attr_type,
milvus::query::VectorQueryPtr& vector_query) {
std::string& vector_placeholder) {
if (general_query->leaf == nullptr) {
Status status;
faiss::ConcurrentBitsetPtr left_bitset, right_bitset;
if (general_query->bin->left_query != nullptr) {
status = ExecBinaryQuery(general_query->bin->left_query, left_bitset, attr_type, vector_query);
status = ExecBinaryQuery(general_query->bin->left_query, left_bitset, attr_type, vector_placeholder);
if (!status.ok()) {
return status;
}
}
if (general_query->bin->right_query != nullptr) {
status = ExecBinaryQuery(general_query->bin->right_query, right_bitset, attr_type, vector_query);
status = ExecBinaryQuery(general_query->bin->right_query, right_bitset, attr_type, vector_placeholder);
if (!status.ok()) {
return status;
}
......@@ -1099,9 +1104,9 @@ ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_quer
}
return Status::OK();
}
if (general_query->leaf->vector_query != nullptr) {
if (general_query->leaf->vector_placeholder.size() > 0) {
// skip vector query
vector_query = general_query->leaf->vector_query;
vector_placeholder = general_query->leaf->vector_placeholder;
bitset = nullptr;
return Status::OK();
}
......
......@@ -71,13 +71,11 @@ class ExecutionEngineImpl : public ExecutionEngine {
Status
ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset,
std::unordered_map<std::string, DataType>& attr_type,
milvus::query::VectorQueryPtr& vector_query) override;
std::unordered_map<std::string, DataType>& attr_type, std::string& vector_placeholder) override;
Status
HybridSearch(query::GeneralQueryPtr general_query, std::unordered_map<std::string, DataType>& attr_type,
uint64_t& nq, uint64_t& topk, std::vector<float>& distances,
std::vector<int64_t>& search_ids) override;
query::QueryPtr query_ptr, std::vector<float>& distances, std::vector<int64_t>& search_ids) override;
Status
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
......
......@@ -54,6 +54,7 @@ static const char* MilvusService_method_names[] = {
"/milvus.grpc.MilvusService/ShowHybridCollectionInfo",
"/milvus.grpc.MilvusService/PreloadHybridCollection",
"/milvus.grpc.MilvusService/InsertEntity",
"/milvus.grpc.MilvusService/HybridSearchPB",
"/milvus.grpc.MilvusService/HybridSearch",
"/milvus.grpc.MilvusService/HybridSearchInSegments",
"/milvus.grpc.MilvusService/GetEntityByID",
......@@ -102,11 +103,12 @@ MilvusService::Stub::Stub(const std::shared_ptr< ::grpc::ChannelInterface>& chan
, rpcmethod_ShowHybridCollectionInfo_(MilvusService_method_names[31], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_PreloadHybridCollection_(MilvusService_method_names[32], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_InsertEntity_(MilvusService_method_names[33], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_HybridSearch_(MilvusService_method_names[34], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_HybridSearchInSegments_(MilvusService_method_names[35], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_GetEntityByID_(MilvusService_method_names[36], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_GetEntityIDs_(MilvusService_method_names[37], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_DeleteEntitiesByID_(MilvusService_method_names[38], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_HybridSearchPB_(MilvusService_method_names[34], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_HybridSearch_(MilvusService_method_names[35], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_HybridSearchInSegments_(MilvusService_method_names[36], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_GetEntityByID_(MilvusService_method_names[37], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_GetEntityIDs_(MilvusService_method_names[38], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_DeleteEntitiesByID_(MilvusService_method_names[39], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
{}
::grpc::Status MilvusService::Stub::CreateCollection(::grpc::ClientContext* context, const ::milvus::grpc::CollectionSchema& request, ::milvus::grpc::Status* response) {
......@@ -1061,6 +1063,34 @@ void MilvusService::Stub::experimental_async::InsertEntity(::grpc::ClientContext
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::HEntityIDs>::Create(channel_.get(), cq, rpcmethod_InsertEntity_, context, request, false);
}
::grpc::Status MilvusService::Stub::HybridSearchPB(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParamPB& request, ::milvus::grpc::HQueryResult* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_HybridSearchPB_, context, request, response);
}
void MilvusService::Stub::experimental_async::HybridSearchPB(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParamPB* request, ::milvus::grpc::HQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_HybridSearchPB_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::HybridSearchPB(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::HQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_HybridSearchPB_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::HybridSearchPB(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParamPB* request, ::milvus::grpc::HQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_HybridSearchPB_, context, request, response, reactor);
}
void MilvusService::Stub::experimental_async::HybridSearchPB(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::HQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_HybridSearchPB_, context, request, response, reactor);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::HQueryResult>* MilvusService::Stub::AsyncHybridSearchPBRaw(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParamPB& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::HQueryResult>::Create(channel_.get(), cq, rpcmethod_HybridSearchPB_, context, request, true);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::HQueryResult>* MilvusService::Stub::PrepareAsyncHybridSearchPBRaw(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParamPB& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::HQueryResult>::Create(channel_.get(), cq, rpcmethod_HybridSearchPB_, context, request, false);
}
::grpc::Status MilvusService::Stub::HybridSearch(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam& request, ::milvus::grpc::HQueryResult* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_HybridSearch_, context, request, response);
}
......@@ -1375,25 +1405,30 @@ MilvusService::Service::Service() {
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[34],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::HSearchParamPB, ::milvus::grpc::HQueryResult>(
std::mem_fn(&MilvusService::Service::HybridSearchPB), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[35],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::HSearchParam, ::milvus::grpc::HQueryResult>(
std::mem_fn(&MilvusService::Service::HybridSearch), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[35],
MilvusService_method_names[36],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::HSearchInSegmentsParam, ::milvus::grpc::TopKQueryResult>(
std::mem_fn(&MilvusService::Service::HybridSearchInSegments), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[36],
MilvusService_method_names[37],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::VectorsIdentity, ::milvus::grpc::HEntity>(
std::mem_fn(&MilvusService::Service::GetEntityByID), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[37],
MilvusService_method_names[38],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::HGetEntityIDsParam, ::milvus::grpc::HEntityIDs>(
std::mem_fn(&MilvusService::Service::GetEntityIDs), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[38],
MilvusService_method_names[39],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::HDeleteByIDParam, ::milvus::grpc::Status>(
std::mem_fn(&MilvusService::Service::DeleteEntitiesByID), this)));
......@@ -1640,6 +1675,13 @@ MilvusService::Service::~Service() {
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
::grpc::Status MilvusService::Service::HybridSearchPB(::grpc::ServerContext* context, const ::milvus::grpc::HSearchParamPB* request, ::milvus::grpc::HQueryResult* response) {
(void) context;
(void) request;
(void) response;
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
::grpc::Status MilvusService::Service::HybridSearch(::grpc::ServerContext* context, const ::milvus::grpc::HSearchParam* request, ::milvus::grpc::HQueryResult* response) {
(void) context;
(void) request;
......
......@@ -342,7 +342,20 @@ message GeneralQuery {
}
}
message VectorParam {
string json = 1;
repeated RowRecord row_record = 2;
}
message HSearchParam {
string collection_name = 1;
repeated string partition_tag_array = 2;
repeated VectorParam vector_param = 3;
string dsl = 4;
repeated KeyValuePair extra_params = 5;
}
message HSearchParamPB {
string collection_name = 1;
repeated string partition_tag_array = 2;
GeneralQuery general_query = 3;
......@@ -351,7 +364,7 @@ message HSearchParam {
message HSearchInSegmentsParam {
repeated string segment_id_array = 1;
HSearchParam search_param = 2;
HSearchParamPB search_param = 2;
}
///////////////////////////////////////////////////////////////////
......@@ -664,16 +677,18 @@ service MilvusService {
///////////////////////////////////////////////////////////////////
// rpc CreateIndex(IndexParam) returns (Status) {}
//
// rpc DescribeIndex(CollectionName) returns (IndexParam) {}
//
// rpc DropIndex(CollectionName) returns (Status) {}
// rpc CreateIndex(IndexParam) returns (Status) {}
//
// rpc DescribeIndex(CollectionName) returns (IndexParam) {}
//
// rpc DropIndex(CollectionName) returns (Status) {}
///////////////////////////////////////////////////////////////////
rpc InsertEntity(HInsertParam) returns (HEntityIDs) {}
rpc HybridSearchPB(HSearchParamPB) returns (HQueryResult) {}
rpc HybridSearch(HSearchParam) returns (HQueryResult) {}
rpc HybridSearchInSegments(HSearchInSegmentsParam) returns (TopKQueryResult) {}
......
......@@ -67,7 +67,7 @@ GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) {
// Put VectorQuery to the end of leaf queries
auto query_size = query->getLeafQueries().size();
for (uint64_t i = 0; i < query_size; ++i) {
if (query->getLeafQueries()[i]->vector_query != nullptr) {
if (query->getLeafQueries()[i]->vector_placeholder.size() > 0) {
std::swap(query->getLeafQueries()[i], query->getLeafQueries()[0]);
break;
}
......
......@@ -14,6 +14,7 @@
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "utils/Json.h"
......@@ -92,7 +93,7 @@ using GeneralQueryPtr = std::shared_ptr<GeneralQuery>;
struct LeafQuery {
TermQueryPtr term_query;
RangeQueryPtr range_query;
VectorQueryPtr vector_query;
std::string vector_placeholder;
float query_boost;
};
......@@ -103,5 +104,11 @@ struct BinaryQuery {
float query_boost;
};
struct Query {
BinaryQueryPtr root;
std::unordered_map<std::string, VectorQueryPtr> vectors;
};
using QueryPtr = std::shared_ptr<Query>;
} // namespace query
} // namespace milvus
......@@ -22,9 +22,15 @@ SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t t
}
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, milvus::query::GeneralQueryPtr general_query,
query::QueryPtr query_ptr,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
const engine::VectorsData& vectors)
: Job(JobType::SEARCH), context_(context), general_query_(general_query), attr_type_(attr_type), vectors_(vectors) {
: Job(JobType::SEARCH),
context_(context),
general_query_(general_query),
query_ptr_(query_ptr),
attr_type_(attr_type),
vectors_(vectors) {
}
bool
......
......@@ -46,7 +46,7 @@ class SearchJob : public Job {
const engine::VectorsData& vectors);
SearchJob(const std::shared_ptr<server::Context>& context, query::GeneralQueryPtr general_query,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
query::QueryPtr query_ptr, std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
const engine::VectorsData& vectorsData);
public:
......@@ -110,6 +110,11 @@ class SearchJob : public Job {
return general_query_;
}
query::QueryPtr
query_ptr() {
return query_ptr_;
}
std::unordered_map<std::string, engine::meta::hybrid::DataType>&
attr_type() {
return attr_type_;
......@@ -135,6 +140,7 @@ class SearchJob : public Job {
Status status_;
query::GeneralQueryPtr general_query_;
query::QueryPtr query_ptr_;
std::unordered_map<std::string, engine::meta::hybrid::DataType> attr_type_;
uint64_t vector_count_;
......
......@@ -267,7 +267,13 @@ XSearchTask::Execute() {
for (; type_it != attr_type.end(); type_it++) {
types.insert(std::make_pair(type_it->first, (engine::DataType)(type_it->second)));
}
s = index_engine_->HybridSearch(general_query, types, nq, topk, output_distance, output_ids);
auto query_ptr = search_job->query_ptr();
s = index_engine_->HybridSearch(general_query, types, query_ptr, output_distance, output_ids);
auto vector_query = query_ptr->vectors.begin()->second;
topk = vector_query->topk;
nq = vector_query->query_vector.float_data.size() / file_->dimension_;
if (!s.ok()) {
search_job->GetStatus() = s;
......
......@@ -316,14 +316,12 @@ RequestHandler::GetEntityByID(const std::shared_ptr<Context>& context, const std
}
Status
RequestHandler::HybridSearch(const std::shared_ptr<Context>& context,
context::HybridSearchContextPtr hybrid_search_context, const std::string& collection_name,
std::vector<std::string>& partition_list, milvus::query::GeneralQueryPtr& general_query,
milvus::json& json_params, std::vector<std::string>& field_names,
engine::QueryResult& result) {
BaseRequestPtr request_ptr =
HybridSearchRequest::Create(context, hybrid_search_context, collection_name, partition_list, general_query,
json_params, field_names, result);
RequestHandler::HybridSearch(const std::shared_ptr<Context>& context, const std::string& collection_name,
std::vector<std::string>& partition_list, query::GeneralQueryPtr& general_query,
query::QueryPtr& query_ptr, milvus::json& json_params,
std::vector<std::string>& field_names, engine::QueryResult& result) {
BaseRequestPtr request_ptr = HybridSearchRequest::Create(context, collection_name, partition_list, general_query,
query_ptr, json_params, field_names, result);
RequestScheduler::ExecRequest(request_ptr);
......
......@@ -146,10 +146,10 @@ class RequestHandler {
std::vector<engine::VectorsData>& vectors);
Status
HybridSearch(const std::shared_ptr<Context>& context, context::HybridSearchContextPtr hybrid_search_context,
const std::string& collection_name, std::vector<std::string>& partition_list,
query::GeneralQueryPtr& general_query, milvus::json& json_params,
std::vector<std::string>& field_names, engine::QueryResult& result);
HybridSearch(const std::shared_ptr<Context>& context, const std::string& collection_name,
std::vector<std::string>& partition_list, query::GeneralQueryPtr& general_query,
query::QueryPtr& query_ptr, milvus::json& json_params, std::vector<std::string>& field_names,
engine::QueryResult& result);
};
} // namespace server
......
......@@ -31,28 +31,26 @@ namespace milvus {
namespace server {
HybridSearchRequest::HybridSearchRequest(const std::shared_ptr<milvus::server::Context>& context,
context::HybridSearchContextPtr& hybrid_search_context,
const std::string& collection_name, std::vector<std::string>& partition_list,
milvus::query::GeneralQueryPtr& general_query, milvus::json& json_params,
std::vector<std::string>& field_names, engine::QueryResult& result)
query::GeneralQueryPtr& general_query, query::QueryPtr& query_ptr,
milvus::json& json_params, std::vector<std::string>& field_names,
engine::QueryResult& result)
: BaseRequest(context, BaseRequest::kHybridSearch),
hybrid_search_context_(hybrid_search_context),
collection_name_(collection_name),
partition_list_(partition_list),
general_query_(general_query),
query_ptr_(query_ptr),
field_names_(field_names),
result_(result) {
}
BaseRequestPtr
HybridSearchRequest::Create(const std::shared_ptr<milvus::server::Context>& context,
context::HybridSearchContextPtr& hybrid_search_context, const std::string& collection_name,
std::vector<std::string>& partition_list, milvus::query::GeneralQueryPtr& general_query,
milvus::json& json_params, std::vector<std::string>& field_names,
engine::QueryResult& result) {
return std::shared_ptr<BaseRequest>(new HybridSearchRequest(context, hybrid_search_context, collection_name,
partition_list, general_query, json_params, field_names,
result));
HybridSearchRequest::Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
std::vector<std::string>& partition_list, query::GeneralQueryPtr& general_query,
query::QueryPtr& query_ptr, milvus::json& json_params,
std::vector<std::string>& field_names, engine::QueryResult& result) {
return std::shared_ptr<BaseRequest>(new HybridSearchRequest(context, collection_name, partition_list, general_query,
query_ptr, json_params, field_names, result));
}
Status
......@@ -106,8 +104,8 @@ HybridSearchRequest::OnExecute() {
}
}
status = DBWrapper::DB()->HybridQuery(context_, collection_name_, partition_list_, hybrid_search_context_,
general_query_, field_names_, attr_type, result_);
status = DBWrapper::DB()->HybridQuery(context_, collection_name_, partition_list_, general_query_, query_ptr_,
field_names_, attr_type, result_);
#ifdef ENABLE_CPU_PROFILING
ProfilerStop();
......
......@@ -24,25 +24,24 @@ namespace server {
class HybridSearchRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<milvus::server::Context>& context,
context::HybridSearchContextPtr& hybrid_search_context, const std::string& collection_name,
std::vector<std::string>& partition_list, milvus::query::GeneralQueryPtr& general_query,
Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
std::vector<std::string>& partition_list, query::GeneralQueryPtr& general_query, query::QueryPtr& query_ptr,
milvus::json& json_params, std::vector<std::string>& field_names, engine::QueryResult& result);
protected:
HybridSearchRequest(const std::shared_ptr<milvus::server::Context>& context,
context::HybridSearchContextPtr& hybrid_search_context, const std::string& collection_name,
std::vector<std::string>& partition_list, milvus::query::GeneralQueryPtr& general_query,
milvus::json& json_params, std::vector<std::string>& field_names, engine::QueryResult& result);
HybridSearchRequest(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
std::vector<std::string>& partition_list, query::GeneralQueryPtr& general_query,
query::QueryPtr& query_ptr, milvus::json& json_params, std::vector<std::string>& field_names,
engine::QueryResult& result);
Status
OnExecute() override;
private:
context::HybridSearchContextPtr hybrid_search_context_;
const std::string collection_name_;
std::vector<std::string> partition_list_;
milvus::query::GeneralQueryPtr general_query_;
milvus::query::QueryPtr query_ptr_;
milvus::json json_params;
std::vector<std::string>& field_names_;
engine::QueryResult& result_;
......
......@@ -360,6 +360,10 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service,
InsertEntity(::grpc::ServerContext* context, const ::milvus::grpc::HInsertParam* request,
::milvus::grpc::HEntityIDs* response) override;
::grpc::Status
HybridSearchPB(::grpc::ServerContext* context, const ::milvus::grpc::HSearchParamPB* request,
::milvus::grpc::HQueryResult* response) override;
::grpc::Status
HybridSearch(::grpc::ServerContext* context, const ::milvus::grpc::HSearchParam* request,
::milvus::grpc::HQueryResult* response) override;
......@@ -391,12 +395,24 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service,
request_handler_ = handler;
}
Status
DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params,
const std::string& dsl_string, query::BooleanQueryPtr& boolean_query,
std::unordered_map<std::string, query::VectorQueryPtr>& query_ptr);
Status
ProcessBooleanQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query);
Status
ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& query);
private:
RequestHandler request_handler_;
// std::unordered_map<::grpc::ServerContext*, std::shared_ptr<Context>> context_map_;
std::unordered_map<std::string, std::shared_ptr<Context>> context_map_;
std::shared_ptr<opentracing::Tracer> tracer_;
std::unordered_map<std::string, engine::meta::hybrid::DataType> field_type_;
// std::unordered_map<::grpc::ServerContext*, std::unique_ptr<opentracing::Span>> span_map_;
mutable std::mt19937_64 random_num_generator_;
......
......@@ -611,7 +611,10 @@ WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::quer
vector_query->topk = vector_json["topk"].get<int64_t>();
vector_query->extra_params = vector_json["extra_params"];
leaf_query->vector_query = vector_query;
// TODO(yukun): remove hardcode here
std::string vector_placeholder = "placeholder_1";
query_ptr_->vectors.insert(std::make_pair(vector_placeholder, vector_query));
leaf_query->vector_placeholder = vector_placeholder;
query->AddLeafQuery(leaf_query);
}
return Status::OK();
......@@ -814,20 +817,22 @@ WebRequestHandler::HybridSearch(const std::string& collection_name, const nlohma
if (query_json.contains("bool")) {
auto boolean_query_json = query_json["bool"];
query::BooleanQueryPtr boolean_query = std::make_shared<query::BooleanQuery>();
auto boolean_query = std::make_shared<query::BooleanQuery>();
query_ptr_ = std::make_shared<query::Query>();
status = ProcessBoolQueryJson(boolean_query_json, boolean_query);
if (!status.ok()) {
return status;
}
query::GeneralQueryPtr general_query = std::make_shared<query::GeneralQuery>();
auto general_query = std::make_shared<query::GeneralQuery>();
query::GenBinaryQuery(boolean_query, general_query->bin);
context::HybridSearchContextPtr hybrid_search_context = std::make_shared<context::HybridSearchContext>();
query_ptr_->root = general_query->bin;
engine::QueryResult result;
std::vector<std::string> field_names;
status = request_handler_.HybridSearch(context_ptr_, hybrid_search_context, collection_name, partition_tags,
general_query, extra_params, field_names, result);
status = request_handler_.HybridSearch(context_ptr_, collection_name, partition_tags, general_query, query_ptr_,
extra_params, field_names, result);
if (!status.ok()) {
return status;
......
......@@ -250,6 +250,7 @@ class WebRequestHandler {
private:
std::shared_ptr<Context> context_ptr_;
RequestHandler request_handler_;
query::QueryPtr query_ptr_;
std::unordered_map<std::string, engine::meta::hybrid::DataType> field_type_;
};
......
......@@ -84,6 +84,7 @@ constexpr ErrorCode SERVER_INVALID_INDEX_FILE_SIZE = ToServerErrorCode(116);
constexpr ErrorCode SERVER_OUT_OF_MEMORY = ToServerErrorCode(117);
constexpr ErrorCode SERVER_INVALID_PARTITION_TAG = ToServerErrorCode(118);
constexpr ErrorCode SERVER_INVALID_BINARY_QUERY = ToServerErrorCode(119);
constexpr ErrorCode SERVER_INVALID_DSL_PARAMETER = ToServerErrorCode(120);
// db error code
constexpr ErrorCode DB_META_TRANSACTION_FAILED = ToDbErrorCode(1);
......
......@@ -116,7 +116,7 @@ BuildEntity(uint64_t n, uint64_t batch_index, milvus::engine::Entity& entity) {
}
void
ConstructGeneralQuery(milvus::query::GeneralQueryPtr& general_query) {
ConstructGeneralQuery(milvus::query::GeneralQueryPtr& general_query, milvus::query::QueryPtr& query_ptr) {
general_query->bin->relation = milvus::query::QueryRelation::AND;
general_query->bin->left_query = std::make_shared<milvus::query::GeneralQuery>();
general_query->bin->right_query = std::make_shared<milvus::query::GeneralQuery>();
......@@ -167,7 +167,12 @@ ConstructGeneralQuery(milvus::query::GeneralQueryPtr& general_query) {
left->bin->right_query->leaf->range_query = range_query;
right->leaf = std::make_shared<milvus::query::LeafQuery>();
right->leaf->vector_query = vector_query;
std::string vector_placeholder = "placeholder_1";
right->leaf->vector_placeholder = vector_placeholder;
query_ptr->root = general_query->bin;
query_ptr->vectors.insert(std::make_pair(vector_placeholder, vector_query));
}
} // namespace
......@@ -234,14 +239,14 @@ TEST_F(DBTest, HYBRID_SEARCH_TEST) {
ASSERT_TRUE(stat.ok());
// Construct general query
milvus::query::GeneralQueryPtr general_query = std::make_shared<milvus::query::GeneralQuery>();
ConstructGeneralQuery(general_query);
auto general_query = std::make_shared<milvus::query::GeneralQuery>();
auto query_ptr = std::make_shared<milvus::query::Query>();
ConstructGeneralQuery(general_query, query_ptr);
std::vector<std::string> tags;
milvus::context::HybridSearchContextPtr hybrid_context = std::make_shared<milvus::context::HybridSearchContext>();
milvus::engine::QueryResult result;
stat = db_->HybridQuery(dummy_context_, COLLECTION_NAME, tags, hybrid_context, general_query, field_names,
attr_type, result);
stat = db_->HybridQuery(dummy_context_, COLLECTION_NAME, tags, general_query, query_ptr, field_names, attr_type,
result);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result.row_num_, NQ);
ASSERT_EQ(result.result_ids_.size(), NQ * TOPK);
......
......@@ -1026,7 +1026,7 @@ TEST_F(RpcHandlerTest, HYBRID_TEST) {
uint64_t nq = 10;
uint64_t topk = 10;
milvus::grpc::HSearchParam search_param;
milvus::grpc::HSearchParamPB search_param;
auto general_query = search_param.mutable_general_query();
auto boolean_query_1 = general_query->mutable_boolean_query();
boolean_query_1->set_occur(milvus::grpc::Occur::MUST);
......@@ -1069,7 +1069,46 @@ TEST_F(RpcHandlerTest, HYBRID_TEST) {
search_extra_param->set_value("");
milvus::grpc::HQueryResult topk_query_result;
handler->HybridSearch(&context, &search_param, &topk_query_result);
handler->HybridSearchPB(&context, &search_param, &topk_query_result);
// Test new HybridSearch
milvus::grpc::HSearchParam new_search_param;
new_search_param.set_collection_name("test_hybrid");
nlohmann::json dsl_json, bool_json, term_json, range_json, vector_json;
term_json["term"]["field_name"] = "field_0";
term_json["term"]["values"] = term_value;
bool_json["must"].push_back(term_json);
range_json["range"]["field_name"] = "field_0";
nlohmann::json comp_json;
comp_json["gte"] = "0";
comp_json["lte"] = "100000";
range_json["range"]["values"] = comp_json;
bool_json["must"].push_back(range_json);
std::string placeholder = "placeholder_1";
vector_json["vector"] = placeholder;
bool_json["must"].push_back(vector_json);
dsl_json["bool"] = bool_json;
nlohmann::json vector_param_json, vector_extra_params;
vector_param_json[placeholder]["field_name"] = "field_1";
vector_param_json[placeholder]["topk"] = topk;
vector_extra_params["nprobe"] = 64;
vector_param_json[placeholder]["params"] = vector_extra_params;
new_search_param.set_dsl(dsl_json.dump());
auto vector_param = new_search_param.add_vector_param();
for (auto record : query_vector) {
auto row_record = vector_param->add_row_record();
CopyRowRecord(row_record, record);
}
vector_param->set_json(vector_param_json.dump());
milvus::grpc::HQueryResult new_query_result;
handler->HybridSearch(&context, &new_search_param, &new_query_result);
}
//////////////////////////////////////////////////////////////////////
......
......@@ -18,15 +18,15 @@
#include <unistd.h>
#include <iostream>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include <unordered_map>
namespace {
const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str();
constexpr int64_t COLLECTION_DIMENSION = 512;
constexpr int64_t COLLECTION_DIMENSION = 128;
constexpr int64_t COLLECTION_INDEX_FILE_SIZE = 1024;
constexpr milvus::MetricType COLLECTION_METRIC_TYPE = milvus::MetricType::L2;
constexpr int64_t BATCH_ENTITY_COUNT = 100000;
......@@ -43,9 +43,8 @@ void
PrintHybridQueryResult(const std::vector<int64_t>& id_array, const milvus::HybridQueryResult& result) {
for (size_t i = 0; i < id_array.size(); i++) {
std::string prefix = "No." + std::to_string(i) + " id:" + std::to_string(id_array[i]);
std::cout<< prefix << "\t[";
std::cout << prefix << "\t[";
for (size_t j = 0; j < result.attr_records.size(); i++) {
}
}
}
......@@ -127,7 +126,7 @@ ClientTest::InsertHybridEntities(std::string& collection_name, int64_t row_num)
}
void
ClientTest::HybridSearch(std::string& collection_name) {
ClientTest::HybridSearchPB(std::string& collection_name) {
std::vector<std::string> partition_tags;
milvus::TopKHybridQueryResult topk_query_result;
......@@ -144,27 +143,29 @@ ClientTest::HybridSearch(std::string& collection_name) {
std::string extra_params;
milvus::Status status =
conn_->HybridSearch(collection_name, partition_tags, query_clause, extra_params, topk_query_result);
for (uint64_t i = 0; i < topk_query_result.size(); i++) {
for (auto attr : topk_query_result[i].attr_records) {
std::cout << "Field: " << attr.first << std::endl;
if (attr.second.int_record.size() > 0) {
for (auto record : attr.second.int_record) {
std::cout << record << "\t";
}
} else if (attr.second.double_record.size() > 0) {
for (auto record : attr.second.double_record) {
std::cout << record << "\t";
}
}
std::cout << std::endl;
}
}
conn_->HybridSearchPB(collection_name, partition_tags, query_clause, extra_params, topk_query_result);
for (uint64_t i = 0; i < topk_query_result.size(); ++i) {
std::cout << topk_query_result[i].ids[1] << " --------- " << topk_query_result[i].distances[1] << std::endl;
milvus_sdk::Utils::PrintTopKHybridQueryResult(topk_query_result);
std::cout << "HybridSearch function call status: " << status.message() << std::endl;
}
void
ClientTest::HybridSearch(std::string& collection_name) {
nlohmann::json dsl_json, vector_param_json;
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json);
std::vector<milvus::Entity> entity_array;
std::vector<int64_t> record_ids;
{ // generate vectors
milvus_sdk::Utils::ConstructVector(NQ, COLLECTION_DIMENSION, entity_array);
}
std::vector<std::string> partition_tags;
milvus::TopKHybridQueryResult topk_query_result;
auto status = conn_->HybridSearch(collection_name, partition_tags, dsl_json.dump(), vector_param_json.dump(),
entity_array, topk_query_result);
milvus_sdk::Utils::PrintTopKHybridQueryResult(topk_query_result);
std::cout << "HybridSearch function call status: " << status.message() << std::endl;
}
......@@ -187,5 +188,6 @@ ClientTest::TestHybrid() {
InsertHybridEntities(collection_name, 10000);
Flush(collection_name);
sleep(2);
// HybridSearchPB(collection_name);
HybridSearch(collection_name);
}
......@@ -36,6 +36,9 @@ class ClientTest {
void
InsertHybridEntities(std::string&, int64_t);
void
HybridSearchPB(std::string&);
void
HybridSearch(std::string&);
......
此差异已折叠。
......@@ -11,8 +11,8 @@
#pragma once
#include "MilvusApi.h"
#include "BooleanQuery.h"
#include "MilvusApi.h"
#include "thirdparty/nlohmann/json.hpp"
#include <memory>
......@@ -54,8 +54,8 @@ class Utils {
PrintIndexParam(const milvus::IndexParam& index_param);
static void
BuildEntities(int64_t from, int64_t to, std::vector<milvus::Entity>& entity_array,
std::vector<int64_t>& entity_ids, int64_t dimension);
BuildEntities(int64_t from, int64_t to, std::vector<milvus::Entity>& entity_array, std::vector<int64_t>& entity_ids,
int64_t dimension);
static void
PrintSearchResult(const std::vector<std::pair<int64_t, milvus::Entity>>& entity_array,
......@@ -71,8 +71,17 @@ class Utils {
const std::vector<std::pair<int64_t, milvus::Entity>>& entity_array,
milvus::TopKQueryResult& topk_query_result);
static void
ConstructVector(uint64_t nq, uint64_t dimension, std::vector<milvus::Entity>& query_vector);
static std::vector<milvus::LeafQueryPtr>
GenLeafQuery();
static void
GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json);
static void
PrintTopKHybridQueryResult(milvus::TopKHybridQueryResult& topk_query_result);
};
} // namespace milvus_sdk
此差异已折叠。
此差异已折叠。
......@@ -67,8 +67,7 @@ class ClientProxy : public Connection {
const std::vector<Entity>& entity_array, std::vector<int64_t>& id_array) override;
Status
GetEntityByID(const std::string& collection_name,
const std::vector<int64_t>& id_array,
GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array,
std::vector<Entity>& entities_data) override;
Status
......@@ -131,10 +130,15 @@ class ClientProxy : public Connection {
InsertEntity(const std::string& collection_name, const std::string& partition_tag, HEntity& entities,
std::vector<uint64_t>& id_array) override;
Status
HybridSearchPB(const std::string& collection_name, const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query, const std::string& extra_params,
TopKHybridQueryResult& topk_query_result) override;
Status
HybridSearch(const std::string& collection_name, const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query, const std::string& extra_params,
TopKHybridQueryResult& topk_query_result) override;
const std::string& dsl, const std::string& vector_param, const std::vector<Entity>& entity_array,
TopKHybridQueryResult& query_result) override;
Status
GetHEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array,
......
此差异已折叠。
......@@ -114,6 +114,9 @@ class GrpcClient {
Status
InsertEntities(milvus::grpc::HInsertParam& entities, milvus::grpc::HEntityIDs& ids);
Status
HybridSearchPB(milvus::grpc::HSearchParamPB& search_param, milvus::grpc::HQueryResult& result);
Status
HybridSearch(milvus::grpc::HSearchParam& search_param, milvus::grpc::HQueryResult& result);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册