未验证 提交 67d8a9b5 编写于 作者: Y yukun 提交者: GitHub

Add http interface for hybrid search (#2079)

* Add http interface for hybrid search
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add unittest for http hybrid interface
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* clang format
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix codacy quality
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 6d87dc3c
......@@ -40,6 +40,7 @@
#include "server/delivery/request/ShowPartitionsRequest.h"
#include "server/delivery/hybrid_request/CreateHybridCollectionRequest.h"
#include "server/delivery/hybrid_request/DescribeHybridCollectionRequest.h"
#include "server/delivery/hybrid_request/HybridSearchRequest.h"
#include "server/delivery/hybrid_request/InsertEntityRequest.h"
......@@ -266,6 +267,15 @@ RequestHandler::CreateHybridCollection(const std::shared_ptr<Context>& context,
return request_ptr->status();
}
Status
RequestHandler::DescribeHybridCollection(const std::shared_ptr<Context>& context, const std::string& collection_name,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& field_types) {
BaseRequestPtr request_ptr = DescribeHybridCollectionRequest::Create(context, collection_name, field_types);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
}
Status
RequestHandler::HasHybridCollection(const std::shared_ptr<Context>& context, std::string& collection_name,
bool& has_collection) {
......
......@@ -121,6 +121,10 @@ class RequestHandler {
std::vector<std::pair<std::string, uint64_t>>& vector_dimensions,
std::vector<std::pair<std::string, std::string>>& field_extra_params);
Status
DescribeHybridCollection(const std::shared_ptr<Context>& context, const std::string& collection_name,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& field_types);
Status
HasHybridCollection(const std::shared_ptr<Context>& context, std::string& collection_name, bool& has_collection);
......
......@@ -49,6 +49,7 @@ RequestGroup(BaseRequest::RequestType type) {
{BaseRequest::kDropCollection, DDL_DML_REQUEST_GROUP},
{BaseRequest::kPreloadCollection, DQL_REQUEST_GROUP},
{BaseRequest::kCreateHybridCollection, DDL_DML_REQUEST_GROUP},
{BaseRequest::kDescribeHybridCollection, INFO_REQUEST_GROUP},
// partition operations
{BaseRequest::kCreatePartition, DDL_DML_REQUEST_GROUP},
......
......@@ -799,6 +799,13 @@ GrpcRequestHandler::CreateHybridCollection(::grpc::ServerContext* context, const
return ::grpc::Status::OK;
}
::grpc::Status
GrpcRequestHandler::DescribeHybridCollection(::grpc::ServerContext* context,
const ::milvus::grpc::CollectionName* request,
::milvus::grpc::Mapping* response) {
CHECK_NULLPTR_RETURN(request);
}
::grpc::Status
GrpcRequestHandler::InsertEntity(::grpc::ServerContext* context, const ::milvus::grpc::HInsertParam* request,
::milvus::grpc::HEntityIDs* response) {
......@@ -916,7 +923,6 @@ GrpcRequestHandler::HybridSearch(::grpc::ServerContext* context, const ::milvus:
DeSerialization(request->general_query(), boolean_query);
query::GeneralQueryPtr general_query = std::make_shared<query::GeneralQuery>();
general_query->bin = std::make_shared<query::BinaryQuery>();
query::GenBinaryQuery(boolean_query, general_query->bin);
Status status;
......
......@@ -320,10 +320,9 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service,
// const ::milvus::grpc::CollectionName* request,
// ::milvus::grpc::Status* response) override;
//
// ::grpc::Status
// DescribeCollection(::grpc::ServerContext* context,
// const ::milvus::grpc::CollectionName* request,
// ::milvus::grpc::Mapping* response) override;
::grpc::Status
DescribeHybridCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
::milvus::grpc::Mapping* response) override;
//
// ::grpc::Status
// CountCollection(::grpc::ServerContext* context,
......
......@@ -11,13 +11,13 @@
#pragma once
#include <string>
#include <iostream>
#include <string>
#include <oatpp/web/server/api/ApiController.hpp>
#include <oatpp/parser/json/mapping/ObjectMapper.hpp>
#include <oatpp/core/macro/codegen.hpp>
#include <oatpp/core/macro/component.hpp>
#include <oatpp/parser/json/mapping/ObjectMapper.hpp>
#include <oatpp/web/server/api/ApiController.hpp>
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
......@@ -37,11 +37,12 @@ namespace web {
class WebController : public oatpp::web::server::api::ApiController {
public:
WebController(const std::shared_ptr<ObjectMapper>& objectMapper)
: oatpp::web::server::api::ApiController(objectMapper) {}
: oatpp::web::server::api::ApiController(objectMapper) {
}
public:
static std::shared_ptr<WebController> createShared(
OATPP_COMPONENT(std::shared_ptr<ObjectMapper>, objectMapper)) {
static std::shared_ptr<WebController>
createShared(OATPP_COMPONENT(std::shared_ptr<ObjectMapper>, objectMapper)) {
return std::make_shared<WebController>(objectMapper);
}
......@@ -84,8 +85,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
......@@ -115,8 +116,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
......@@ -139,8 +140,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
......@@ -172,8 +173,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
......@@ -197,8 +198,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
}
......@@ -229,8 +230,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
}
......@@ -243,7 +244,6 @@ class WebController : public oatpp::web::server::api::ApiController {
WebRequestHandler handler = WebRequestHandler();
String result;
auto status_dto = handler.ShowTables(query_params, result);
std::shared_ptr<OutgoingResponse> response;
......@@ -255,8 +255,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
......@@ -270,8 +270,8 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(GetTable)
ENDPOINT("GET", "/collections/{collection_name}", GetTable,
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
ENDPOINT("GET", "/collections/{collection_name}", GetTable, PATH(String, collection_name),
QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "\'");
tr.RecordSection("Received request.");
......@@ -292,8 +292,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
......@@ -320,8 +320,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
......@@ -335,8 +335,8 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(CreateIndex)
ENDPOINT("POST", "/collections/{collection_name}/indexes", CreateIndex,
PATH(String, collection_name), BODY_STRING(String, body)) {
ENDPOINT("POST", "/collections/{collection_name}/indexes", CreateIndex, PATH(String, collection_name),
BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + "/indexes\'");
tr.RecordSection("Received request.");
......@@ -355,8 +355,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
......@@ -365,7 +365,8 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(GetIndex)
ENDPOINT("GET", "/collections/{collection_name}/indexes", GetIndex, PATH(String, collection_name)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/indexes\'");
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
"/indexes\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
......@@ -385,8 +386,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
......@@ -395,7 +396,8 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(DropIndex)
ENDPOINT("DELETE", "/collections/{collection_name}/indexes", DropIndex, PATH(String, collection_name)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "/indexes\'");
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
"/indexes\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
......@@ -413,8 +415,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
......@@ -428,9 +430,10 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(CreatePartition)
ENDPOINT("POST", "/collections/{collection_name}/partitions",
CreatePartition, PATH(String, collection_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/partitions\'");
ENDPOINT("POST", "/collections/{collection_name}/partitions", CreatePartition, PATH(String, collection_name),
BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() +
"/partitions\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
......@@ -448,17 +451,18 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
ADD_CORS(ShowPartitions)
ENDPOINT("GET", "/collections/{collection_name}/partitions", ShowPartitions,
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/partitions\'");
ENDPOINT("GET", "/collections/{collection_name}/partitions", ShowPartitions, PATH(String, collection_name),
QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
"/partitions\'");
tr.RecordSection("Received request.");
auto offset = query_params.get("offset");
......@@ -476,21 +480,22 @@ class WebController : public oatpp::web::server::api::ApiController {
case StatusCode::COLLECTION_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:response = createDtoResponse(Status::CODE_400, status_dto);
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
ADD_CORS(DropPartition)
ENDPOINT("DELETE", "/collections/{collection_name}/partitions", DropPartition,
PATH(String, collection_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) +
"DELETE \'/collections/" + collection_name->std_str() + "/partitions\'");
ENDPOINT("DELETE", "/collections/{collection_name}/partitions", DropPartition, PATH(String, collection_name),
BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
"/partitions\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
......@@ -508,16 +513,16 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
ADD_CORS(ShowSegments)
ENDPOINT("GET", "/collections/{collection_name}/segments", ShowSegments,
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
ENDPOINT("GET", "/collections/{collection_name}/segments", ShowSegments, PATH(String, collection_name),
QUERIES(const QueryParams&, query_params)) {
auto offset = query_params.get("offset");
auto page_size = query_params.get("page_size");
......@@ -541,7 +546,8 @@ class WebController : public oatpp::web::server::api::ApiController {
* GetSegmentVector
*/
ENDPOINT("GET", "/collections/{collection_name}/segments/{segment_name}/{info}", GetSegmentInfo,
PATH(String, collection_name), PATH(String, segment_name), PATH(String, info), QUERIES(const QueryParams&, query_params)) {
PATH(String, collection_name), PATH(String, segment_name), PATH(String, info),
QUERIES(const QueryParams&, query_params)) {
auto offset = query_params.get("offset");
auto page_size = query_params.get("page_size");
......@@ -570,8 +576,8 @@ class WebController : public oatpp::web::server::api::ApiController {
*
* GetVectorByID ?id=
*/
ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors,
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors, PATH(String, collection_name),
QUERIES(const QueryParams&, query_params)) {
auto handler = WebRequestHandler();
String response;
auto status_dto = handler.GetVector(collection_name, query_params, response);
......@@ -588,9 +594,10 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(Insert)
ENDPOINT("POST", "/collections/{collection_name}/vectors", Insert,
PATH(String, collection_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/vectors\'");
ENDPOINT("POST", "/collections/{collection_name}/vectors", Insert, PATH(String, collection_name),
BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() +
"/vectors\'");
tr.RecordSection("Received request.");
auto ids_dto = VectorIdsDto::createShared();
......@@ -609,17 +616,48 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
ADD_CORS(InsertEntity)
ENDPOINT("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name),
BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections/" + collection_name->std_str() +
"/entities\'");
tr.RecordSection("Received request.");
auto ids_dto = VectorIdsDto::createShared();
WebRequestHandler handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.InsertEntity(collection_name, body, ids_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, ids_dto);
break;
case StatusCode::COLLECTION_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
ADD_CORS(VectorsOp)
ENDPOINT("PUT", "/collections/{collection_name}/vectors", VectorsOp,
PATH(String, collection_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() + "/vectors\'");
ENDPOINT("PUT", "/collections/{collection_name}/vectors", VectorsOp, PATH(String, collection_name),
BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() +
"/vectors\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
......@@ -638,8 +676,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
......@@ -668,8 +706,8 @@ class WebController : public oatpp::web::server::api::ApiController {
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
......@@ -694,19 +732,41 @@ class WebController : public oatpp::web::server::api::ApiController {
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
ADD_CORS(CreateHybridCollection)
ENDPOINT("POST", "/hybrid_collections", CreateHybridCollection, BODY_STRING(String, body_str)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.CreateHybridCollection(body_str);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
}
/**
* Finish ENDPOINTs generation ('ApiController' codegen)
*/
#include OATPP_CODEGEN_END(ApiController)
};
} // namespace web
} // namespace server
} // namespace milvus
} // namespace web
} // namespace server
} // namespace milvus
......@@ -15,6 +15,7 @@
#include <cmath>
#include <ctime>
#include <string>
#include <unordered_map>
#include <vector>
#include "config/Config.h"
......@@ -567,6 +568,251 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
return Status::OK();
}
Status
WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::query::BooleanQueryPtr& query) {
if (json.contains("term")) {
auto leaf_query = std::make_shared<query::LeafQuery>();
auto term_json = json["term"];
std::string field_name = term_json["field_name"];
auto term_value_json = term_json["values"];
if (!term_value_json.is_array()) {
std::string msg = "Term json string is not an array";
return Status{BODY_PARSE_FAIL, msg};
}
auto term_size = term_value_json.size();
auto term_query = std::make_shared<query::TermQuery>();
term_query->field_name = field_name;
term_query->field_value.resize(term_size * sizeof(int64_t));
switch (field_type_.at(field_name)) {
case engine::meta::hybrid::DataType::INT8:
case engine::meta::hybrid::DataType::INT16:
case engine::meta::hybrid::DataType::INT32:
case engine::meta::hybrid::DataType::INT64: {
std::vector<int64_t> term_value(term_size, 0);
for (uint64_t i = 0; i < term_size; ++i) {
term_value[i] = term_value_json[i].get<int64_t>();
}
memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(int64_t));
break;
}
case engine::meta::hybrid::DataType::FLOAT:
case engine::meta::hybrid::DataType::DOUBLE: {
std::vector<double> term_value(term_size, 0);
for (uint64_t i = 0; i < term_size; ++i) {
term_value[i] = term_value_json[i].get<double>();
}
memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(double));
break;
}
}
leaf_query->term_query = term_query;
query->AddLeafQuery(leaf_query);
} else if (json.contains("range")) {
auto leaf_query = std::make_shared<query::LeafQuery>();
auto range_query = std::make_shared<query::RangeQuery>();
auto range_json = json["range"];
std::string field_name = range_json["field_name"];
range_query->field_name = field_name;
auto range_value_json = range_json["values"];
if (range_value_json.contains("lt")) {
query::CompareExpr compare_expr;
compare_expr.compare_operator = query::CompareOperator::LT;
compare_expr.operand = range_value_json["lt"].get<std::string>();
range_query->compare_expr.emplace_back(compare_expr);
}
if (range_value_json.contains("lte")) {
query::CompareExpr compare_expr;
compare_expr.compare_operator = query::CompareOperator::LTE;
compare_expr.operand = range_value_json["lte"].get<std::string>();
range_query->compare_expr.emplace_back(compare_expr);
}
if (range_value_json.contains("eq")) {
query::CompareExpr compare_expr;
compare_expr.compare_operator = query::CompareOperator::EQ;
compare_expr.operand = range_value_json["eq"].get<std::string>();
range_query->compare_expr.emplace_back(compare_expr);
}
if (range_value_json.contains("ne")) {
query::CompareExpr compare_expr;
compare_expr.compare_operator = query::CompareOperator::NE;
compare_expr.operand = range_value_json["ne"].get<std::string>();
range_query->compare_expr.emplace_back(compare_expr);
}
if (range_value_json.contains("gt")) {
query::CompareExpr compare_expr;
compare_expr.compare_operator = query::CompareOperator::GT;
compare_expr.operand = range_value_json["gt"].get<std::string>();
range_query->compare_expr.emplace_back(compare_expr);
}
if (range_value_json.contains("gte")) {
query::CompareExpr compare_expr;
compare_expr.compare_operator = query::CompareOperator::GTE;
compare_expr.operand = range_value_json["gte"].get<std::string>();
range_query->compare_expr.emplace_back(compare_expr);
}
leaf_query->range_query = range_query;
query->AddLeafQuery(leaf_query);
} else if (json.contains("vector")) {
auto leaf_query = std::make_shared<query::LeafQuery>();
auto vector_query = std::make_shared<query::VectorQuery>();
auto vector_json = json["vector"];
std::string field_name = vector_json["field_name"];
vector_query->field_name = field_name;
engine::VectorsData vectors;
// TODO(yukun): process binary vector
CopyRecordsFromJson(vector_json["values"], vectors, false);
vector_query->query_vector.float_data = vectors.float_data_;
vector_query->query_vector.binary_data = vectors.binary_data_;
vector_query->topk = vector_json["topk"].get<int64_t>();
vector_query->extra_params = vector_json["extra_params"];
leaf_query->vector_query = vector_query;
query->AddLeafQuery(leaf_query);
}
return Status::OK();
}
Status
WebRequestHandler::ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query) {
if (query_json.contains("must")) {
boolean_query->SetOccur(query::Occur::MUST);
auto must_json = query_json["must"];
if (!must_json.is_array()) {
std::string msg = "Must json string is not an array";
return Status{BODY_PARSE_FAIL, msg};
}
for (auto& json : must_json) {
auto must_query = std::make_shared<query::BooleanQuery>();
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
ProcessBoolQueryJson(json, must_query);
boolean_query->AddBooleanQuery(must_query);
} else {
ProcessLeafQueryJson(json, boolean_query);
}
}
return Status::OK();
} else if (query_json.contains("should")) {
boolean_query->SetOccur(query::Occur::SHOULD);
auto should_json = query_json["should"];
if (!should_json.is_array()) {
std::string msg = "Should json string is not an array";
return Status{BODY_PARSE_FAIL, msg};
}
for (auto& json : should_json) {
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
auto should_query = std::make_shared<query::BooleanQuery>();
ProcessBoolQueryJson(json, should_query);
boolean_query->AddBooleanQuery(should_query);
} else {
ProcessLeafQueryJson(json, boolean_query);
}
}
return Status::OK();
} else if (query_json.contains("must_not")) {
boolean_query->SetOccur(query::Occur::MUST_NOT);
auto should_json = query_json["must_not"];
if (!should_json.is_array()) {
std::string msg = "Must_not json string is not an array";
return Status{BODY_PARSE_FAIL, msg};
}
for (auto& json : should_json) {
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
auto must_not_query = std::make_shared<query::BooleanQuery>();
ProcessBoolQueryJson(json, must_not_query);
boolean_query->AddBooleanQuery(must_not_query);
} else {
ProcessLeafQueryJson(json, boolean_query);
}
}
return Status::OK();
} else {
std::string msg = "Must json string doesnot include right query";
return Status{BODY_PARSE_FAIL, msg};
}
}
Status
WebRequestHandler::HybridSearch(const std::string& collection_name, const nlohmann::json& json,
std::string& result_str) {
Status status;
status = request_handler_.DescribeHybridCollection(context_ptr_, collection_name, field_type_);
if (!status.ok()) {
return Status{UNEXPECTED_ERROR, "DescribeHybridCollection failed"};
}
std::vector<std::string> partition_tags;
if (json.contains("partition_tags")) {
auto tags = json["partition_tags"];
if (!tags.is_null() && !tags.is_array()) {
return Status(BODY_PARSE_FAIL, "Field \"partition_tags\" must be a array");
}
for (auto& tag : tags) {
partition_tags.emplace_back(tag.get<std::string>());
}
}
if (json.contains("bool")) {
auto boolean_query_json = json["bool"];
query::BooleanQueryPtr boolean_query = std::make_shared<query::BooleanQuery>();
status = ProcessBoolQueryJson(boolean_query_json, boolean_query);
if (!status.ok()) {
return status;
}
query::GeneralQueryPtr general_query = std::make_shared<query::GeneralQuery>();
query::GenBinaryQuery(boolean_query, general_query->bin);
context::HybridSearchContextPtr hybrid_search_context = std::make_shared<context::HybridSearchContext>();
TopKQueryResult result;
status = request_handler_.HybridSearch(context_ptr_, hybrid_search_context, collection_name, partition_tags,
general_query, result);
if (!status.ok()) {
return status;
}
nlohmann::json result_json;
result_json["num"] = result.row_num_;
if (result.row_num_ == 0) {
result_json["result"] = std::vector<int64_t>();
result_str = result_json.dump();
return Status::OK();
}
auto step = result.id_list_.size() / result.row_num_;
nlohmann::json search_result_json;
for (size_t i = 0; i < result.row_num_; i++) {
nlohmann::json raw_result_json;
for (size_t j = 0; j < step; j++) {
nlohmann::json one_result_json;
one_result_json["id"] = std::to_string(result.id_list_.at(i * step + j));
one_result_json["distance"] = std::to_string(result.distance_list_.at(i * step + j));
raw_result_json.emplace_back(one_result_json);
}
search_result_json.emplace_back(raw_result_json);
}
result_json["result"] = search_result_json;
result_str = result_json.dump();
}
return Status::OK();
}
Status
WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohmann::json& json,
std::string& result_str) {
......@@ -930,6 +1176,50 @@ WebRequestHandler::CreateTable(const TableRequestDto::ObjectWrapper& collection_
ASSIGN_RETURN_STATUS_DTO(status)
}
StatusDto::ObjectWrapper
WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& body) {
auto json_str = nlohmann::json::parse(body->c_str());
std::string collection_name = json_str["collection_name"];
// TODO(yukun): do checking
std::vector<std::pair<std::string, engine::meta::hybrid::DataType>> field_types;
std::vector<std::pair<std::string, std::string>> field_extra_params;
std::vector<std::pair<std::string, uint64_t>> vector_dimensions;
for (auto& field : json_str["fields"]) {
std::string field_name = field["field_name"];
std::string field_type = field["field_type"];
auto extra_params = field["extra_params"];
if (field_type == "int8") {
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT8));
} else if (field_type == "int16") {
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT16));
} else if (field_type == "int32") {
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT32));
} else if (field_type == "int64") {
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT64));
} else if (field_type == "float") {
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::FLOAT));
} else if (field_type == "double") {
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::DOUBLE));
} else if (field_type == "vector") {
} else {
std::string msg = field_name + " has wrong field_type";
RETURN_STATUS_DTO(BODY_PARSE_FAIL, msg.c_str());
}
field_extra_params.emplace_back(std::make_pair(field_name, extra_params.dump()));
if (extra_params.contains("dimension")) {
vector_dimensions.emplace_back(std::make_pair(field_name, extra_params["dimension"].get<uint64_t>()));
}
}
auto status = request_handler_.CreateHybridCollection(context_ptr_, collection_name, field_types, vector_dimensions,
field_extra_params);
ASSIGN_RETURN_STATUS_DTO(status)
}
StatusDto::ObjectWrapper
WebRequestHandler::ShowTables(const OQueryParams& query_params, OString& result) {
int64_t offset = 0;
......@@ -1347,6 +1637,106 @@ WebRequestHandler::Insert(const OString& collection_name, const OString& body, V
ASSIGN_RETURN_STATUS_DTO(status)
}
StatusDto::ObjectWrapper
WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::server::web::OString& body,
VectorIdsDto::ObjectWrapper& ids_dto) {
if (nullptr == body.get() || body->getSize() == 0) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Request payload is required.")
}
auto body_json = nlohmann::json::parse(body->c_str());
std::string partition_tag = body_json["partition_tag"];
uint64_t row_num = body_json["row_num"];
std::unordered_map<std::string, engine::meta::hybrid::DataType> field_types;
auto status = request_handler_.DescribeHybridCollection(context_ptr_, collection_name->c_str(), field_types);
auto entities = body_json["entity"];
if (!entities.is_array()) {
RETURN_STATUS_DTO(ILLEGAL_BODY, "An entity must be an array");
}
std::vector<std::string> field_names;
std::vector<std::vector<uint8_t>> attr_values;
size_t attr_size = 0;
std::unordered_map<std::string, engine::VectorsData> vector_datas;
for (auto& entity : entities) {
std::string field_name = entity["field_name"];
field_names.emplace_back(field_name);
auto field_value = entity["field_value"];
std::vector<uint8_t> attr_value;
switch (field_types.at(field_name)) {
case engine::meta::hybrid::DataType::INT8:
case engine::meta::hybrid::DataType::INT16:
case engine::meta::hybrid::DataType::INT32:
case engine::meta::hybrid::DataType::INT64: {
std::vector<int64_t> value;
auto size = field_value.size();
value.resize(size);
attr_value.resize(size * sizeof(int64_t));
size_t offset = 0;
for (auto data : field_value) {
value[offset] = data.get<int64_t>();
++offset;
}
memcpy(attr_value.data(), value.data(), size * sizeof(int64_t));
attr_size += size * sizeof(int64_t);
attr_values.emplace_back(attr_value);
break;
}
case engine::meta::hybrid::DataType::FLOAT:
case engine::meta::hybrid::DataType::DOUBLE: {
std::vector<double> value;
auto size = field_value.size();
value.resize(size);
attr_value.resize(size * sizeof(double));
size_t offset = 0;
for (auto data : field_value) {
value[offset] = data.get<double>();
++offset;
}
memcpy(attr_value.data(), value.data(), size * sizeof(double));
attr_size += size * sizeof(double);
attr_values.emplace_back(attr_value);
break;
}
case engine::meta::hybrid::DataType::VECTOR: {
bool bin_flag;
status = IsBinaryTable(collection_name->c_str(), bin_flag);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
engine::VectorsData vectors;
CopyRecordsFromJson(field_value, vectors, bin_flag);
vector_datas.insert(std::make_pair(field_name, vectors));
}
default: {}
}
}
std::vector<uint8_t> attrs(attr_size, 0);
size_t attr_offset = 0;
for (auto& data : attr_values) {
memcpy(attrs.data() + attr_offset, data.data(), data.size());
attr_offset += data.size();
}
status = request_handler_.InsertEntity(context_ptr_, collection_name->c_str(), partition_tag, row_num, field_names,
attrs, vector_datas);
if (status.ok()) {
ids_dto->ids = ids_dto->ids->createShared();
for (auto& id : vector_datas.begin()->second.id_array_) {
ids_dto->ids->pushBack(std::to_string(id).c_str());
}
}
ASSIGN_RETURN_STATUS_DTO(status)
}
StatusDto::ObjectWrapper
WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) {
int64_t id = 0;
......@@ -1389,6 +1779,8 @@ WebRequestHandler::VectorsOp(const OString& collection_name, const OString& payl
status = DeleteByIDs(collection_name->std_str(), payload_json["delete"], result_str);
} else if (payload_json.contains("search")) {
status = Search(collection_name->std_str(), payload_json["search"], result_str);
} else if (payload_json.contains("query")) {
status = HybridSearch(collection_name->c_str(), payload_json["query"], result_str);
} else {
status = Status(ILLEGAL_BODY, "Unknown body");
}
......
......@@ -14,6 +14,7 @@
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
......@@ -141,6 +142,15 @@ class WebRequestHandler {
Status
Search(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
Status
ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& boolean_query);
Status
ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query);
Status
HybridSearch(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
Status
DeleteByIDs(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
......@@ -176,6 +186,9 @@ class WebRequestHandler {
StatusDto::ObjectWrapper
ShowTables(const OQueryParams& query_params, OString& result);
StatusDto::ObjectWrapper
CreateHybridCollection(const OString& body);
StatusDto::ObjectWrapper
GetTable(const OString& collection_name, const OQueryParams& query_params, OString& result);
......@@ -219,6 +232,9 @@ class WebRequestHandler {
StatusDto::ObjectWrapper
Insert(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto);
StatusDto::ObjectWrapper
InsertEntity(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto);
StatusDto::ObjectWrapper
GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response);
......@@ -244,6 +260,7 @@ class WebRequestHandler {
private:
std::shared_ptr<Context> context_ptr_;
RequestHandler request_handler_;
std::unordered_map<std::string, engine::meta::hybrid::DataType> field_type_;
};
} // namespace web
......
......@@ -156,6 +156,17 @@ RandomBinRecordsJson(int64_t dim, int64_t num) {
return json;
}
nlohmann::json
RandomAttrRecordsJson(int64_t row_num) {
nlohmann::json json;
std::default_random_engine e;
std::uniform_int_distribution<unsigned> u(0, 1000);
for (size_t i = 0; i < row_num; i++) {
json.push_back(u(e));
}
return json;
}
std::string
RandomName() {
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
......@@ -697,6 +708,12 @@ class TestClient : public oatpp::web::client::ApiClient {
API_CALL("PUT", "/system/{op}", op, PATH(String, cmd_str, "op"), BODY_STRING(String, body))
API_CALL("POST", "/hybrid_collections", createHybridCollection, BODY_STRING(String, body_str))
API_CALL("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name), BODY_STRING(String, body))
// API_CALL("POST", "/hybrid_collections/{collection_name}/vectors", HybridSearch, PATH(String, collection_name), BODY_STRING(String, body))
#include OATPP_CODEGEN_END(ApiClient)
};
......@@ -967,6 +984,92 @@ TEST_F(WebControllerTest, CREATE_COLLECTION) {
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
}
TEST_F(WebControllerTest, HYBRID_TEST) {
nlohmann::json create_json;
create_json["collection_name"] = "test_hybrid";
nlohmann::json field_json_0, field_json_1;
field_json_0["field_name"] = "field_0";
field_json_0["field_type"] = "int64";
field_json_0["extra_params"] = "";
field_json_1["field_name"] = "field_1";
field_json_1["field_type"] = "vector";
nlohmann::json extra_params;
extra_params["dimension"] = 128;
field_json_1["extra_params"] = extra_params;
create_json["fields"].push_back(field_json_0);
create_json["fields"].push_back(field_json_1);
auto response = client_ptr->createHybridCollection(create_json.dump().c_str());
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::SUCCESS, result_dto->code->getValue()) << result_dto->message->std_str();
int64_t dimension = 128;
int64_t row_num = 1000;
nlohmann::json insert_json;
insert_json["partition_tag"] = "";
nlohmann::json entity_0, entity_1;
entity_0["field_name"] = "field_0";
entity_0["field_value"] = RandomAttrRecordsJson(row_num);
entity_1["field_name"] = "field_1";
entity_1["field_value"] = RandomRecordsJson(dimension, row_num);
insert_json["entity"].push_back(entity_0);
insert_json["entity"].push_back(entity_1);
insert_json["row_num"] = row_num;
OString collection_name = "test_hybrid";
response = client_ptr->InsertEntity(collection_name, insert_json.dump().c_str(), conncetion_ptr);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
auto vector_dto = response->readBodyToDto<milvus::server::web::VectorIdsDto>(object_mapper.get());
ASSERT_EQ(row_num, vector_dto->ids->count());
auto status = FlushTable(client_ptr, conncetion_ptr, collection_name);
ASSERT_TRUE(status.ok()) << status.message();
// TODO(yukun): when hybrid operation is added to wal, the sleep() can be deleted
sleep(2);
int64_t nq = 10;
int64_t topk = 100;
nlohmann::json query_json, bool_json, term_json, range_json, vector_json;
term_json["term"]["field_name"] = "field_0";
term_json["term"]["values"] = RandomAttrRecordsJson(nq);
bool_json["must"].push_back(term_json);
range_json["range"]["field_name"] = "field_0";
nlohmann::json comp_json;
comp_json["gte"] = "0";
comp_json["lte"] = "100000";
range_json["range"]["values"] = comp_json;
bool_json["must"].push_back(range_json);
vector_json["vector"]["field_name"] = "field_1";
vector_json["vector"]["topk"] = topk;
vector_json["vector"]["nq"] = nq;
vector_json["vector"]["values"] = RandomRecordsJson(128, nq);
bool_json["must"].push_back(vector_json);
query_json["query"]["bool"] = bool_json;
response = client_ptr->vectorsOp(collection_name, query_json.dump().c_str(), conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
auto result_json = nlohmann::json::parse(response->readBodyToString()->std_str());
ASSERT_TRUE(result_json.contains("num"));
ASSERT_TRUE(result_json["num"].is_number());
ASSERT_EQ(nq, result_json["num"].get<int64_t>());
ASSERT_TRUE(result_json.contains("result"));
ASSERT_TRUE(result_json["result"].is_array());
auto result0_json = result_json["result"][0];
ASSERT_TRUE(result0_json.is_array());
ASSERT_EQ(topk, result0_json.size());
}
TEST_F(WebControllerTest, GET_COLLECTION_META) {
OString collection_name = "web_test_create_collection" + OString(RandomName().c_str());
GenTable(client_ptr, conncetion_ptr, collection_name, 10, 10, "L2");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册