未验证 提交 68a2918c 编写于 作者: Y yukun 提交者: GitHub

Add WebServer unittest (#3321)

* Add web server interface
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add unittest/server
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add web server ut
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 729c6ebd
......@@ -62,10 +62,10 @@ set(link_lib
milvus_engine
config
metrics
log
oatpp
query
utils
log
)
......
......@@ -294,7 +294,7 @@ message Command {
* @index_name: a name for index provided by user, unique within this field
* @extra_params: index parameters in json format
* for vector field:
* extra_params["index_type"] = one of the values: IDMAP, IVFLAT, IVFSQ8, NSGMIX, IVFSQ8H,
* extra_params["index_type"] = one of the values: FLAT, IVF_LAT, IVF_SQ8, NSGMIX, IVFSQ8H,
* PQ, HNSW, HNSW_SQ8NM, ANNOY
* extra_params["metric_type"] = one of the values: L2, IP, HAMMING, JACCARD, TANIMOTO
* SUBSTRUCTURE, SUPERSTRUCTURE
......
......@@ -16,16 +16,16 @@ namespace server {
namespace web {
const char* NAME_ENGINE_TYPE_FLAT = "FLAT";
const char* NAME_ENGINE_TYPE_IVFFLAT = "IVFFLAT";
const char* NAME_ENGINE_TYPE_IVFSQ8 = "IVFSQ8";
const char* NAME_ENGINE_TYPE_IVFSQ8H = "IVFSQ8H";
const char* NAME_ENGINE_TYPE_IVFFLAT = "IVF_FLAT";
const char* NAME_ENGINE_TYPE_IVFSQ8 = "IVF_SQ8";
const char* NAME_ENGINE_TYPE_IVFSQ8H = "IVF_SQ8H";
const char* NAME_ENGINE_TYPE_RNSG = "RNSG";
const char* NAME_ENGINE_TYPE_IVFPQ = "IVFPQ";
const char* NAME_ENGINE_TYPE_IVFPQ = "IVF_PQ";
const char* NAME_ENGINE_TYPE_HNSW = "HNSW";
const char* NAME_ENGINE_TYPE_ANNOY = "ANNOY";
const char* NAME_ENGINE_TYPE_RHNSWFLAT = "RHNSWFLAT";
const char* NAME_ENGINE_TYPE_RHNSWPQ = "RHNSWPQ";
const char* NAME_ENGINE_TYPE_RHNSWSQ8 = "RHNSWSQ8";
const char* NAME_ENGINE_TYPE_RHNSWFLAT = "RHNSW_FLAT";
const char* NAME_ENGINE_TYPE_RHNSWPQ = "RHNSW_PQ";
const char* NAME_ENGINE_TYPE_RHNSWSQ8 = "RHNSW_SQ8";
const char* NAME_METRIC_TYPE_L2 = "L2";
const char* NAME_METRIC_TYPE_IP = "IP";
......@@ -47,6 +47,8 @@ const int64_t VALUE_INDEX_NLIST_DEFAULT = 16384;
const int64_t VALUE_CONFIG_CPU_CACHE_CAPACITY_DEFAULT = 4;
const bool VALUE_CONFIG_CACHE_INSERT_DATA_DEFAULT = false;
const char* NAME_ID = "__id";
/////////////////////////////////////////////////////
} // namespace web
......
......@@ -46,6 +46,7 @@ extern const char* VALUE_INDEX_INDEX_TYPE_DEFAULT;
extern const int64_t VALUE_INDEX_NLIST_DEFAULT;
extern const int64_t VALUE_CONFIG_CPU_CACHE_CAPACITY_DEFAULT;
extern const bool VALUE_CONFIG_CACHE_INSERT_DATA_DEFAULT;
extern const char* NAME_ID;
/////////////////////////////////////////////////////
......
......@@ -676,9 +676,10 @@ Updates the index type and nlist of a collection.
<tr><td>Header </td><td><pre><code>accept: application/json</code></pre> </td></tr>
<tr><td>Body</td><td><pre><code>
{
"index_type": string,
"index_type": "IVF_FLAT",
"metric_type": "IP",
"params": {
......
"nlist": 1024
}
}
</code></pre> </td></tr>
......
......@@ -80,6 +80,13 @@ static std::map<std::string, engine::DataType> str2type = {{"int32", engine::Dat
{"vector_float", engine::DataType::VECTOR_FLOAT},
{"vector_binary", engine::DataType::VECTOR_BINARY}};
static std::map<engine::DataType, std::string> type2str = {{engine::DataType::INT32, "int32"},
{engine::DataType::INT64, "int64"},
{engine::DataType::FLOAT, "float"},
{engine::DataType::DOUBLE, "double"},
{engine::DataType::VECTOR_FLOAT, "vector_float"},
{engine::DataType::VECTOR_BINARY, "vector_binary"}};
} // namespace web
} // namespace server
} // namespace milvus
......@@ -33,8 +33,7 @@ namespace milvus::server::web {
#define WEB_LOG_PREFIX "[Web] "
#define ADD_DEFAULT_CORS(endpoint) \
ADD_CORS(endpoint, "*", "OPTIONS, GET, POST, PUT, DELETE")
#define ADD_DEFAULT_CORS(endpoint) ADD_CORS(endpoint, "*", "OPTIONS, GET, POST, PUT, DELETE")
class WebController : public oatpp::web::server::api::ApiController {
public:
......@@ -280,7 +279,6 @@ class WebController : public oatpp::web::server::api::ApiController {
// "count": 58
// })";
response = createResponse(Status::CODE_200, result);
return response;
}
......@@ -364,16 +362,16 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_DEFAULT_CORS(IndexOptions)
ENDPOINT("OPTIONS", "/collections/{collection_name}/indexes", IndexOptions) {
ENDPOINT("OPTIONS", "/collections/{collection_name}/fields/{field_name}/indexes", IndexOptions) {
return createResponse(Status::CODE_204, "No Content");
}
ADD_DEFAULT_CORS(CreateIndex)
ENDPOINT("POST", "/collections/{collection_name}/fields/{field_name}/indexes/{index_name}", CreateIndex,
PATH(String, collection_name), PATH(String, field_name), PATH(String, index_name),
BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + "/indexes\'");
ENDPOINT("POST", "/collections/{collection_name}/fields/{field_name}/indexes", CreateIndex,
PATH(String, collection_name), PATH(String, field_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() +
"/indexes\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
......@@ -432,10 +430,10 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_DEFAULT_CORS(DropIndex)
ENDPOINT("DELETE", "/collections/{collection_name}/fields/{field_name}/indexes/{index_name}", DropIndex,
PATH(String, collection_name), PATH(String, field_name), PATH(String, index_name)) {
ENDPOINT("DELETE", "/collections/{collection_name}/fields/{field_name}/indexes", DropIndex,
PATH(String, collection_name), PATH(String, field_name)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
"/indexes\'");
"/fields/" + field_name->std_str() + "/indexes\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
......@@ -559,12 +557,12 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_DEFAULT_CORS(GetEntities)
ENDPOINT("GET", "/collections/{collection_name}/partitions/{partition_tag}/entities", GetEntities,
PATH(String, collection_name), PATH(String, partition_tag), QUERIES(QueryParams, query_params),
BODY_STRING(String, body)) {
ENDPOINT("GET", "/collections/{collection_name}/entities", GetEntities, PATH(String, collection_name),
QUERIES(QueryParams, query_params)) {
auto handler = WebRequestHandler();
String response;
String response;
auto status_dto = handler.GetEntity(collection_name, query_params, response);
switch (*(status_dto->code)) {
case StatusCode::SUCCESS:
......@@ -622,21 +620,21 @@ class WebController : public oatpp::web::server::api::ApiController {
}
}
ADD_DEFAULT_CORS(VectorsOptions)
ADD_CORS(EntityOptions)
ENDPOINT("OPTIONS", "/collections/{collection_name}/entities", VectorsOptions) {
ENDPOINT("OPTIONS", "/collections/{collection_name}/entities", EntityOptions) {
return createResponse(Status::CODE_204, "No Content");
}
ADD_DEFAULT_CORS(InsertEntity)
ADD_DEFAULT_CORS(Insert)
ENDPOINT("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name),
ENDPOINT("POST", "/collections/{collection_name}/entities", Insert, PATH(String, collection_name),
BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections/" + collection_name->std_str() +
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() +
"/entities\'");
tr.RecordSection("Received request.");
auto ids_dto = VectorIdsDto::createShared();
auto ids_dto = EntityIdsDto::createShared();
WebRequestHandler handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
......@@ -751,4 +749,3 @@ class WebController : public oatpp::web::server::api::ApiController {
};
} // namespace milvus::server::web
......@@ -18,14 +18,14 @@ namespace milvus::server::web {
#include OATPP_CODEGEN_BEGIN(DTO)
class VectorIdsDto : public ODTO {
DTO_INIT(VectorIdsDto, DTO)
class EntityIdsDto : public ODTO {
DTO_INIT(EntityIdsDto, DTO)
DTO_FIELD(List<String>, ids);
};
#include OATPP_CODEGEN_END(DTO)
using VectorIdsDtoT = oatpp::Object<VectorIdsDto>;
using EntityIdsDtoT = oatpp::Object<EntityIdsDto>;
} // namespace milvus::server::web
......@@ -97,6 +97,43 @@ CopyStructuredData(const nlohmann::json& json, std::vector<uint8_t>& raw) {
memcpy(raw.data(), values.data(), size * sizeof(T));
}
void
CopyRowVectorFromJson(const nlohmann::json& json, std::vector<uint8_t>& vectors_data, bool bin) {
// if (!json.is_array()) {
// return Status(ILLEGAL_BODY, "field \"vectors\" must be a array");
// }
std::vector<float> float_vector;
if (!bin) {
for (auto& data : json) {
float_vector.emplace_back(data.get<float>());
}
auto size = float_vector.size() * sizeof(float);
vectors_data.resize(size);
memcpy(vectors_data.data(), float_vector.data(), size);
} else {
for (auto& data : json) {
vectors_data.emplace_back(data.get<uint8_t>());
}
}
}
template <typename T>
void
CopyRowStructuredData(const nlohmann::json& entity_json, const std::string& field_name, const int64_t offset,
const int64_t row_num, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data) {
T value = entity_json.get<T>();
std::vector<uint8_t> temp_data(sizeof(T), 0);
memcpy(temp_data.data(), &value, sizeof(T));
if (chunk_data.find(field_name) == chunk_data.end()) {
std::vector<uint8_t> T_data(row_num * sizeof(T), 0);
memcpy(T_data.data(), temp_data.data(), sizeof(T));
chunk_data.insert({field_name, T_data});
} else {
int64_t T_offset = offset * sizeof(T);
memcpy(chunk_data.at(field_name).data() + T_offset, temp_data.data(), sizeof(T));
}
}
using FloatJson = nlohmann::basic_json<std::map, std::vector, std::string, bool, std::int64_t, std::uint64_t, float>;
/////////////////////////////////// Private methods ///////////////////////////////////////
......@@ -236,12 +273,15 @@ WebRequestHandler::GetCollectionMetaInfo(const std::string& collection_name, nlo
json_out["collection_name"] = schema.collection_name_;
for (const auto& field : schema.fields_) {
if (field.first == engine::FIELD_UID) {
continue;
}
nlohmann::json field_json;
field_json["field_name"] = field.first;
field_json["field_type"] = field.second.field_type_;
field_json["field_type"] = type2str.at(field.second.field_type_);
field_json["index_params"] = field.second.index_params_;
field_json["extra_params"] = field.second.field_params_;
json_out["field"].push_back(field_json);
json_out["fields"].push_back(field_json);
}
return Status::OK();
}
......@@ -548,9 +588,16 @@ WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::quer
}
engine::VectorsData vector_data;
for (auto& vector_records : vector_param_it.value()["values"]) {
// TODO: Binary vector???
for (auto& data : vector_records) {
vector_query->query_vector.float_data.emplace_back(data.get<float>());
if (field_type_.find(vector_name) != field_type_.end()) {
if (field_type_.at(vector_name) == engine::DataType::VECTOR_FLOAT) {
for (auto& data : vector_records) {
vector_query->query_vector.float_data.emplace_back(data.get<float>());
}
} else if (field_type_.at(vector_name) == engine::DataType::VECTOR_BINARY) {
for (auto& data : vector_records) {
vector_query->query_vector.binary_data.emplace_back(data.get<int8_t>());
}
}
}
}
query_ptr->index_fields.insert(vector_name);
......@@ -653,6 +700,9 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
if (!status.ok()) {
return Status{UNEXPECTED_ERROR, "DescribeHybridCollection failed"};
}
for (const auto& field : collection_schema.fields_) {
field_type_.insert({field.first, field.second.field_type_});
}
milvus::json extra_params;
if (json.contains("fields")) {
......@@ -1136,6 +1186,12 @@ WebRequestHandler::CreateCollection(const milvus::server::web::OString& body) {
}
milvus::json json_params;
if (json_str.contains(engine::PARAM_SEGMENT_ROW_COUNT)) {
json_params[engine::PARAM_SEGMENT_ROW_COUNT] = json_str[engine::PARAM_SEGMENT_ROW_COUNT];
}
if (json_str.contains(engine::PARAM_UID_AUTOGEN)) {
json_params[engine::PARAM_UID_AUTOGEN] = json_str[engine::PARAM_UID_AUTOGEN];
}
auto status = req_handler_.CreateCollection(context_ptr_, collection_name, fields, json_params);
......@@ -1198,8 +1254,8 @@ WebRequestHandler::ShowCollections(const OQueryParams& query_params, OString& re
result_json["collections"] = collections_json;
}
AddStatusToJson(result_json, status.code(), status.message());
result = result_json.dump().c_str();
ASSIGN_RETURN_STATUS_DTO(status)
}
......@@ -1474,36 +1530,107 @@ WebRequestHandler::GetSegmentInfo(const OString& collection_name, const OString&
*/
StatusDtoT
WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::server::web::OString& body,
VectorIdsDtoT& ids_dto) {
EntityIdsDtoT& ids_dto) {
if (nullptr == body.get() || body->getSize() == 0) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Request payload is required.")
}
auto body_json = nlohmann::json::parse(body->c_str());
std::string partition_name = body_json["partition_tag"];
int32_t row_num = body_json["row_num"];
std::string partition_name;
if (body_json.contains("partition_tag")) {
partition_name = body_json["partition_tag"];
}
CollectionSchema collection_schema;
std::unordered_map<std::string, engine::DataType> field_types;
auto status = req_handler_.GetCollectionInfo(context_ptr_, collection_name->std_str(), collection_schema);
if (!status.ok()) {
auto msg = "Collection " + collection_name->std_str() + " not exist";
RETURN_STATUS_DTO(COLLECTION_NOT_EXISTS, msg.c_str());
}
for (const auto& field : collection_schema.fields_) {
field_types.insert({field.first, field.second.field_type_});
}
auto entities = body_json["entity"];
if (!entities.is_array()) {
RETURN_STATUS_DTO(ILLEGAL_BODY, "An entity must be an array");
std::unordered_map<std::string, std::vector<uint8_t>> chunk_data;
int64_t row_num;
auto entities_json = body_json["entities"];
if (!entities_json.is_array()) {
RETURN_STATUS_DTO(ILLEGAL_ARGUMENT, "Entities is not an array");
}
row_num = entities_json.size();
int64_t offset = 0;
std::vector<uint8_t> ids;
for (auto& one_entity : entities_json) {
for (auto& entity : one_entity.items()) {
std::string field_name = entity.key();
if (field_name == NAME_ID) {
if (ids.empty()) {
ids.resize(row_num * sizeof(int64_t));
}
int64_t id = entity.value().get<int64_t>();
int64_t id_offset = offset * sizeof(int64_t);
memcpy(ids.data() + id_offset, &id, sizeof(int64_t));
continue;
}
std::vector<uint8_t> temp_data;
switch (field_types.at(field_name)) {
case engine::DataType::INT32: {
CopyRowStructuredData<int32_t>(entity.value(), field_name, offset, row_num, chunk_data);
break;
}
case engine::DataType::INT64: {
CopyRowStructuredData<int64_t>(entity.value(), field_name, offset, row_num, chunk_data);
break;
}
case engine::DataType::FLOAT: {
CopyRowStructuredData<float>(entity.value(), field_name, offset, row_num, chunk_data);
break;
}
case engine::DataType::DOUBLE: {
CopyRowStructuredData<double>(entity.value(), field_name, offset, row_num, chunk_data);
break;
}
case engine::DataType::VECTOR_FLOAT:
case engine::DataType::VECTOR_BINARY: {
bool is_bin = !(field_types.at(field_name) == engine::DataType::VECTOR_FLOAT);
CopyRowVectorFromJson(entity.value(), temp_data, is_bin);
auto size = temp_data.size();
if (chunk_data.find(field_name) == chunk_data.end()) {
std::vector<uint8_t> vector_data(row_num * size, 0);
memcpy(vector_data.data(), temp_data.data(), size);
chunk_data.insert({field_name, vector_data});
} else {
int64_t vector_offset = offset * size;
memcpy(chunk_data.at(field_name).data() + vector_offset, temp_data.data(), size);
}
break;
}
default: {}
}
}
offset++;
}
std::unordered_map<std::string, std::vector<uint8_t>> chunk_data;
if (!ids.empty()) {
chunk_data.insert({engine::FIELD_UID, ids});
}
for (auto& entity : entities) {
std::string field_name = entity["field_name"];
auto field_value = entity["field_value"];
auto size = field_value.size();
if (size != row_num) {
RETURN_STATUS_DTO(ILLEGAL_ROWRECORD, "Field row count inconsist");
#if 0
for (auto& entity : body_json["entities"].items()) {
std::string field_name = entity.key();
auto field_value = entity.value();
if (!field_value.is_array()) {
RETURN_STATUS_DTO(ILLEGAL_ROWRECORD, "Field value is not an array");
}
if (field_name == NAME_ID) {
std::vector<uint8_t> temp_data(field_value.size() * sizeof(int64_t), 0);
CopyStructuredData<int64_t>(field_value, temp_data);
chunk_data.insert({engine::FIELD_UID, temp_data});
continue;
}
row_num = field_value.size();
std::vector<uint8_t> temp_data;
switch (field_types.at(field_name)) {
......@@ -1536,6 +1663,7 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
chunk_data.insert(std::make_pair(field_name, temp_data));
}
#endif
status = req_handler_.Insert(context_ptr_, collection_name->c_str(), partition_name, row_num, chunk_data);
if (!status.ok()) {
......@@ -1547,6 +1675,7 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
if (pair != chunk_data.end()) {
int64_t count = pair->second.size() / 8;
int64_t* pdata = reinterpret_cast<int64_t*>(pair->second.data());
ids_dto->ids = ids_dto->ids.createShared();
for (int64_t i = 0; i < count; ++i) {
ids_dto->ids->push_back(std::to_string(pdata[i]).c_str());
}
......
......@@ -213,7 +213,7 @@ class WebRequestHandler {
* Vector
*/
StatusDtoT
InsertEntity(const OString& collection_name, const OString& body, VectorIdsDtoT& ids_dto);
InsertEntity(const OString& collection_name, const OString& body, EntityIdsDtoT& ids_dto);
StatusDtoT
GetEntity(const OString& collection_name, const OQueryParams& query_params, OString& response);
......
......@@ -38,6 +38,7 @@ set( SCHEDULER_FILES ${SCHEDULER_MAIN_FILES}
${SCHEDULER_SELECTOR_FILES}
${SCHEDULER_RESOURCE_FILES}
${SCHEDULER_TASK_FILES}
)
set( ENTRY_FILE ${CMAKE_CURRENT_SOURCE_DIR}/main.cpp )
......@@ -80,6 +81,7 @@ if ( MILVUS_WITH_AWS )
endif ()
add_subdirectory(db)
add_subdirectory(server)
#add_subdirectory(metrics)
#add_subdirectory(scheduler)
#add_subdirectory(thirdparty)
......
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
set( TEST_FILES
${CMAKE_CURRENT_SOURCE_DIR}/test_web.cpp
)
add_executable( test_server
${TEST_FILES}
${SCHEDULER_FILES}
# ${grpc_server_files}
# ${grpc_service_files}
# ${web_server_files}
# ${server_delivery_files}
# ${server_files}
# ${server_init_files}
)
get_target_property( var server INCLUDE_DIRECTORIES)
#set_target_properties( server PROPERTIES INTERFACE_INCLUDE_DIRECTORIES $<BUILD_INTERFACE:${var}>)
target_link_libraries( test_server
${UNITTEST_LIBS}
server
milvus_engine
metrics
config
stdc++
utils
tracing
query
log
)
install( TARGETS test_server DESTINATION unittest )
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册