未验证 提交 4beb0549 编写于 作者: Y yukun 提交者: GitHub

Add web server interface (#3257)

Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 6f5be4b5
......@@ -1698,17 +1698,17 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery(
if (vector_param_it != it.value().end()) {
const std::string& field_name = vector_param_it.key();
vector_query->field_name = field_name;
nlohmann::json vector_json = vector_param_it.value();
int64_t topk = vector_json["topk"];
nlohmann::json param_json = vector_param_it.value();
int64_t topk = param_json["topk"];
status = server::ValidateSearchTopk(topk);
if (!status.ok()) {
return status;
}
vector_query->topk = topk;
if (vector_json.contains("metric_type")) {
std::string metric_type = vector_json["metric_type"];
if (param_json.contains("metric_type")) {
std::string metric_type = param_json["metric_type"];
vector_query->metric_type = metric_type;
query_ptr->metric_types.insert({field_name, vector_json["metric_type"]});
query_ptr->metric_types.insert({field_name, param_json["metric_type"]});
}
if (!vector_param_it.value()["params"].empty()) {
vector_query->extra_params = vector_param_it.value()["params"];
......
......@@ -11,6 +11,7 @@
#pragma once
#include <map>
#include <string>
#include <unordered_map>
......@@ -72,6 +73,13 @@ enum StatusCode : int {
MAX = ILLEGAL_QUERY_PARAM
};
static std::map<std::string, engine::DataType> str2type = {{"int32", engine::DataType::INT32},
{"int64", engine::DataType::INT64},
{"float", engine::DataType::FLOAT},
{"double", engine::DataType::DOUBLE},
{"vector_float", engine::DataType::VECTOR_FLOAT},
{"vector_binary", engine::DataType::VECTOR_BINARY}};
} // namespace web
} // namespace server
} // namespace milvus
......@@ -215,74 +215,71 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(CreateCollection)
ENDPOINT("POST", "/collections", CreateCollection, BODY_DTO(CollectionRequestDto::ObjectWrapper, body)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections\'");
// tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.CreateCollection(body);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_201, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
StatusDto::ObjectWrapper status;
auto response = createDtoResponse(Status::CODE_200, status);
ENDPOINT("POST", "/collections", CreateCollection, BODY_STRING(String, body_str)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.CreateCollection(body_str);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
}
ADD_CORS(ShowCollections)
ENDPOINT("GET", "/collections", ShowCollections, QUERIES(const QueryParams&, query_params)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections\'");
// tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// String result;
// auto status_dto = handler.ShowCollections(query_params, result);
// std::shared_ptr<OutgoingResponse> response;
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, result);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
json result_json = R"({
"collections": [
{
"collection_name": "test_collection",
"fields": [
{
"field_name": "field_vec",
"field_type": "VECTOR_FLOAT",
"index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
"extra_params": {"dimension": 128, "metric_type": "L2"}
}
],
"segment_size": 1024
}
],
"count": 58
})";
String result = result_json.dump().c_str();
auto response = createResponse(Status::CODE_200, result);
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
String result;
auto status_dto = handler.ShowCollections(query_params, result);
std::shared_ptr<OutgoingResponse> response;
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createResponse(Status::CODE_200, result);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
// json result_json = R"({
// "collections": [
// {
// "collection_name": "test_collection",
// "fields": [
// {
// "field_name": "field_vec",
// "field_type": "VECTOR_FLOAT",
// "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
// "extra_params": {"dimension": 128, "metric_type": "L2"}
// }
// ],
// "segment_size": 1024
// }
// ],
// "count": 58
// })";
response = createResponse(Status::CODE_200, result);
return response;
}
......@@ -296,74 +293,71 @@ class WebController : public oatpp::web::server::api::ApiController {
ENDPOINT("GET", "/collections/{collection_name}", GetCollection, PATH(String, collection_name),
QUERIES(const QueryParams&, query_params)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
// "\'"); tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// String response_str;
// auto status_dto = handler.GetCollection(collection_name, query_params, response_str);
//
// std::shared_ptr<OutgoingResponse> response;
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, response_str);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
json result_json = R"({
"collection_name": "test_collection",
"fields": [
{
"field_name": "field_vec",
"field_type": "VECTOR_FLOAT",
"index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
"extra_params": {"dimension": 128, "metric_type": "L2"}
}
],
"row_count": 10000
})";
auto response = createResponse(Status::CODE_200, result_json.dump().c_str());
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
String response_str;
auto status_dto = handler.GetCollection(collection_name, query_params, response_str);
std::shared_ptr<OutgoingResponse> response;
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createResponse(Status::CODE_200, response_str);
break;
case StatusCode::COLLECTION_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
// json result_json = R"({
// "collection_name": "test_collection",
// "fields": [
// {
// "field_name": "field_vec",
// "field_type": "VECTOR_FLOAT",
// "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
// "extra_params": {"dimension": 128, "metric_type": "L2"}
// }
// ],
// "row_count": 10000
// })";
return response;
}
ADD_CORS(DropCollection)
ENDPOINT("DELETE", "/collections/{collection_name}", DropCollection, PATH(String, collection_name)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
// "\'"); tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.DropCollection(collection_name);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_204, status_dto);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
StatusDto::ObjectWrapper status;
auto response = createDtoResponse(Status::CODE_201, status);
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.DropCollection(collection_name);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_204, status_dto);
break;
case StatusCode::COLLECTION_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
}
......@@ -378,97 +372,90 @@ class WebController : public oatpp::web::server::api::ApiController {
ENDPOINT("POST", "/collections/{collection_name}/fields/{field_name}/indexes/{index_name}", CreateIndex,
PATH(String, collection_name), PATH(String, field_name), PATH(String, index_name),
BODY_STRING(String, body)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() +
// "/indexes\'"); tr.RecordSection("Received request.");
//
// auto handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.CreateIndex(collection_name, body);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_201, status_dto);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
StatusDto::ObjectWrapper status;
auto response = createDtoResponse(Status::CODE_201, status);
return response;
}
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + "/indexes\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.CreateIndex(collection_name, field_name, body);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, status_dto);
break;
case StatusCode::COLLECTION_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
ADD_CORS(GetIndex)
ENDPOINT("GET", "/collections/{collection_name}/fields/{field_name}/indexes", GetIndex,
PATH(String, collection_name), PATH(String, field_name)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
// "/indexes\'");
// tr.RecordSection("Received request.");
//
// auto handler = WebRequestHandler();
//
// OString result;
// auto status_dto = handler.GetIndex(collection_name, result);
//
// std::shared_ptr<OutgoingResponse> response;
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, result);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
json result = R"({ "index_name": "FLAT", "params": {"index_type": "IVF_FLAT", "nlist": 4096 } })";
auto response = createResponse(Status::CODE_200, result.dump().c_str());
return response;
}
// ADD_CORS(GetIndex)
//
// ENDPOINT("GET", "/collections/{collection_name}/fields/{field_name}/indexes", GetIndex,
// PATH(String, collection_name), PATH(String, field_name)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
// "/indexes\'");
// tr.RecordSection("Received request.");
//
// auto handler = WebRequestHandler();
//
// OString result;
// auto status_dto = handler.GetIndex(collection_name, result);
//
// std::shared_ptr<OutgoingResponse> response;
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, result);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
//
// return response;
// }
ADD_CORS(DropIndex)
ENDPOINT("DELETE", "/collections/{collection_name}/fields/{field_name}/indexes/{index_name}", DropIndex,
PATH(String, collection_name), PATH(String, field_name), PATH(String, index_name)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
// "/indexes\'");
// tr.RecordSection("Received request.");
//
// auto handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.DropIndex(collection_name);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_204, status_dto);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
StatusDto::ObjectWrapper status;
auto response = createDtoResponse(Status::CODE_204, status);
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
"/indexes\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.DropIndex(collection_name, field_name);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_204, status_dto);
break;
case StatusCode::COLLECTION_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
}
......@@ -574,23 +561,18 @@ class WebController : public oatpp::web::server::api::ApiController {
ENDPOINT("GET", "/collections/{collection_name}/partitions/{partition_tag}/entities", GetEntities,
PATH(String, collection_name), PATH(String, partition_tag), QUERIES(const QueryParams&, query_params),
BODY_STRING(String, body)) {
json result = R"({
"entities": [
{
"__id": "1578989029645098000",
"field_1": 1,
"field_vec": []
},
{
"__id": "1578989029645098001",
"field_1": 2,
"field_vec": []
}
]
})";
auto response = createResponse(Status::CODE_200, result.dump().c_str());
return response;
auto handler = WebRequestHandler();
String response;
auto status_dto = handler.GetEntity(collection_name, query_params, response);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
return createResponse(Status::CODE_200, response);
case StatusCode::COLLECTION_NOT_EXISTS:
return createDtoResponse(Status::CODE_404, status_dto);
default:
return createDtoResponse(Status::CODE_400, status_dto);
}
}
ADD_CORS(ShowSegments)
......@@ -645,75 +627,6 @@ class WebController : public oatpp::web::server::api::ApiController {
return createResponse(Status::CODE_204, "No Content");
}
ADD_CORS(GetVectors)
/**
*
* GetVectorByID ?id=
*/
ENDPOINT("GET", "/collections/{collection_name}/Entities", GetVectors, PATH(String, collection_name),
QUERIES(const QueryParams&, query_params)) {
// auto handler = WebRequestHandler();
// String response;
// auto status_dto = handler.GetVector(collection_name, query_params, response);
//
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// return createResponse(Status::CODE_200, response);
// case StatusCode::COLLECTION_NOT_EXISTS:
// return createDtoResponse(Status::CODE_404, status_dto);
// default:
// return createDtoResponse(Status::CODE_400, status_dto);
// }
json result = R"({
"entities": [
{
"__id": "1578989029645098000",
"field_1": 1,
"field_vec": []
},
{
"__id": "1578989029645098001",
"field_1": 2,
"field_vec": []
}
]
})";
auto response = createResponse(Status::CODE_200, result.dump().c_str());
return response;
}
ADD_CORS(Insert)
ENDPOINT("POST", "/collections/{collection_name}/entities", Insert, PATH(String, collection_name),
BODY_STRING(String, body)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() +
// "/vectors\'");
// tr.RecordSection("Received request.");
//
// auto ids_dto = VectorIdsDto::createShared();
// WebRequestHandler handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.Insert(collection_name, body, ids_dto);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_201, ids_dto);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost");
StatusDto::ObjectWrapper status;
auto response = createDtoResponse(Status::CODE_201, status);
return response;
}
ADD_CORS(InsertEntity)
ENDPOINT("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name),
......@@ -756,7 +669,7 @@ class WebController : public oatpp::web::server::api::ApiController {
OString result;
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.VectorsOp(collection_name, body, result);
auto status_dto = handler.EntityOp(collection_name, body, result);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createResponse(Status::CODE_200, result);
......@@ -774,61 +687,6 @@ class WebController : public oatpp::web::server::api::ApiController {
return response;
}
ADD_CORS(VectorsOp)
ENDPOINT("PUT", "/collections/{collection_name}/entities", VectorsOp, PATH(String, collection_name),
BODY_STRING(String, body)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() +
// "/vectors\'");
// tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// OString result;
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.VectorsOp(collection_name, body, result);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, result);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost");
json result = R"({
"num": 2,
"results": [
[
{
"id": "1578989029645098000",
"distance": "0.000000",
"entity": {
"field_1": 1,
"field_2": 2,
"field_vec": []
}
},
{
"id": "1578989029645098001",
"distance": "0.010000",
"entity": {
"field_1": 10,
"field_2": 20,
"field_vec": []
}
}
]
]
})";
auto response = createResponse(Status::CODE_200, result.dump().c_str());
return response;
}
ADD_CORS(SystemOptions)
ENDPOINT("OPTIONS", "/system/{info}", SystemOptions) {
......@@ -885,29 +743,6 @@ class WebController : public oatpp::web::server::api::ApiController {
return response;
}
ADD_CORS(CreateHybridCollection)
ENDPOINT("POST", "/hybrid_collections", CreateHybridCollection, BODY_STRING(String, body_str)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.CreateHybridCollection(body_str);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost";
tr.ElapseFromBegin(ttr);
return response;
}
/**
* Finish ENDPOINTs generation ('ApiController' codegen)
*/
......
......@@ -13,6 +13,7 @@
#include <algorithm>
#include <ctime>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -23,6 +24,7 @@
#include "db/Utils.h"
#include "metrics/SystemInfo.h"
#include "query/BinaryQuery.h"
#include "server/ValidationUtil.h"
#include "server/delivery/request/BaseReq.h"
#include "server/web_impl/Constants.h"
#include "server/web_impl/Types.h"
......@@ -117,29 +119,31 @@ WebRequestHandler::IsBinaryCollection(const std::string& collection_name, bool&
}
Status
WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, engine::VectorsData& vectors, bool bin) {
WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, std::vector<uint8_t>& vectors_data, bool bin) {
if (!json.is_array()) {
return Status(ILLEGAL_BODY, "field \"vectors\" must be a array");
}
vectors.vector_count_ = json.size();
std::vector<float> float_vector;
if (!bin) {
for (auto& vec : json) {
if (!vec.is_array()) {
return Status(ILLEGAL_BODY, "A vector in field \"vectors\" must be a float array");
}
for (auto& data : vec) {
vectors.float_data_.emplace_back(data.get<float>());
float_vector.emplace_back(data.get<float>());
}
}
auto size = float_vector.size() * sizeof(float);
vectors_data.resize(size);
memcpy(vectors_data.data(), float_vector.data(), size);
} else {
for (auto& vec : json) {
if (!vec.is_array()) {
return Status(ILLEGAL_BODY, "A vector in field \"vectors\" must be a float array");
}
for (auto& data : vec) {
vectors.binary_data_.emplace_back(data.get<uint8_t>());
vectors_data.emplace_back(data.get<uint8_t>());
}
}
}
......@@ -147,6 +151,79 @@ WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, engine::Vecto
return Status::OK();
}
Status
WebRequestHandler::CopyData2Json(const milvus::engine::DataChunkPtr& data_chunk,
const milvus::engine::snapshot::FieldElementMappings& field_mappings,
const std::vector<int64_t>& id_array, nlohmann::json& json_res) {
int64_t id_size = id_array.size();
for (int i = 0; i < id_size; i++) {
nlohmann::json one_json;
nlohmann::json entity_json;
for (const auto& it : field_mappings) {
auto type = it.first->GetFtype();
std::string name = it.first->GetName();
engine::BinaryDataPtr data = data_chunk->fixed_fields_[name];
if (data == nullptr || data->data_.empty())
continue;
auto single_size = data->data_.size() / id_size;
switch (type) {
case engine::DataType::INT32: {
int32_t int32_value;
int64_t offset = sizeof(int32_t) * i;
memcpy(&int32_value, data->data_.data() + offset, sizeof(int32_t));
entity_json[name] = int32_value;
break;
}
case engine::DataType::INT64: {
int64_t int64_value;
int64_t offset = sizeof(int64_t) * i;
memcpy(&int64_value, data->data_.data() + offset, sizeof(int64_t));
entity_json[name] = int64_value;
break;
}
case engine::DataType::FLOAT: {
float float_value;
int64_t offset = sizeof(float) * i;
memcpy(&float_value, data->data_.data() + offset, sizeof(float));
entity_json[name] = float_value;
break;
}
case engine::DataType::DOUBLE: {
double double_value;
int64_t offset = sizeof(double) * i;
memcpy(&double_value, data->data_.data() + offset, sizeof(double));
entity_json[name] = double_value;
break;
}
case engine::DataType::VECTOR_BINARY: {
std::vector<int8_t> binary_vector;
auto vector_size = single_size * sizeof(int8_t) / sizeof(int8_t);
binary_vector.resize(vector_size);
int64_t offset = vector_size * i;
memcpy(binary_vector.data(), data->data_.data() + offset, vector_size);
entity_json[name] = binary_vector;
break;
}
case engine::DataType::VECTOR_FLOAT: {
std::vector<float> float_vector;
auto vector_size = single_size * sizeof(int8_t) / sizeof(float);
float_vector.resize(vector_size);
int64_t offset = vector_size * i;
memcpy(float_vector.data(), data->data_.data() + offset, vector_size);
entity_json[name] = float_vector;
break;
}
}
}
one_json["entity"] = entity_json;
one_json["id"] = id_array[i];
json_res.push_back(one_json);
}
}
///////////////////////// WebRequestHandler methods ///////////////////////////////////////
Status
WebRequestHandler::GetCollectionMetaInfo(const std::string& collection_name, nlohmann::json& json_out) {
......@@ -157,12 +234,14 @@ WebRequestHandler::GetCollectionMetaInfo(const std::string& collection_name, nlo
STATUS_CHECK(req_handler_.CountEntities(context_ptr_, collection_name, count));
json_out["collection_name"] = schema.collection_name_;
json_out["dimension"] = schema.extra_params_[engine::PARAM_DIMENSION].get<int64_t>();
json_out["segment_row_count"] = schema.extra_params_[engine::PARAM_SEGMENT_ROW_COUNT].get<int64_t>();
json_out["metric_type"] = schema.extra_params_[engine::PARAM_INDEX_METRIC_TYPE].get<int64_t>();
json_out["index_params"] = schema.extra_params_[engine::PARAM_INDEX_EXTRA_PARAMS].get<std::string>();
json_out["count"] = count;
for (const auto& field : schema.fields_) {
nlohmann::json field_json;
field_json["field_name"] = field.first;
field_json["field_type"] = field.second.field_type_;
field_json["index_params"] = field.second.index_params_;
field_json["extra_params"] = field.second.field_params_;
json_out["field"].push_back(field_json);
}
return Status::OK();
}
......@@ -194,7 +273,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& collection_name, int64_t
auto new_ids = std::vector<int64_t>(vector_ids.begin() + ids_begin, vector_ids.begin() + ids_end);
nlohmann::json vectors_json;
auto status = GetVectorsByIDs(collection_name, new_ids, vectors_json);
// auto status = GetVectorsByIDs(collection_name, new_ids, vectors_json);
nlohmann::json result_json;
if (vectors_json.empty()) {
......@@ -204,7 +283,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& collection_name, int64_t
}
json_out["count"] = vector_ids.size();
AddStatusToJson(json_out, status.code(), status.message());
// AddStatusToJson(json_out, status.code(), status.message());
return Status::OK();
}
......@@ -406,287 +485,162 @@ WebRequestHandler::SetConfig(const nlohmann::json& json, std::string& result_str
}
Status
WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::query::BooleanQueryPtr& query) {
WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::query::BooleanQueryPtr& query,
std::string& field_name, query::QueryPtr& query_ptr) {
auto status = Status::OK();
if (json.contains("term")) {
auto leaf_query = std::make_shared<query::LeafQuery>();
auto term_json = json["term"];
std::string field_name = term_json["field_name"];
auto term_value_json = term_json["values"];
if (!term_value_json.is_array()) {
std::string msg = "Term json string is not an array";
return Status{BODY_PARSE_FAIL, msg};
}
// auto term_size = term_value_json.size();
// auto term_query = std::make_shared<query::TermQuery>();
// term_query->field_name = field_name;
// term_query->field_value.resize(term_size * sizeof(int64_t));
//
// switch (field_type_.at(field_name)) {
// case engine::DataType::INT8:
// case engine::DataType::INT16:
// case engine::DataType::INT32:
// case engine::DataType::INT64: {
// std::vector<int64_t> term_value(term_size, 0);
// for (uint64_t i = 0; i < term_size; ++i) {
// term_value[i] = term_value_json[i].get<int64_t>();
// }
// memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(int64_t));
// break;
// }
// case engine::DataType::FLOAT:
// case engine::DataType::DOUBLE: {
// std::vector<double> term_value(term_size, 0);
// for (uint64_t i = 0; i < term_size; ++i) {
// term_value[i] = term_value_json[i].get<double>();
// }
// memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(double));
// break;
// }
// default:
// break;
// }
//
// leaf_query->term_query = term_query;
// query->AddLeafQuery(leaf_query);
// } else if (json.contains("range")) {
// auto leaf_query = std::make_shared<query::LeafQuery>();
// auto range_query = std::make_shared<query::RangeQuery>();
//
// auto range_json = json["range"];
// std::string field_name = range_json["field_name"];
// range_query->field_name = field_name;
//
// auto range_value_json = range_json["values"];
// if (range_value_json.contains("lt")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::LT;
// compare_expr.operand = range_value_json["lt"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("lte")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::LTE;
// compare_expr.operand = range_value_json["lte"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("eq")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::EQ;
// compare_expr.operand = range_value_json["eq"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("ne")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::NE;
// compare_expr.operand = range_value_json["ne"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("gt")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::GT;
// compare_expr.operand = range_value_json["gt"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("gte")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::GTE;
// compare_expr.operand = range_value_json["gte"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
//
// leaf_query->range_query = range_query;
// query->AddLeafQuery(leaf_query);
// } else if (json.contains("vector")) {
// auto leaf_query = std::make_shared<query::LeafQuery>();
// auto vector_query = std::make_shared<query::VectorQuery>();
//
// auto vector_json = json["vector"];
// std::string field_name = vector_json["field_name"];
// vector_query->field_name = field_name;
//
// engine::VectorsData vectors;
// // TODO(yukun): process binary vector
// CopyRecordsFromJson(vector_json["values"], vectors, false);
//
// vector_query->query_vector.float_data = vectors.float_data_;
// vector_query->query_vector.binary_data = vectors.binary_data_;
//
// vector_query->topk = vector_json["topk"].get<int64_t>();
// vector_query->extra_params = vector_json["extra_params"];
//
// // TODO(yukun): remove hardcode here
// std::string vector_placeholder = "placeholder_1";
// query_ptr_->vectors.insert(std::make_pair(vector_placeholder, vector_query));
// leaf_query->vector_placeholder = vector_placeholder;
// query->AddLeafQuery(leaf_query);
}
return Status::OK();
}
Status
WebRequestHandler::ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query) {
if (query_json.contains("must")) {
boolean_query->SetOccur(query::Occur::MUST);
auto must_json = query_json["must"];
if (!must_json.is_array()) {
std::string msg = "Must json string is not an array";
return Status{BODY_PARSE_FAIL, msg};
}
for (auto& json : must_json) {
auto must_query = std::make_shared<query::BooleanQuery>();
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
ProcessBoolQueryJson(json, must_query);
boolean_query->AddBooleanQuery(must_query);
} else {
ProcessLeafQueryJson(json, boolean_query);
auto term_query = std::make_shared<query::TermQuery>();
nlohmann::json json_obj = json["term"];
JSON_NULL_CHECK(json_obj);
JSON_OBJECT_CHECK(json_obj);
term_query->json_obj = json_obj;
nlohmann::json::iterator json_it = json_obj.begin();
field_name = json_it.key();
leaf_query->term_query = term_query;
query->AddLeafQuery(leaf_query);
} else if (json.contains("range")) {
auto leaf_query = std::make_shared<query::LeafQuery>();
auto range_query = std::make_shared<query::RangeQuery>();
nlohmann::json json_obj = json["range"];
JSON_NULL_CHECK(json_obj);
JSON_OBJECT_CHECK(json_obj);
range_query->json_obj = json_obj;
nlohmann::json::iterator json_it = json_obj.begin();
field_name = json_it.key();
leaf_query->range_query = range_query;
query->AddLeafQuery(leaf_query);
} else if (json.contains("vector")) {
auto leaf_query = std::make_shared<query::LeafQuery>();
auto vector_json = json["vector"];
JSON_NULL_CHECK(vector_json);
std::random_device dev;
std::mt19937 rng(dev());
std::uniform_int_distribution<std::mt19937::result_type> dist(0, 64);
int64_t place_number = dist(rng);
std::string placeholder = "placeholder" + std::to_string(place_number);
leaf_query->vector_placeholder = placeholder;
query->AddLeafQuery(leaf_query);
auto vector_query = std::make_shared<query::VectorQuery>();
json::iterator vector_param_it = vector_json.begin();
if (vector_param_it != vector_json.end()) {
const std::string& vector_name = vector_param_it.key();
vector_query->field_name = vector_name;
nlohmann::json param_json = vector_param_it.value();
int64_t topk = param_json["topk"];
status = server::ValidateSearchTopk(topk);
if (!status.ok()) {
return status;
}
}
return Status::OK();
} else if (query_json.contains("should")) {
boolean_query->SetOccur(query::Occur::SHOULD);
auto should_json = query_json["should"];
if (!should_json.is_array()) {
std::string msg = "Should json string is not an array";
return Status{BODY_PARSE_FAIL, msg};
}
for (auto& json : should_json) {
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
auto should_query = std::make_shared<query::BooleanQuery>();
ProcessBoolQueryJson(json, should_query);
boolean_query->AddBooleanQuery(should_query);
} else {
ProcessLeafQueryJson(json, boolean_query);
vector_query->topk = topk;
if (param_json.contains("metric_type")) {
std::string metric_type = param_json["metric_type"];
vector_query->metric_type = metric_type;
query_ptr->metric_types.insert({vector_name, param_json["metric_type"]});
}
}
return Status::OK();
} else if (query_json.contains("must_not")) {
boolean_query->SetOccur(query::Occur::MUST_NOT);
auto should_json = query_json["must_not"];
if (!should_json.is_array()) {
std::string msg = "Must_not json string is not an array";
return Status{BODY_PARSE_FAIL, msg};
}
for (auto& json : should_json) {
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
auto must_not_query = std::make_shared<query::BooleanQuery>();
ProcessBoolQueryJson(json, must_not_query);
boolean_query->AddBooleanQuery(must_not_query);
} else {
ProcessLeafQueryJson(json, boolean_query);
if (!vector_param_it.value()["params"].empty()) {
vector_query->extra_params = vector_param_it.value()["params"];
}
engine::VectorsData vector_data;
for (auto& vector_records : vector_param_it.value()["values"]) {
// TODO: Binary vector???
for (auto& data : vector_records) {
vector_query->query_vector.float_data.emplace_back(data.get<float>());
}
}
query_ptr->index_fields.insert(vector_name);
}
return Status::OK();
query_ptr->vectors.insert(std::make_pair(placeholder, vector_query));
} else {
std::string msg = "Must json string doesnot include right query";
return Status{BODY_PARSE_FAIL, msg};
return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"};
}
return status;
}
void
ConvertRowToColumnJson(const std::vector<engine::AttrsData>& row_attrs, const std::vector<std::string>& field_names,
const int64_t row_num, nlohmann::json& column_attrs_json) {
// if (field_names.size() == 0) {
// if (row_attrs.size() > 0) {
// auto attr_it = row_attrs[0].attr_type_.begin();
// for (; attr_it != row_attrs[0].attr_type_.end(); attr_it++) {
// field_names.emplace_back(attr_it->first);
// }
// }
// }
Status
WebRequestHandler::ProcessBooleanQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query,
query::QueryPtr& query_ptr) {
auto status = Status::OK();
if (query_json.empty()) {
return Status{SERVER_INVALID_ARGUMENT, "BoolQuery is null"};
}
for (auto& el : query_json.items()) {
if (el.key() == "must") {
boolean_query->SetOccur(query::Occur::MUST);
auto must_json = el.value();
if (!must_json.is_array()) {
std::string msg = "Must json string is not an array";
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
}
for (uint64_t i = 0; i < field_names.size() - 1; i++) {
std::vector<int64_t> int_data;
std::vector<double> double_data;
for (auto& attr : row_attrs) {
int64_t int_value;
double double_value;
auto attr_data = attr.attr_data_.at(field_names[i]);
switch (attr.attr_type_.at(field_names[i])) {
case engine::DataType::INT8: {
if (attr_data.size() == sizeof(int8_t)) {
int_value = attr_data[0];
int_data.emplace_back(int_value);
}
break;
}
case engine::DataType::INT16: {
if (attr_data.size() == sizeof(int16_t)) {
memcpy(&int_value, attr_data.data(), sizeof(int16_t));
int_data.emplace_back(int_value);
}
break;
}
case engine::DataType::INT32: {
if (attr_data.size() == sizeof(int32_t)) {
memcpy(&int_value, attr_data.data(), sizeof(int32_t));
int_data.emplace_back(int_value);
}
break;
}
case engine::DataType::INT64: {
if (attr_data.size() == sizeof(int64_t)) {
memcpy(&int_value, attr_data.data(), sizeof(int64_t));
int_data.emplace_back(int_value);
for (auto& json : must_json) {
auto must_query = std::make_shared<query::BooleanQuery>();
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
STATUS_CHECK(ProcessBooleanQueryJson(json, must_query, query_ptr));
boolean_query->AddBooleanQuery(must_query);
} else {
std::string field_name;
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name, query_ptr));
if (!field_name.empty()) {
query_ptr->index_fields.insert(field_name);
}
break;
}
case engine::DataType::FLOAT: {
if (attr_data.size() == sizeof(float)) {
float float_value;
memcpy(&float_value, attr_data.data(), sizeof(float));
double_value = float_value;
double_data.emplace_back(double_value);
}
break;
}
case engine::DataType::DOUBLE: {
if (attr_data.size() == sizeof(double)) {
memcpy(&double_value, attr_data.data(), sizeof(double));
double_data.emplace_back(double_value);
}
} else if (el.key() == "should") {
boolean_query->SetOccur(query::Occur::SHOULD);
auto should_json = el.value();
if (!should_json.is_array()) {
std::string msg = "Should json string is not an array";
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
}
for (auto& json : should_json) {
auto should_query = std::make_shared<query::BooleanQuery>();
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
STATUS_CHECK(ProcessBooleanQueryJson(json, should_query, query_ptr));
boolean_query->AddBooleanQuery(should_query);
} else {
std::string field_name;
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name, query_ptr));
if (!field_name.empty()) {
query_ptr->index_fields.insert(field_name);
}
break;
}
default: { return; }
}
}
if (int_data.size() > 0) {
if (row_num == -1) {
nlohmann::json int_data_json(int_data);
column_attrs_json[field_names[i]] = int_data_json;
} else {
nlohmann::json topk_int_result;
int64_t topk = int_data.size() / row_num;
for (int64_t j = 0; j < row_num; j++) {
std::vector<int64_t> one_int_result(topk);
memcpy(one_int_result.data(), int_data.data() + j * topk, sizeof(int64_t) * topk);
nlohmann::json one_int_result_json(one_int_result);
std::string tag = "top" + std::to_string(j);
topk_int_result[tag] = one_int_result_json;
}
column_attrs_json[field_names[i]] = topk_int_result;
} else if (el.key() == "must_not") {
boolean_query->SetOccur(query::Occur::MUST_NOT);
auto should_json = el.value();
if (!should_json.is_array()) {
std::string msg = "Must_not json string is not an array";
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
}
} else if (double_data.size() > 0) {
if (row_num == -1) {
nlohmann::json double_data_json(double_data);
column_attrs_json[field_names[i]] = double_data_json;
} else {
nlohmann::json topk_double_result;
int64_t topk = int_data.size() / row_num;
for (int64_t j = 0; j < row_num; j++) {
std::vector<double> one_double_result(topk);
memcpy(one_double_result.data(), double_data.data() + j * topk, sizeof(double) * topk);
nlohmann::json one_double_result_json(one_double_result);
std::string tag = "top" + std::to_string(j);
topk_double_result[tag] = one_double_result_json;
for (auto& json : should_json) {
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
auto must_not_query = std::make_shared<query::BooleanQuery>();
STATUS_CHECK(ProcessBooleanQueryJson(json, must_not_query, query_ptr));
boolean_query->AddBooleanQuery(must_not_query);
} else {
std::string field_name;
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name, query_ptr));
if (!field_name.empty()) {
query_ptr->index_fields.insert(field_name);
}
}
column_attrs_json[field_names[i]] = topk_double_result;
}
} else {
std::string msg = "BoolQuery json string does not include bool query";
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
}
}
return status;
}
Status
......@@ -724,7 +678,7 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
auto boolean_query = std::make_shared<query::BooleanQuery>();
query_ptr_ = std::make_shared<query::Query>();
status = ProcessBoolQueryJson(boolean_query_json, boolean_query);
status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr_);
if (!status.ok()) {
return status;
}
......@@ -749,22 +703,75 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
return Status::OK();
}
auto step = result->result_ids_.size() / result->row_num_;
nlohmann::json search_result_json;
auto step = result->result_ids_.size() / result->row_num_; // topk
auto field_data = result->data_chunk_->fixed_fields_;
for (int64_t i = 0; i < result->row_num_; i++) {
nlohmann::json raw_result_json;
for (size_t j = 0; j < step; j++) {
nlohmann::json one_result_json;
one_result_json["id"] = std::to_string(result->result_ids_.at(i * step + j));
one_result_json["distance"] = std::to_string(result->result_distances_.at(i * step + j));
raw_result_json.emplace_back(one_result_json);
nlohmann::json one_entity_json;
for (const auto& field : field_mappings) {
auto field_name = field.first->GetName();
switch ((int64_t)field.first->GetFtype()) {
case engine::DataType::INT32: {
int32_t int32_value;
int64_t offset = (i * step + j) * sizeof(int32_t);
memcpy(&int32_value, field_data.at(field_name)->data_.data() + offset, sizeof(int32_t));
one_entity_json[field_name] = int32_value;
break;
}
case engine::DataType::INT64: {
int64_t int64_value;
int64_t offset = (i * step + j) * sizeof(int64_t);
memcpy(&int64_value, field_data.at(field_name)->data_.data() + offset, sizeof(int64_t));
one_entity_json[field_name] = int64_value;
break;
}
case engine::DataType::FLOAT: {
float float_value;
int64_t offset = (i * step + j) * sizeof(float);
memcpy(&float_value, field_data.at(field_name)->data_.data() + offset, sizeof(float));
one_entity_json[field_name] = float_value;
break;
}
case engine::DataType::DOUBLE: {
double double_value;
int64_t offset = (i * step + j) * sizeof(double);
memcpy(&double_value, field_data.at(field_name)->data_.data() + offset, sizeof(double));
one_entity_json[field_name] = double_value;
break;
}
case engine::DataType::VECTOR_FLOAT: {
std::vector<float> float_vector;
auto dim =
field_data.at(field_name)->data_.size() / (result->result_ids_.size() * sizeof(float));
int64_t offset = (i * step + j) * dim * sizeof(float);
float_vector.resize(dim);
memcpy(float_vector.data(), field_data.at(field_name)->data_.data() + offset,
dim * sizeof(float));
one_entity_json[field_name] = float_vector;
break;
}
case engine::DataType::VECTOR_BINARY: {
std::vector<int8_t> binary_vector;
auto dim = field_data.at(field_name)->data_.size() / (result->result_ids_.size());
int64_t offset = (i * step + j) * dim;
binary_vector.resize(dim);
memcpy(binary_vector.data(), field_data.at(field_name)->data_.data() + offset,
dim * sizeof(int8_t));
one_entity_json[field_name] = binary_vector;
break;
}
default: { return Status(SERVER_UNEXPECTED_ERROR, "Return field data type is wrong"); }
}
}
one_result_json["entity"] = one_entity_json;
raw_result_json.push_back(one_result_json);
}
search_result_json.emplace_back(raw_result_json);
result_json.emplace_back(raw_result_json);
}
nlohmann::json attr_json;
// ConvertRowToColumnJson(result->attrs_, query_ptr_->field_names, result->row_num_, attr_json);
result_json["Entity"] = attr_json;
result_json["result"] = search_result_json;
result_str = result_json.dump();
}
......@@ -774,7 +781,7 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
Status
WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohmann::json& json,
std::string& result_str) {
std::vector<int64_t> vector_ids;
std::vector<int64_t> entity_ids;
if (!json.contains("ids")) {
return Status(BODY_FIELD_LOSS, "Field \"delete\" must contains \"ids\"");
}
......@@ -788,10 +795,10 @@ WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohman
if (!ValidateStringIsNumber(id_str).ok()) {
return Status(ILLEGAL_BODY, "Members in \"ids\" must be integer string");
}
vector_ids.emplace_back(std::stol(id_str));
entity_ids.emplace_back(std::stol(id_str));
}
auto status = req_handler_.DeleteEntityByID(context_ptr_, collection_name, vector_ids);
auto status = req_handler_.DeleteEntityByID(context_ptr_, collection_name, entity_ids);
nlohmann::json result_json;
AddStatusToJson(result_json, status.code(), status.message());
......@@ -807,89 +814,23 @@ WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std:
engine::DataChunkPtr data_chunk;
engine::snapshot::FieldElementMappings field_mappings;
std::vector<engine::AttrsData> attr_batch;
std::vector<engine::VectorsData> vector_batch;
auto status = req_handler_.GetEntityByID(context_ptr_, collection_name, ids, field_names, valid_row, field_mappings,
data_chunk);
if (!status.ok()) {
return status;
}
std::vector<uint8_t> id_array = data_chunk->fixed_fields_[engine::FIELD_UID]->data_;
for (const auto& it : field_mappings) {
std::string name = it.first->GetName();
uint64_t type = it.first->GetFtype();
std::vector<uint8_t>& data = data_chunk->fixed_fields_[name]->data_;
if (type == engine::DataType::VECTOR_BINARY) {
engine::VectorsData vectors_data;
memcpy(vectors_data.binary_data_.data(), data.data(), data.size());
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
vector_batch.emplace_back(vectors_data);
} else if (type == engine::DataType::VECTOR_FLOAT) {
engine::VectorsData vectors_data;
memcpy(vectors_data.float_data_.data(), data.data(), data.size());
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
vector_batch.emplace_back(vectors_data);
} else {
engine::AttrsData attrs_data;
attrs_data.attr_type_[name] = static_cast<engine::DataType>(type);
attrs_data.attr_data_[name] = data;
memcpy(attrs_data.id_array_.data(), id_array.data(), id_array.size());
attr_batch.emplace_back(attrs_data);
}
}
bool bin;
status = IsBinaryCollection(collection_name, bin);
if (!status.ok()) {
return status;
}
nlohmann::json vectors_json, attrs_json;
for (size_t i = 0; i < vector_batch.size(); i++) {
nlohmann::json vector_json;
if (bin) {
vector_json["vector"] = vector_batch.at(i).binary_data_;
} else {
vector_json["vector"] = vector_batch.at(i).float_data_;
}
vector_json["id"] = std::to_string(ids[i]);
vectors_json.push_back(vector_json);
}
ConvertRowToColumnJson(attr_batch, field_names, -1, attrs_json);
json_out["vectors"] = vectors_json;
json_out["attributes"] = attrs_json;
return Status::OK();
}
Status
WebRequestHandler::GetVectorsByIDs(const std::string& collection_name, const std::vector<int64_t>& ids,
nlohmann::json& json_out) {
std::vector<engine::VectorsData> vector_batch;
auto status = Status::OK();
// auto status = req_handler_.GetVectorsByID(context_ptr_, collection_name, ids, vector_batch);
if (!status.ok()) {
return status;
}
bool bin;
status = IsBinaryCollection(collection_name, bin);
if (!status.ok()) {
return status;
}
nlohmann::json vectors_json;
for (size_t i = 0; i < vector_batch.size(); i++) {
nlohmann::json vector_json;
if (bin) {
vector_json["vector"] = vector_batch.at(i).binary_data_;
} else {
vector_json["vector"] = vector_batch.at(i).float_data_;
int64_t valid_size = 0;
for (auto row : valid_row) {
if (row) {
valid_size++;
}
vector_json["id"] = std::to_string(ids[i]);
json_out.push_back(vector_json);
}
std::vector<uint8_t> id_data = data_chunk->fixed_fields_[engine::FIELD_UID]->data_;
std::vector<int64_t> id_array(valid_size);
memcpy(id_array.data(), id_data.data(), valid_size * sizeof(int64_t));
CopyData2Json(data_chunk, field_mappings, id_array, json_out);
return Status::OK();
}
......@@ -1169,34 +1110,7 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt
* Collection {
*/
StatusDto::ObjectWrapper
WebRequestHandler::CreateCollection(const CollectionRequestDto::ObjectWrapper& collection_schema) {
if (nullptr == collection_schema->collection_name.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'collection_name\' is missing")
}
if (nullptr == collection_schema->dimension.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'dimension\' is missing")
}
if (nullptr == collection_schema->index_file_size.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_file_size\' is missing")
}
if (nullptr == collection_schema->metric_type.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'metric_type\' is missing")
}
auto status = Status::OK();
// auto status = req_handler_.CreateCollection(
// context_ptr_, collection_schema->collection_name->std_str(), collection_schema->dimension,
// collection_schema->index_file_size,
// static_cast<int64_t>(MetricNameMap.at(collection_schema->metric_type->std_str())));
ASSIGN_RETURN_STATUS_DTO(status)
}
StatusDto::ObjectWrapper
WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& body) {
WebRequestHandler::CreateCollection(const milvus::server::web::OString& body) {
auto json_str = nlohmann::json::parse(body->c_str());
std::string collection_name = json_str["collection_name"];
......@@ -1208,24 +1122,14 @@ WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& bo
field_schema.field_params_ = field["extra_params"];
const std::string& field_type = field["field_type"];
if (field_type == "int8") {
field_schema.field_type_ = engine::DataType::INT8;
} else if (field_type == "int16") {
field_schema.field_type_ = engine::DataType::INT16;
} else if (field_type == "int32") {
field_schema.field_type_ = engine::DataType::INT32;
} else if (field_type == "int64") {
field_schema.field_type_ = engine::DataType::INT64;
} else if (field_type == "float") {
field_schema.field_type_ = engine::DataType::FLOAT;
} else if (field_type == "double") {
field_schema.field_type_ = engine::DataType::DOUBLE;
} else if (field_type == "vector") {
} else {
std::string field_type = field["field_type"];
std::transform(field_type.begin(), field_type.end(), field_type.begin(), ::tolower);
if (str2type.find(field_type) == str2type.end()) {
std::string msg = field_name + " has wrong field_type";
RETURN_STATUS_DTO(BODY_PARSE_FAIL, msg.c_str());
}
field_schema.field_type_ = str2type.at(field_type);
fields[field_name] = field_schema;
}
......@@ -1336,18 +1240,15 @@ WebRequestHandler::DropCollection(const OString& collection_name) {
*/
StatusDto::ObjectWrapper
WebRequestHandler::CreateIndex(const OString& collection_name, const OString& body) {
WebRequestHandler::CreateIndex(const OString& collection_name, const OString& field_name, const OString& body) {
try {
auto request_json = nlohmann::json::parse(body->std_str());
std::string field_name, index_name;
if (!request_json.contains("index_type")) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_type\' is required");
}
auto status = Status::OK();
// auto status =
// req_handler_.CreateIndex(context_ptr_, collection_name->std_str(), index,
// request_json["params"]);
auto status =
req_handler_.CreateIndex(context_ptr_, collection_name->std_str(), field_name->std_str(), "", request_json);
ASSIGN_RETURN_STATUS_DTO(status);
} catch (nlohmann::detail::parse_error& e) {
RETURN_STATUS_DTO(BODY_PARSE_FAIL, e.what())
......@@ -1359,10 +1260,8 @@ WebRequestHandler::CreateIndex(const OString& collection_name, const OString& bo
}
StatusDto::ObjectWrapper
WebRequestHandler::DropIndex(const OString& collection_name) {
auto status = Status::OK();
// auto status = req_handler_.DropIndex(context_ptr_, collection_name->std_str());
WebRequestHandler::DropIndex(const OString& collection_name, const OString& field_name) {
auto status = req_handler_.DropIndex(context_ptr_, collection_name->std_str(), field_name->std_str(), "");
ASSIGN_RETURN_STATUS_DTO(status)
}
......@@ -1583,9 +1482,12 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
std::string partition_name = body_json["partition_tag"];
int32_t row_num = body_json["row_num"];
CollectionSchema collection_schema;
std::unordered_map<std::string, engine::DataType> field_types;
auto status = Status::OK();
// auto status = req_handler_.DescribeHybridCollection(context_ptr_, collection_name->c_str(), field_types);
auto status = req_handler_.GetCollectionInfo(context_ptr_, collection_name->std_str(), collection_schema);
for (const auto& field : collection_schema.fields_) {
field_types.insert({field.first, field.second.field_type_});
}
auto entities = body_json["entity"];
if (!entities.is_array()) {
......@@ -1621,15 +1523,12 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
break;
}
case engine::DataType::VECTOR_FLOAT: {
bool bin_flag;
status = IsBinaryCollection(collection_name->c_str(), bin_flag);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
// engine::VectorsData vectors;
// CopyRecordsFromJson(field_value, vectors, bin_flag);
// vector_datas.insert(std::make_pair(field_name, vectors));
CopyRecordsFromJson(field_value, temp_data, false);
break;
}
case engine::DataType::VECTOR_BINARY: {
CopyRecordsFromJson(field_value, temp_data, true);
break;
}
default: {}
}
......@@ -1702,47 +1601,7 @@ WebRequestHandler::GetEntity(const milvus::server::web::OString& collection_name
}
StatusDto::ObjectWrapper
WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) {
auto status = Status::OK();
try {
auto query_ids = query_params.get("ids");
if (query_ids == nullptr || query_ids.get() == nullptr) {
RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param ids is required.");
}
std::vector<std::string> ids;
StringHelpFunctions::SplitStringByDelimeter(query_ids->c_str(), ",", ids);
std::vector<int64_t> vector_ids;
for (auto& id : ids) {
vector_ids.push_back(std::stol(id));
}
engine::VectorsData vectors;
nlohmann::json vectors_json;
status = GetVectorsByIDs(collection_name->std_str(), vector_ids, vectors_json);
if (!status.ok()) {
response = "NULL";
ASSIGN_RETURN_STATUS_DTO(status)
}
FloatJson json;
json["code"] = (int64_t)status.code();
json["message"] = status.message();
if (vectors_json.empty()) {
json["vectors"] = std::vector<int64_t>();
} else {
json["vectors"] = vectors_json;
}
response = json.dump().c_str();
} catch (std::exception& e) {
RETURN_STATUS_DTO(SERVER_UNEXPECTED_ERROR, e.what());
}
ASSIGN_RETURN_STATUS_DTO(status);
}
StatusDto::ObjectWrapper
WebRequestHandler::VectorsOp(const OString& collection_name, const OString& payload, OString& response) {
WebRequestHandler::EntityOp(const OString& collection_name, const OString& payload, OString& response) {
auto status = Status::OK();
std::string result_str;
......
......@@ -85,7 +85,11 @@ class WebRequestHandler {
IsBinaryCollection(const std::string& collection_name, bool& bin);
Status
CopyRecordsFromJson(const nlohmann::json& json, engine::VectorsData& vectors, bool bin);
CopyRecordsFromJson(const nlohmann::json& json, std::vector<uint8_t>& vectors_data, bool bin);
Status
CopyData2Json(const engine::DataChunkPtr& data_chunk, const engine::snapshot::FieldElementMappings& field_mappings,
const std::vector<int64_t>& id_array, nlohmann::json& json_res);
protected:
Status
......@@ -124,10 +128,12 @@ class WebRequestHandler {
SetConfig(const nlohmann::json& json, std::string& result_str);
Status
ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& boolean_query);
ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& boolean_query, std::string& field_name,
query::QueryPtr& query_ptr);
Status
ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query);
ProcessBooleanQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query,
query::QueryPtr& query_ptr);
Status
Search(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
......@@ -135,9 +141,6 @@ class WebRequestHandler {
Status
DeleteByIDs(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
Status
GetVectorsByIDs(const std::string& collection_name, const std::vector<int64_t>& ids, nlohmann::json& json_out);
Status
GetEntityByIDs(const std::string& collection_name, const std::vector<int64_t>& ids,
std::vector<std::string>& field_names, nlohmann::json& json_out);
......@@ -167,12 +170,10 @@ class WebRequestHandler {
#endif
StatusDto::ObjectWrapper
CreateCollection(const CollectionRequestDto::ObjectWrapper& table_schema);
StatusDto::ObjectWrapper
ShowCollections(const OQueryParams& query_params, OString& result);
CreateCollection(const milvus::server::web::OString& body);
StatusDto::ObjectWrapper
CreateHybridCollection(const OString& body);
ShowCollections(const OQueryParams& query_params, OString& result);
StatusDto::ObjectWrapper
GetCollection(const OString& collection_name, const OQueryParams& query_params, OString& result);
......@@ -181,10 +182,10 @@ class WebRequestHandler {
DropCollection(const OString& collection_name);
StatusDto::ObjectWrapper
CreateIndex(const OString& collection_name, const OString& body);
CreateIndex(const OString& collection_name, const OString& field_name, const OString& body);
StatusDto::ObjectWrapper
DropIndex(const OString& collection_name);
DropIndex(const OString& collection_name, const OString& field_name);
StatusDto::ObjectWrapper
CreatePartition(const OString& collection_name, const PartitionRequestDto::ObjectWrapper& param);
......@@ -221,7 +222,7 @@ class WebRequestHandler {
GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response);
StatusDto::ObjectWrapper
VectorsOp(const OString& collection_name, const OString& payload, OString& response);
EntityOp(const OString& collection_name, const OString& payload, OString& response);
/**
*
......
......@@ -27,7 +27,7 @@ const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str();
constexpr int64_t COLLECTION_DIMENSION = 512;
constexpr int64_t COLLECTION_INDEX_FILE_SIZE = 1024;
constexpr milvus::MetricType COLLECTION_METRIC_TYPE = milvus::MetricType::L2;
constexpr int64_t BATCH_ENTITY_COUNT = 4000;
constexpr int64_t BATCH_ENTITY_COUNT = 10000;
constexpr int64_t NQ = 5;
constexpr int64_t TOP_K = 10;
constexpr int64_t NPROBE = 32;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册