Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
4beb0549
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
4beb0549
编写于
8月 14, 2020
作者:
Y
yukun
提交者:
GitHub
8月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add web server interface (#3257)
Signed-off-by:
N
fishpenguin
<
kun.yu@zilliz.com
>
上级
6f5be4b5
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
558 addition
and
855 deletion
+558
-855
core/src/server/grpc_impl/GrpcRequestHandler.cpp
core/src/server/grpc_impl/GrpcRequestHandler.cpp
+5
-5
core/src/server/web_impl/Types.h
core/src/server/web_impl/Types.h
+8
-0
core/src/server/web_impl/controller/WebController.hpp
core/src/server/web_impl/controller/WebController.hpp
+207
-372
core/src/server/web_impl/handler/WebRequestHandler.cpp
core/src/server/web_impl/handler/WebRequestHandler.cpp
+323
-464
core/src/server/web_impl/handler/WebRequestHandler.h
core/src/server/web_impl/handler/WebRequestHandler.h
+14
-13
sdk/examples/simple/src/ClientTest.cpp
sdk/examples/simple/src/ClientTest.cpp
+1
-1
未找到文件。
core/src/server/grpc_impl/GrpcRequestHandler.cpp
浏览文件 @
4beb0549
...
...
@@ -1698,17 +1698,17 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery(
if
(
vector_param_it
!=
it
.
value
().
end
())
{
const
std
::
string
&
field_name
=
vector_param_it
.
key
();
vector_query
->
field_name
=
field_name
;
nlohmann
::
json
vector
_json
=
vector_param_it
.
value
();
int64_t
topk
=
vector
_json
[
"topk"
];
nlohmann
::
json
param
_json
=
vector_param_it
.
value
();
int64_t
topk
=
param
_json
[
"topk"
];
status
=
server
::
ValidateSearchTopk
(
topk
);
if
(
!
status
.
ok
())
{
return
status
;
}
vector_query
->
topk
=
topk
;
if
(
vector
_json
.
contains
(
"metric_type"
))
{
std
::
string
metric_type
=
vector
_json
[
"metric_type"
];
if
(
param
_json
.
contains
(
"metric_type"
))
{
std
::
string
metric_type
=
param
_json
[
"metric_type"
];
vector_query
->
metric_type
=
metric_type
;
query_ptr
->
metric_types
.
insert
({
field_name
,
vector
_json
[
"metric_type"
]});
query_ptr
->
metric_types
.
insert
({
field_name
,
param
_json
[
"metric_type"
]});
}
if
(
!
vector_param_it
.
value
()[
"params"
].
empty
())
{
vector_query
->
extra_params
=
vector_param_it
.
value
()[
"params"
];
...
...
core/src/server/web_impl/Types.h
浏览文件 @
4beb0549
...
...
@@ -11,6 +11,7 @@
#pragma once
#include <map>
#include <string>
#include <unordered_map>
...
...
@@ -72,6 +73,13 @@ enum StatusCode : int {
MAX
=
ILLEGAL_QUERY_PARAM
};
static
std
::
map
<
std
::
string
,
engine
::
DataType
>
str2type
=
{{
"int32"
,
engine
::
DataType
::
INT32
},
{
"int64"
,
engine
::
DataType
::
INT64
},
{
"float"
,
engine
::
DataType
::
FLOAT
},
{
"double"
,
engine
::
DataType
::
DOUBLE
},
{
"vector_float"
,
engine
::
DataType
::
VECTOR_FLOAT
},
{
"vector_binary"
,
engine
::
DataType
::
VECTOR_BINARY
}};
}
// namespace web
}
// namespace server
}
// namespace milvus
core/src/server/web_impl/controller/WebController.hpp
浏览文件 @
4beb0549
...
...
@@ -215,74 +215,71 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS
(
CreateCollection
)
ENDPOINT
(
"POST"
,
"/collections"
,
CreateCollection
,
BODY_DTO
(
CollectionRequestDto
::
ObjectWrapper
,
body
))
{
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections\'");
// tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.CreateCollection(body);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_201, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
StatusDto
::
ObjectWrapper
status
;
auto
response
=
createDtoResponse
(
Status
::
CODE_200
,
status
);
ENDPOINT
(
"POST"
,
"/collections"
,
CreateCollection
,
BODY_STRING
(
String
,
body_str
))
{
TimeRecorder
tr
(
std
::
string
(
WEB_LOG_PREFIX
)
+
"POST
\'
/collections
\'
"
);
tr
.
RecordSection
(
"Received request."
);
WebRequestHandler
handler
=
WebRequestHandler
();
std
::
shared_ptr
<
OutgoingResponse
>
response
;
auto
status_dto
=
handler
.
CreateCollection
(
body_str
);
switch
(
status_dto
->
code
->
getValue
())
{
case
StatusCode
::
SUCCESS
:
response
=
createDtoResponse
(
Status
::
CODE_201
,
status_dto
);
break
;
default:
response
=
createDtoResponse
(
Status
::
CODE_400
,
status_dto
);
}
std
::
string
ttr
=
"Done. Status: code = "
+
std
::
to_string
(
status_dto
->
code
->
getValue
())
+
", reason = "
+
status_dto
->
message
->
std_str
()
+
". Total cost"
;
tr
.
ElapseFromBegin
(
ttr
);
return
response
;
}
ADD_CORS
(
ShowCollections
)
ENDPOINT
(
"GET"
,
"/collections"
,
ShowCollections
,
QUERIES
(
const
QueryParams
&
,
query_params
))
{
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections\'");
// tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// String result;
// auto status_dto = handler.ShowCollections(query_params, result);
// std::shared_ptr<OutgoingResponse> response;
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, result);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
json
result_json
=
R"({
"collections": [
{
"collection_name": "test_collection",
"fields": [
{
"field_name": "field_vec",
"field_type": "VECTOR_FLOAT",
"index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
"extra_params": {"dimension": 128, "metric_type": "L2"}
}
],
"segment_size": 1024
}
],
"count": 58
})"
;
String
result
=
result_json
.
dump
().
c_str
();
auto
response
=
createResponse
(
Status
::
CODE_200
,
result
);
TimeRecorder
tr
(
std
::
string
(
WEB_LOG_PREFIX
)
+
"GET
\'
/collections
\'
"
);
tr
.
RecordSection
(
"Received request."
);
WebRequestHandler
handler
=
WebRequestHandler
();
String
result
;
auto
status_dto
=
handler
.
ShowCollections
(
query_params
,
result
);
std
::
shared_ptr
<
OutgoingResponse
>
response
;
switch
(
status_dto
->
code
->
getValue
())
{
case
StatusCode
::
SUCCESS
:
response
=
createResponse
(
Status
::
CODE_200
,
result
);
break
;
default:
response
=
createDtoResponse
(
Status
::
CODE_400
,
status_dto
);
}
std
::
string
ttr
=
"Done. Status: code = "
+
std
::
to_string
(
status_dto
->
code
->
getValue
())
+
", reason = "
+
status_dto
->
message
->
std_str
()
+
". Total cost"
;
tr
.
ElapseFromBegin
(
ttr
);
// json result_json = R"({
// "collections": [
// {
// "collection_name": "test_collection",
// "fields": [
// {
// "field_name": "field_vec",
// "field_type": "VECTOR_FLOAT",
// "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
// "extra_params": {"dimension": 128, "metric_type": "L2"}
// }
// ],
// "segment_size": 1024
// }
// ],
// "count": 58
// })";
response
=
createResponse
(
Status
::
CODE_200
,
result
);
return
response
;
}
...
...
@@ -296,74 +293,71 @@ class WebController : public oatpp::web::server::api::ApiController {
ENDPOINT
(
"GET"
,
"/collections/{collection_name}"
,
GetCollection
,
PATH
(
String
,
collection_name
),
QUERIES
(
const
QueryParams
&
,
query_params
))
{
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
// "\'"); tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// String response_str;
// auto status_dto = handler.GetCollection(collection_name, query_params, response_str);
//
// std::shared_ptr<OutgoingResponse> response;
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, response_str);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
json
result_json
=
R"({
"collection_name": "test_collection",
"fields": [
{
"field_name": "field_vec",
"field_type": "VECTOR_FLOAT",
"index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
"extra_params": {"dimension": 128, "metric_type": "L2"}
}
],
"row_count": 10000
})"
;
auto
response
=
createResponse
(
Status
::
CODE_200
,
result_json
.
dump
().
c_str
());
TimeRecorder
tr
(
std
::
string
(
WEB_LOG_PREFIX
)
+
"GET
\'
/collections/"
+
collection_name
->
std_str
()
+
"
\'
"
);
tr
.
RecordSection
(
"Received request."
);
WebRequestHandler
handler
=
WebRequestHandler
();
String
response_str
;
auto
status_dto
=
handler
.
GetCollection
(
collection_name
,
query_params
,
response_str
);
std
::
shared_ptr
<
OutgoingResponse
>
response
;
switch
(
status_dto
->
code
->
getValue
())
{
case
StatusCode
::
SUCCESS
:
response
=
createResponse
(
Status
::
CODE_200
,
response_str
);
break
;
case
StatusCode
::
COLLECTION_NOT_EXISTS
:
response
=
createDtoResponse
(
Status
::
CODE_404
,
status_dto
);
break
;
default:
response
=
createDtoResponse
(
Status
::
CODE_400
,
status_dto
);
}
std
::
string
ttr
=
"Done. Status: code = "
+
std
::
to_string
(
status_dto
->
code
->
getValue
())
+
", reason = "
+
status_dto
->
message
->
std_str
()
+
". Total cost"
;
tr
.
ElapseFromBegin
(
ttr
);
// json result_json = R"({
// "collection_name": "test_collection",
// "fields": [
// {
// "field_name": "field_vec",
// "field_type": "VECTOR_FLOAT",
// "index_params": {"name": "index_1", "index_type": "IVFFLAT", "nlist": 4096},
// "extra_params": {"dimension": 128, "metric_type": "L2"}
// }
// ],
// "row_count": 10000
// })";
return
response
;
}
ADD_CORS
(
DropCollection
)
ENDPOINT
(
"DELETE"
,
"/collections/{collection_name}"
,
DropCollection
,
PATH
(
String
,
collection_name
))
{
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
// "\'"); tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.DropCollection(collection_name);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_204, status_dto);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
StatusDto
::
ObjectWrapper
status
;
auto
response
=
createDtoResponse
(
Status
::
CODE_201
,
status
);
TimeRecorder
tr
(
std
::
string
(
WEB_LOG_PREFIX
)
+
"DELETE
\'
/collections/"
+
collection_name
->
std_str
()
+
"
\'
"
);
tr
.
RecordSection
(
"Received request."
);
WebRequestHandler
handler
=
WebRequestHandler
();
std
::
shared_ptr
<
OutgoingResponse
>
response
;
auto
status_dto
=
handler
.
DropCollection
(
collection_name
);
switch
(
status_dto
->
code
->
getValue
())
{
case
StatusCode
::
SUCCESS
:
response
=
createDtoResponse
(
Status
::
CODE_204
,
status_dto
);
break
;
case
StatusCode
::
COLLECTION_NOT_EXISTS
:
response
=
createDtoResponse
(
Status
::
CODE_404
,
status_dto
);
break
;
default:
response
=
createDtoResponse
(
Status
::
CODE_400
,
status_dto
);
}
std
::
string
ttr
=
"Done. Status: code = "
+
std
::
to_string
(
status_dto
->
code
->
getValue
())
+
", reason = "
+
status_dto
->
message
->
std_str
()
+
". Total cost"
;
tr
.
ElapseFromBegin
(
ttr
);
return
response
;
}
...
...
@@ -378,97 +372,90 @@ class WebController : public oatpp::web::server::api::ApiController {
ENDPOINT
(
"POST"
,
"/collections/{collection_name}/fields/{field_name}/indexes/{index_name}"
,
CreateIndex
,
PATH
(
String
,
collection_name
),
PATH
(
String
,
field_name
),
PATH
(
String
,
index_name
),
BODY_STRING
(
String
,
body
))
{
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() +
// "/indexes\'"); tr.RecordSection("Received request.");
//
// auto handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.CreateIndex(collection_name, body);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_201, status_dto);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
StatusDto
::
ObjectWrapper
status
;
auto
response
=
createDtoResponse
(
Status
::
CODE_201
,
status
);
return
response
;
}
TimeRecorder
tr
(
std
::
string
(
WEB_LOG_PREFIX
)
+
"POST
\'
/tables/"
+
collection_name
->
std_str
()
+
"/indexes
\'
"
);
tr
.
RecordSection
(
"Received request."
);
auto
handler
=
WebRequestHandler
();
std
::
shared_ptr
<
OutgoingResponse
>
response
;
auto
status_dto
=
handler
.
CreateIndex
(
collection_name
,
field_name
,
body
);
switch
(
status_dto
->
code
->
getValue
())
{
case
StatusCode
::
SUCCESS
:
response
=
createDtoResponse
(
Status
::
CODE_201
,
status_dto
);
break
;
case
StatusCode
::
COLLECTION_NOT_EXISTS
:
response
=
createDtoResponse
(
Status
::
CODE_404
,
status_dto
);
break
;
default:
response
=
createDtoResponse
(
Status
::
CODE_400
,
status_dto
);
}
std
::
string
ttr
=
"Done. Status: code = "
+
std
::
to_string
(
status_dto
->
code
->
getValue
())
+
", reason = "
+
status_dto
->
message
->
std_str
()
+
". Total cost"
;
tr
.
ElapseFromBegin
(
ttr
);
ADD_CORS
(
GetIndex
)
ENDPOINT
(
"GET"
,
"/collections/{collection_name}/fields/{field_name}/indexes"
,
GetIndex
,
PATH
(
String
,
collection_name
),
PATH
(
String
,
field_name
))
{
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
// "/indexes\'");
// tr.RecordSection("Received request.");
//
// auto handler = WebRequestHandler();
//
// OString result;
// auto status_dto = handler.GetIndex(collection_name, result);
//
// std::shared_ptr<OutgoingResponse> response;
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, result);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
json
result
=
R"({ "index_name": "FLAT", "params": {"index_type": "IVF_FLAT", "nlist": 4096 } })"
;
auto
response
=
createResponse
(
Status
::
CODE_200
,
result
.
dump
().
c_str
());
return
response
;
}
// ADD_CORS(GetIndex)
//
// ENDPOINT("GET", "/collections/{collection_name}/fields/{field_name}/indexes", GetIndex,
// PATH(String, collection_name), PATH(String, field_name)) {
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
// "/indexes\'");
// tr.RecordSection("Received request.");
//
// auto handler = WebRequestHandler();
//
// OString result;
// auto status_dto = handler.GetIndex(collection_name, result);
//
// std::shared_ptr<OutgoingResponse> response;
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, result);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
//
// return response;
// }
ADD_CORS
(
DropIndex
)
ENDPOINT
(
"DELETE"
,
"/collections/{collection_name}/fields/{field_name}/indexes/{index_name}"
,
DropIndex
,
PATH
(
String
,
collection_name
),
PATH
(
String
,
field_name
),
PATH
(
String
,
index_name
))
{
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
// "/indexes\'");
// tr.RecordSection("Received request.");
//
// auto handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.DropIndex(collection_name);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_204, status_dto);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost";
// tr.ElapseFromBegin(ttr);
StatusDto
::
ObjectWrapper
status
;
auto
response
=
createDtoResponse
(
Status
::
CODE_204
,
status
);
TimeRecorder
tr
(
std
::
string
(
WEB_LOG_PREFIX
)
+
"DELETE
\'
/collections/"
+
collection_name
->
std_str
()
+
"/indexes
\'
"
);
tr
.
RecordSection
(
"Received request."
);
auto
handler
=
WebRequestHandler
();
std
::
shared_ptr
<
OutgoingResponse
>
response
;
auto
status_dto
=
handler
.
DropIndex
(
collection_name
,
field_name
);
switch
(
status_dto
->
code
->
getValue
())
{
case
StatusCode
::
SUCCESS
:
response
=
createDtoResponse
(
Status
::
CODE_204
,
status_dto
);
break
;
case
StatusCode
::
COLLECTION_NOT_EXISTS
:
response
=
createDtoResponse
(
Status
::
CODE_404
,
status_dto
);
break
;
default:
response
=
createDtoResponse
(
Status
::
CODE_400
,
status_dto
);
}
std
::
string
ttr
=
"Done. Status: code = "
+
std
::
to_string
(
status_dto
->
code
->
getValue
())
+
", reason = "
+
status_dto
->
message
->
std_str
()
+
". Total cost"
;
tr
.
ElapseFromBegin
(
ttr
);
return
response
;
}
...
...
@@ -574,23 +561,18 @@ class WebController : public oatpp::web::server::api::ApiController {
ENDPOINT
(
"GET"
,
"/collections/{collection_name}/partitions/{partition_tag}/entities"
,
GetEntities
,
PATH
(
String
,
collection_name
),
PATH
(
String
,
partition_tag
),
QUERIES
(
const
QueryParams
&
,
query_params
),
BODY_STRING
(
String
,
body
))
{
json
result
=
R"({
"entities": [
{
"__id": "1578989029645098000",
"field_1": 1,
"field_vec": []
},
{
"__id": "1578989029645098001",
"field_1": 2,
"field_vec": []
}
]
})"
;
auto
response
=
createResponse
(
Status
::
CODE_200
,
result
.
dump
().
c_str
());
return
response
;
auto
handler
=
WebRequestHandler
();
String
response
;
auto
status_dto
=
handler
.
GetEntity
(
collection_name
,
query_params
,
response
);
switch
(
status_dto
->
code
->
getValue
())
{
case
StatusCode
::
SUCCESS
:
return
createResponse
(
Status
::
CODE_200
,
response
);
case
StatusCode
::
COLLECTION_NOT_EXISTS
:
return
createDtoResponse
(
Status
::
CODE_404
,
status_dto
);
default:
return
createDtoResponse
(
Status
::
CODE_400
,
status_dto
);
}
}
ADD_CORS
(
ShowSegments
)
...
...
@@ -645,75 +627,6 @@ class WebController : public oatpp::web::server::api::ApiController {
return
createResponse
(
Status
::
CODE_204
,
"No Content"
);
}
ADD_CORS
(
GetVectors
)
/**
*
* GetVectorByID ?id=
*/
ENDPOINT
(
"GET"
,
"/collections/{collection_name}/Entities"
,
GetVectors
,
PATH
(
String
,
collection_name
),
QUERIES
(
const
QueryParams
&
,
query_params
))
{
// auto handler = WebRequestHandler();
// String response;
// auto status_dto = handler.GetVector(collection_name, query_params, response);
//
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// return createResponse(Status::CODE_200, response);
// case StatusCode::COLLECTION_NOT_EXISTS:
// return createDtoResponse(Status::CODE_404, status_dto);
// default:
// return createDtoResponse(Status::CODE_400, status_dto);
// }
json
result
=
R"({
"entities": [
{
"__id": "1578989029645098000",
"field_1": 1,
"field_vec": []
},
{
"__id": "1578989029645098001",
"field_1": 2,
"field_vec": []
}
]
})"
;
auto
response
=
createResponse
(
Status
::
CODE_200
,
result
.
dump
().
c_str
());
return
response
;
}
ADD_CORS
(
Insert
)
ENDPOINT
(
"POST"
,
"/collections/{collection_name}/entities"
,
Insert
,
PATH
(
String
,
collection_name
),
BODY_STRING
(
String
,
body
))
{
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() +
// "/vectors\'");
// tr.RecordSection("Received request.");
//
// auto ids_dto = VectorIdsDto::createShared();
// WebRequestHandler handler = WebRequestHandler();
//
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.Insert(collection_name, body, ids_dto);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createDtoResponse(Status::CODE_201, ids_dto);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost");
StatusDto
::
ObjectWrapper
status
;
auto
response
=
createDtoResponse
(
Status
::
CODE_201
,
status
);
return
response
;
}
ADD_CORS
(
InsertEntity
)
ENDPOINT
(
"POST"
,
"/hybrid_collections/{collection_name}/entities"
,
InsertEntity
,
PATH
(
String
,
collection_name
),
...
...
@@ -756,7 +669,7 @@ class WebController : public oatpp::web::server::api::ApiController {
OString
result
;
std
::
shared_ptr
<
OutgoingResponse
>
response
;
auto
status_dto
=
handler
.
Vectors
Op
(
collection_name
,
body
,
result
);
auto
status_dto
=
handler
.
Entity
Op
(
collection_name
,
body
,
result
);
switch
(
status_dto
->
code
->
getValue
())
{
case
StatusCode
::
SUCCESS
:
response
=
createResponse
(
Status
::
CODE_200
,
result
);
...
...
@@ -774,61 +687,6 @@ class WebController : public oatpp::web::server::api::ApiController {
return
response
;
}
ADD_CORS
(
VectorsOp
)
ENDPOINT
(
"PUT"
,
"/collections/{collection_name}/entities"
,
VectorsOp
,
PATH
(
String
,
collection_name
),
BODY_STRING
(
String
,
body
))
{
// TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() +
// "/vectors\'");
// tr.RecordSection("Received request.");
//
// WebRequestHandler handler = WebRequestHandler();
//
// OString result;
// std::shared_ptr<OutgoingResponse> response;
// auto status_dto = handler.VectorsOp(collection_name, body, result);
// switch (status_dto->code->getValue()) {
// case StatusCode::SUCCESS:
// response = createResponse(Status::CODE_200, result);
// break;
// case StatusCode::COLLECTION_NOT_EXISTS:
// response = createDtoResponse(Status::CODE_404, status_dto);
// break;
// default:
// response = createDtoResponse(Status::CODE_400, status_dto);
// }
//
// tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
// ", reason = " + status_dto->message->std_str() + ". Total cost");
json
result
=
R"({
"num": 2,
"results": [
[
{
"id": "1578989029645098000",
"distance": "0.000000",
"entity": {
"field_1": 1,
"field_2": 2,
"field_vec": []
}
},
{
"id": "1578989029645098001",
"distance": "0.010000",
"entity": {
"field_1": 10,
"field_2": 20,
"field_vec": []
}
}
]
]
})"
;
auto
response
=
createResponse
(
Status
::
CODE_200
,
result
.
dump
().
c_str
());
return
response
;
}
ADD_CORS
(
SystemOptions
)
ENDPOINT
(
"OPTIONS"
,
"/system/{info}"
,
SystemOptions
)
{
...
...
@@ -885,29 +743,6 @@ class WebController : public oatpp::web::server::api::ApiController {
return
response
;
}
ADD_CORS
(
CreateHybridCollection
)
ENDPOINT
(
"POST"
,
"/hybrid_collections"
,
CreateHybridCollection
,
BODY_STRING
(
String
,
body_str
))
{
TimeRecorder
tr
(
std
::
string
(
WEB_LOG_PREFIX
)
+
"POST
\'
/hybrid_collections
\'
"
);
tr
.
RecordSection
(
"Received request."
);
WebRequestHandler
handler
=
WebRequestHandler
();
std
::
shared_ptr
<
OutgoingResponse
>
response
;
auto
status_dto
=
handler
.
CreateHybridCollection
(
body_str
);
switch
(
status_dto
->
code
->
getValue
())
{
case
StatusCode
::
SUCCESS
:
response
=
createDtoResponse
(
Status
::
CODE_201
,
status_dto
);
break
;
default:
response
=
createDtoResponse
(
Status
::
CODE_400
,
status_dto
);
}
std
::
string
ttr
=
"Done. Status: code = "
+
std
::
to_string
(
status_dto
->
code
->
getValue
())
+
", reason = "
+
status_dto
->
message
->
std_str
()
+
". Total cost"
;
tr
.
ElapseFromBegin
(
ttr
);
return
response
;
}
/**
* Finish ENDPOINTs generation ('ApiController' codegen)
*/
...
...
core/src/server/web_impl/handler/WebRequestHandler.cpp
浏览文件 @
4beb0549
...
...
@@ -13,6 +13,7 @@
#include <algorithm>
#include <ctime>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
...
...
@@ -23,6 +24,7 @@
#include "db/Utils.h"
#include "metrics/SystemInfo.h"
#include "query/BinaryQuery.h"
#include "server/ValidationUtil.h"
#include "server/delivery/request/BaseReq.h"
#include "server/web_impl/Constants.h"
#include "server/web_impl/Types.h"
...
...
@@ -117,29 +119,31 @@ WebRequestHandler::IsBinaryCollection(const std::string& collection_name, bool&
}
Status
WebRequestHandler
::
CopyRecordsFromJson
(
const
nlohmann
::
json
&
json
,
engine
::
VectorsData
&
vectors
,
bool
bin
)
{
WebRequestHandler
::
CopyRecordsFromJson
(
const
nlohmann
::
json
&
json
,
std
::
vector
<
uint8_t
>&
vectors_data
,
bool
bin
)
{
if
(
!
json
.
is_array
())
{
return
Status
(
ILLEGAL_BODY
,
"field
\"
vectors
\"
must be a array"
);
}
vectors
.
vector_count_
=
json
.
size
();
std
::
vector
<
float
>
float_vector
;
if
(
!
bin
)
{
for
(
auto
&
vec
:
json
)
{
if
(
!
vec
.
is_array
())
{
return
Status
(
ILLEGAL_BODY
,
"A vector in field
\"
vectors
\"
must be a float array"
);
}
for
(
auto
&
data
:
vec
)
{
vectors
.
float_data_
.
emplace_back
(
data
.
get
<
float
>
());
float_vector
.
emplace_back
(
data
.
get
<
float
>
());
}
}
auto
size
=
float_vector
.
size
()
*
sizeof
(
float
);
vectors_data
.
resize
(
size
);
memcpy
(
vectors_data
.
data
(),
float_vector
.
data
(),
size
);
}
else
{
for
(
auto
&
vec
:
json
)
{
if
(
!
vec
.
is_array
())
{
return
Status
(
ILLEGAL_BODY
,
"A vector in field
\"
vectors
\"
must be a float array"
);
}
for
(
auto
&
data
:
vec
)
{
vectors
.
binary_data_
.
emplace_back
(
data
.
get
<
uint8_t
>
());
vectors
_data
.
emplace_back
(
data
.
get
<
uint8_t
>
());
}
}
}
...
...
@@ -147,6 +151,79 @@ WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, engine::Vecto
return
Status
::
OK
();
}
Status
WebRequestHandler
::
CopyData2Json
(
const
milvus
::
engine
::
DataChunkPtr
&
data_chunk
,
const
milvus
::
engine
::
snapshot
::
FieldElementMappings
&
field_mappings
,
const
std
::
vector
<
int64_t
>&
id_array
,
nlohmann
::
json
&
json_res
)
{
int64_t
id_size
=
id_array
.
size
();
for
(
int
i
=
0
;
i
<
id_size
;
i
++
)
{
nlohmann
::
json
one_json
;
nlohmann
::
json
entity_json
;
for
(
const
auto
&
it
:
field_mappings
)
{
auto
type
=
it
.
first
->
GetFtype
();
std
::
string
name
=
it
.
first
->
GetName
();
engine
::
BinaryDataPtr
data
=
data_chunk
->
fixed_fields_
[
name
];
if
(
data
==
nullptr
||
data
->
data_
.
empty
())
continue
;
auto
single_size
=
data
->
data_
.
size
()
/
id_size
;
switch
(
type
)
{
case
engine
::
DataType
::
INT32
:
{
int32_t
int32_value
;
int64_t
offset
=
sizeof
(
int32_t
)
*
i
;
memcpy
(
&
int32_value
,
data
->
data_
.
data
()
+
offset
,
sizeof
(
int32_t
));
entity_json
[
name
]
=
int32_value
;
break
;
}
case
engine
::
DataType
::
INT64
:
{
int64_t
int64_value
;
int64_t
offset
=
sizeof
(
int64_t
)
*
i
;
memcpy
(
&
int64_value
,
data
->
data_
.
data
()
+
offset
,
sizeof
(
int64_t
));
entity_json
[
name
]
=
int64_value
;
break
;
}
case
engine
::
DataType
::
FLOAT
:
{
float
float_value
;
int64_t
offset
=
sizeof
(
float
)
*
i
;
memcpy
(
&
float_value
,
data
->
data_
.
data
()
+
offset
,
sizeof
(
float
));
entity_json
[
name
]
=
float_value
;
break
;
}
case
engine
::
DataType
::
DOUBLE
:
{
double
double_value
;
int64_t
offset
=
sizeof
(
double
)
*
i
;
memcpy
(
&
double_value
,
data
->
data_
.
data
()
+
offset
,
sizeof
(
double
));
entity_json
[
name
]
=
double_value
;
break
;
}
case
engine
::
DataType
::
VECTOR_BINARY
:
{
std
::
vector
<
int8_t
>
binary_vector
;
auto
vector_size
=
single_size
*
sizeof
(
int8_t
)
/
sizeof
(
int8_t
);
binary_vector
.
resize
(
vector_size
);
int64_t
offset
=
vector_size
*
i
;
memcpy
(
binary_vector
.
data
(),
data
->
data_
.
data
()
+
offset
,
vector_size
);
entity_json
[
name
]
=
binary_vector
;
break
;
}
case
engine
::
DataType
::
VECTOR_FLOAT
:
{
std
::
vector
<
float
>
float_vector
;
auto
vector_size
=
single_size
*
sizeof
(
int8_t
)
/
sizeof
(
float
);
float_vector
.
resize
(
vector_size
);
int64_t
offset
=
vector_size
*
i
;
memcpy
(
float_vector
.
data
(),
data
->
data_
.
data
()
+
offset
,
vector_size
);
entity_json
[
name
]
=
float_vector
;
break
;
}
}
}
one_json
[
"entity"
]
=
entity_json
;
one_json
[
"id"
]
=
id_array
[
i
];
json_res
.
push_back
(
one_json
);
}
}
///////////////////////// WebRequestHandler methods ///////////////////////////////////////
Status
WebRequestHandler
::
GetCollectionMetaInfo
(
const
std
::
string
&
collection_name
,
nlohmann
::
json
&
json_out
)
{
...
...
@@ -157,12 +234,14 @@ WebRequestHandler::GetCollectionMetaInfo(const std::string& collection_name, nlo
STATUS_CHECK
(
req_handler_
.
CountEntities
(
context_ptr_
,
collection_name
,
count
));
json_out
[
"collection_name"
]
=
schema
.
collection_name_
;
json_out
[
"dimension"
]
=
schema
.
extra_params_
[
engine
::
PARAM_DIMENSION
].
get
<
int64_t
>
();
json_out
[
"segment_row_count"
]
=
schema
.
extra_params_
[
engine
::
PARAM_SEGMENT_ROW_COUNT
].
get
<
int64_t
>
();
json_out
[
"metric_type"
]
=
schema
.
extra_params_
[
engine
::
PARAM_INDEX_METRIC_TYPE
].
get
<
int64_t
>
();
json_out
[
"index_params"
]
=
schema
.
extra_params_
[
engine
::
PARAM_INDEX_EXTRA_PARAMS
].
get
<
std
::
string
>
();
json_out
[
"count"
]
=
count
;
for
(
const
auto
&
field
:
schema
.
fields_
)
{
nlohmann
::
json
field_json
;
field_json
[
"field_name"
]
=
field
.
first
;
field_json
[
"field_type"
]
=
field
.
second
.
field_type_
;
field_json
[
"index_params"
]
=
field
.
second
.
index_params_
;
field_json
[
"extra_params"
]
=
field
.
second
.
field_params_
;
json_out
[
"field"
].
push_back
(
field_json
);
}
return
Status
::
OK
();
}
...
...
@@ -194,7 +273,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& collection_name, int64_t
auto
new_ids
=
std
::
vector
<
int64_t
>
(
vector_ids
.
begin
()
+
ids_begin
,
vector_ids
.
begin
()
+
ids_end
);
nlohmann
::
json
vectors_json
;
auto
status
=
GetVectorsByIDs
(
collection_name
,
new_ids
,
vectors_json
);
//
auto status = GetVectorsByIDs(collection_name, new_ids, vectors_json);
nlohmann
::
json
result_json
;
if
(
vectors_json
.
empty
())
{
...
...
@@ -204,7 +283,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& collection_name, int64_t
}
json_out
[
"count"
]
=
vector_ids
.
size
();
AddStatusToJson
(
json_out
,
status
.
code
(),
status
.
message
());
//
AddStatusToJson(json_out, status.code(), status.message());
return
Status
::
OK
();
}
...
...
@@ -406,287 +485,162 @@ WebRequestHandler::SetConfig(const nlohmann::json& json, std::string& result_str
}
Status
WebRequestHandler
::
ProcessLeafQueryJson
(
const
nlohmann
::
json
&
json
,
milvus
::
query
::
BooleanQueryPtr
&
query
)
{
WebRequestHandler
::
ProcessLeafQueryJson
(
const
nlohmann
::
json
&
json
,
milvus
::
query
::
BooleanQueryPtr
&
query
,
std
::
string
&
field_name
,
query
::
QueryPtr
&
query_ptr
)
{
auto
status
=
Status
::
OK
();
if
(
json
.
contains
(
"term"
))
{
auto
leaf_query
=
std
::
make_shared
<
query
::
LeafQuery
>
();
auto
term_json
=
json
[
"term"
];
std
::
string
field_name
=
term_json
[
"field_name"
];
auto
term_value_json
=
term_json
[
"values"
];
if
(
!
term_value_json
.
is_array
())
{
std
::
string
msg
=
"Term json string is not an array"
;
return
Status
{
BODY_PARSE_FAIL
,
msg
};
}
// auto term_size = term_value_json.size();
// auto term_query = std::make_shared<query::TermQuery>();
// term_query->field_name = field_name;
// term_query->field_value.resize(term_size * sizeof(int64_t));
//
// switch (field_type_.at(field_name)) {
// case engine::DataType::INT8:
// case engine::DataType::INT16:
// case engine::DataType::INT32:
// case engine::DataType::INT64: {
// std::vector<int64_t> term_value(term_size, 0);
// for (uint64_t i = 0; i < term_size; ++i) {
// term_value[i] = term_value_json[i].get<int64_t>();
// }
// memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(int64_t));
// break;
// }
// case engine::DataType::FLOAT:
// case engine::DataType::DOUBLE: {
// std::vector<double> term_value(term_size, 0);
// for (uint64_t i = 0; i < term_size; ++i) {
// term_value[i] = term_value_json[i].get<double>();
// }
// memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(double));
// break;
// }
// default:
// break;
// }
//
// leaf_query->term_query = term_query;
// query->AddLeafQuery(leaf_query);
// } else if (json.contains("range")) {
// auto leaf_query = std::make_shared<query::LeafQuery>();
// auto range_query = std::make_shared<query::RangeQuery>();
//
// auto range_json = json["range"];
// std::string field_name = range_json["field_name"];
// range_query->field_name = field_name;
//
// auto range_value_json = range_json["values"];
// if (range_value_json.contains("lt")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::LT;
// compare_expr.operand = range_value_json["lt"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("lte")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::LTE;
// compare_expr.operand = range_value_json["lte"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("eq")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::EQ;
// compare_expr.operand = range_value_json["eq"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("ne")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::NE;
// compare_expr.operand = range_value_json["ne"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("gt")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::GT;
// compare_expr.operand = range_value_json["gt"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
// if (range_value_json.contains("gte")) {
// query::CompareExpr compare_expr;
// compare_expr.compare_operator = query::CompareOperator::GTE;
// compare_expr.operand = range_value_json["gte"].get<std::string>();
// range_query->compare_expr.emplace_back(compare_expr);
// }
//
// leaf_query->range_query = range_query;
// query->AddLeafQuery(leaf_query);
// } else if (json.contains("vector")) {
// auto leaf_query = std::make_shared<query::LeafQuery>();
// auto vector_query = std::make_shared<query::VectorQuery>();
//
// auto vector_json = json["vector"];
// std::string field_name = vector_json["field_name"];
// vector_query->field_name = field_name;
//
// engine::VectorsData vectors;
// // TODO(yukun): process binary vector
// CopyRecordsFromJson(vector_json["values"], vectors, false);
//
// vector_query->query_vector.float_data = vectors.float_data_;
// vector_query->query_vector.binary_data = vectors.binary_data_;
//
// vector_query->topk = vector_json["topk"].get<int64_t>();
// vector_query->extra_params = vector_json["extra_params"];
//
// // TODO(yukun): remove hardcode here
// std::string vector_placeholder = "placeholder_1";
// query_ptr_->vectors.insert(std::make_pair(vector_placeholder, vector_query));
// leaf_query->vector_placeholder = vector_placeholder;
// query->AddLeafQuery(leaf_query);
}
return
Status
::
OK
();
}
Status
WebRequestHandler
::
ProcessBoolQueryJson
(
const
nlohmann
::
json
&
query_json
,
query
::
BooleanQueryPtr
&
boolean_query
)
{
if
(
query_json
.
contains
(
"must"
))
{
boolean_query
->
SetOccur
(
query
::
Occur
::
MUST
);
auto
must_json
=
query_json
[
"must"
];
if
(
!
must_json
.
is_array
())
{
std
::
string
msg
=
"Must json string is not an array"
;
return
Status
{
BODY_PARSE_FAIL
,
msg
};
}
for
(
auto
&
json
:
must_json
)
{
auto
must_query
=
std
::
make_shared
<
query
::
BooleanQuery
>
();
if
(
json
.
contains
(
"must"
)
||
json
.
contains
(
"should"
)
||
json
.
contains
(
"must_not"
))
{
ProcessBoolQueryJson
(
json
,
must_query
);
boolean_query
->
AddBooleanQuery
(
must_query
);
}
else
{
ProcessLeafQueryJson
(
json
,
boolean_query
);
auto
term_query
=
std
::
make_shared
<
query
::
TermQuery
>
();
nlohmann
::
json
json_obj
=
json
[
"term"
];
JSON_NULL_CHECK
(
json_obj
);
JSON_OBJECT_CHECK
(
json_obj
);
term_query
->
json_obj
=
json_obj
;
nlohmann
::
json
::
iterator
json_it
=
json_obj
.
begin
();
field_name
=
json_it
.
key
();
leaf_query
->
term_query
=
term_query
;
query
->
AddLeafQuery
(
leaf_query
);
}
else
if
(
json
.
contains
(
"range"
))
{
auto
leaf_query
=
std
::
make_shared
<
query
::
LeafQuery
>
();
auto
range_query
=
std
::
make_shared
<
query
::
RangeQuery
>
();
nlohmann
::
json
json_obj
=
json
[
"range"
];
JSON_NULL_CHECK
(
json_obj
);
JSON_OBJECT_CHECK
(
json_obj
);
range_query
->
json_obj
=
json_obj
;
nlohmann
::
json
::
iterator
json_it
=
json_obj
.
begin
();
field_name
=
json_it
.
key
();
leaf_query
->
range_query
=
range_query
;
query
->
AddLeafQuery
(
leaf_query
);
}
else
if
(
json
.
contains
(
"vector"
))
{
auto
leaf_query
=
std
::
make_shared
<
query
::
LeafQuery
>
();
auto
vector_json
=
json
[
"vector"
];
JSON_NULL_CHECK
(
vector_json
);
std
::
random_device
dev
;
std
::
mt19937
rng
(
dev
());
std
::
uniform_int_distribution
<
std
::
mt19937
::
result_type
>
dist
(
0
,
64
);
int64_t
place_number
=
dist
(
rng
);
std
::
string
placeholder
=
"placeholder"
+
std
::
to_string
(
place_number
);
leaf_query
->
vector_placeholder
=
placeholder
;
query
->
AddLeafQuery
(
leaf_query
);
auto
vector_query
=
std
::
make_shared
<
query
::
VectorQuery
>
();
json
::
iterator
vector_param_it
=
vector_json
.
begin
();
if
(
vector_param_it
!=
vector_json
.
end
())
{
const
std
::
string
&
vector_name
=
vector_param_it
.
key
();
vector_query
->
field_name
=
vector_name
;
nlohmann
::
json
param_json
=
vector_param_it
.
value
();
int64_t
topk
=
param_json
[
"topk"
];
status
=
server
::
ValidateSearchTopk
(
topk
);
if
(
!
status
.
ok
())
{
return
status
;
}
}
return
Status
::
OK
();
}
else
if
(
query_json
.
contains
(
"should"
))
{
boolean_query
->
SetOccur
(
query
::
Occur
::
SHOULD
);
auto
should_json
=
query_json
[
"should"
];
if
(
!
should_json
.
is_array
())
{
std
::
string
msg
=
"Should json string is not an array"
;
return
Status
{
BODY_PARSE_FAIL
,
msg
};
}
for
(
auto
&
json
:
should_json
)
{
if
(
json
.
contains
(
"must"
)
||
json
.
contains
(
"should"
)
||
json
.
contains
(
"must_not"
))
{
auto
should_query
=
std
::
make_shared
<
query
::
BooleanQuery
>
();
ProcessBoolQueryJson
(
json
,
should_query
);
boolean_query
->
AddBooleanQuery
(
should_query
);
}
else
{
ProcessLeafQueryJson
(
json
,
boolean_query
);
vector_query
->
topk
=
topk
;
if
(
param_json
.
contains
(
"metric_type"
))
{
std
::
string
metric_type
=
param_json
[
"metric_type"
];
vector_query
->
metric_type
=
metric_type
;
query_ptr
->
metric_types
.
insert
({
vector_name
,
param_json
[
"metric_type"
]});
}
}
return
Status
::
OK
();
}
else
if
(
query_json
.
contains
(
"must_not"
))
{
boolean_query
->
SetOccur
(
query
::
Occur
::
MUST_NOT
);
auto
should_json
=
query_json
[
"must_not"
];
if
(
!
should_json
.
is_array
())
{
std
::
string
msg
=
"Must_not json string is not an array"
;
return
Status
{
BODY_PARSE_FAIL
,
msg
};
}
for
(
auto
&
json
:
should_json
)
{
if
(
json
.
contains
(
"must"
)
||
json
.
contains
(
"should"
)
||
json
.
contains
(
"must_not"
))
{
auto
must_not_query
=
std
::
make_shared
<
query
::
BooleanQuery
>
();
ProcessBoolQueryJson
(
json
,
must_not_query
);
boolean_query
->
AddBooleanQuery
(
must_not_query
);
}
else
{
ProcessLeafQueryJson
(
json
,
boolean_query
);
if
(
!
vector_param_it
.
value
()[
"params"
].
empty
())
{
vector_query
->
extra_params
=
vector_param_it
.
value
()[
"params"
];
}
engine
::
VectorsData
vector_data
;
for
(
auto
&
vector_records
:
vector_param_it
.
value
()[
"values"
])
{
// TODO: Binary vector???
for
(
auto
&
data
:
vector_records
)
{
vector_query
->
query_vector
.
float_data
.
emplace_back
(
data
.
get
<
float
>
());
}
}
query_ptr
->
index_fields
.
insert
(
vector_name
);
}
return
Status
::
OK
();
query_ptr
->
vectors
.
insert
(
std
::
make_pair
(
placeholder
,
vector_query
));
}
else
{
std
::
string
msg
=
"Must json string doesnot include right query"
;
return
Status
{
BODY_PARSE_FAIL
,
msg
};
return
Status
{
SERVER_INVALID_ARGUMENT
,
"Leaf query get wrong key"
};
}
return
status
;
}
void
ConvertRowToColumnJson
(
const
std
::
vector
<
engine
::
AttrsData
>&
row_attrs
,
const
std
::
vector
<
std
::
string
>&
field_names
,
const
int64_t
row_num
,
nlohmann
::
json
&
column_attrs_json
)
{
// if (field_names.size() == 0) {
// if (row_attrs.size() > 0) {
// auto attr_it = row_attrs[0].attr_type_.begin();
// for (; attr_it != row_attrs[0].attr_type_.end(); attr_it++) {
// field_names.emplace_back(attr_it->first);
// }
// }
// }
Status
WebRequestHandler
::
ProcessBooleanQueryJson
(
const
nlohmann
::
json
&
query_json
,
query
::
BooleanQueryPtr
&
boolean_query
,
query
::
QueryPtr
&
query_ptr
)
{
auto
status
=
Status
::
OK
();
if
(
query_json
.
empty
())
{
return
Status
{
SERVER_INVALID_ARGUMENT
,
"BoolQuery is null"
};
}
for
(
auto
&
el
:
query_json
.
items
())
{
if
(
el
.
key
()
==
"must"
)
{
boolean_query
->
SetOccur
(
query
::
Occur
::
MUST
);
auto
must_json
=
el
.
value
();
if
(
!
must_json
.
is_array
())
{
std
::
string
msg
=
"Must json string is not an array"
;
return
Status
{
SERVER_INVALID_DSL_PARAMETER
,
msg
};
}
for
(
uint64_t
i
=
0
;
i
<
field_names
.
size
()
-
1
;
i
++
)
{
std
::
vector
<
int64_t
>
int_data
;
std
::
vector
<
double
>
double_data
;
for
(
auto
&
attr
:
row_attrs
)
{
int64_t
int_value
;
double
double_value
;
auto
attr_data
=
attr
.
attr_data_
.
at
(
field_names
[
i
]);
switch
(
attr
.
attr_type_
.
at
(
field_names
[
i
]))
{
case
engine
::
DataType
::
INT8
:
{
if
(
attr_data
.
size
()
==
sizeof
(
int8_t
))
{
int_value
=
attr_data
[
0
];
int_data
.
emplace_back
(
int_value
);
}
break
;
}
case
engine
::
DataType
::
INT16
:
{
if
(
attr_data
.
size
()
==
sizeof
(
int16_t
))
{
memcpy
(
&
int_value
,
attr_data
.
data
(),
sizeof
(
int16_t
));
int_data
.
emplace_back
(
int_value
);
}
break
;
}
case
engine
::
DataType
::
INT32
:
{
if
(
attr_data
.
size
()
==
sizeof
(
int32_t
))
{
memcpy
(
&
int_value
,
attr_data
.
data
(),
sizeof
(
int32_t
));
int_data
.
emplace_back
(
int_value
);
}
break
;
}
case
engine
::
DataType
::
INT64
:
{
if
(
attr_data
.
size
()
==
sizeof
(
int64_t
))
{
memcpy
(
&
int_value
,
attr_data
.
data
(),
sizeof
(
int64_t
));
int_data
.
emplace_back
(
int_value
);
for
(
auto
&
json
:
must_json
)
{
auto
must_query
=
std
::
make_shared
<
query
::
BooleanQuery
>
();
if
(
json
.
contains
(
"must"
)
||
json
.
contains
(
"should"
)
||
json
.
contains
(
"must_not"
))
{
STATUS_CHECK
(
ProcessBooleanQueryJson
(
json
,
must_query
,
query_ptr
));
boolean_query
->
AddBooleanQuery
(
must_query
);
}
else
{
std
::
string
field_name
;
STATUS_CHECK
(
ProcessLeafQueryJson
(
json
,
boolean_query
,
field_name
,
query_ptr
));
if
(
!
field_name
.
empty
())
{
query_ptr
->
index_fields
.
insert
(
field_name
);
}
break
;
}
case
engine
::
DataType
::
FLOAT
:
{
if
(
attr_data
.
size
()
==
sizeof
(
float
))
{
float
float_value
;
memcpy
(
&
float_value
,
attr_data
.
data
(),
sizeof
(
float
));
double_value
=
float_value
;
double_data
.
emplace_back
(
double_value
);
}
break
;
}
case
engine
::
DataType
::
DOUBLE
:
{
if
(
attr_data
.
size
()
==
sizeof
(
double
))
{
memcpy
(
&
double_value
,
attr_data
.
data
(),
sizeof
(
double
));
double_data
.
emplace_back
(
double_value
);
}
}
else
if
(
el
.
key
()
==
"should"
)
{
boolean_query
->
SetOccur
(
query
::
Occur
::
SHOULD
);
auto
should_json
=
el
.
value
();
if
(
!
should_json
.
is_array
())
{
std
::
string
msg
=
"Should json string is not an array"
;
return
Status
{
SERVER_INVALID_DSL_PARAMETER
,
msg
};
}
for
(
auto
&
json
:
should_json
)
{
auto
should_query
=
std
::
make_shared
<
query
::
BooleanQuery
>
();
if
(
json
.
contains
(
"must"
)
||
json
.
contains
(
"should"
)
||
json
.
contains
(
"must_not"
))
{
STATUS_CHECK
(
ProcessBooleanQueryJson
(
json
,
should_query
,
query_ptr
));
boolean_query
->
AddBooleanQuery
(
should_query
);
}
else
{
std
::
string
field_name
;
STATUS_CHECK
(
ProcessLeafQueryJson
(
json
,
boolean_query
,
field_name
,
query_ptr
));
if
(
!
field_name
.
empty
())
{
query_ptr
->
index_fields
.
insert
(
field_name
);
}
break
;
}
default:
{
return
;
}
}
}
if
(
int_data
.
size
()
>
0
)
{
if
(
row_num
==
-
1
)
{
nlohmann
::
json
int_data_json
(
int_data
);
column_attrs_json
[
field_names
[
i
]]
=
int_data_json
;
}
else
{
nlohmann
::
json
topk_int_result
;
int64_t
topk
=
int_data
.
size
()
/
row_num
;
for
(
int64_t
j
=
0
;
j
<
row_num
;
j
++
)
{
std
::
vector
<
int64_t
>
one_int_result
(
topk
);
memcpy
(
one_int_result
.
data
(),
int_data
.
data
()
+
j
*
topk
,
sizeof
(
int64_t
)
*
topk
);
nlohmann
::
json
one_int_result_json
(
one_int_result
);
std
::
string
tag
=
"top"
+
std
::
to_string
(
j
);
topk_int_result
[
tag
]
=
one_int_result_json
;
}
column_attrs_json
[
field_names
[
i
]]
=
topk_int_result
;
}
else
if
(
el
.
key
()
==
"must_not"
)
{
boolean_query
->
SetOccur
(
query
::
Occur
::
MUST_NOT
);
auto
should_json
=
el
.
value
();
if
(
!
should_json
.
is_array
())
{
std
::
string
msg
=
"Must_not json string is not an array"
;
return
Status
{
SERVER_INVALID_DSL_PARAMETER
,
msg
};
}
}
else
if
(
double_data
.
size
()
>
0
)
{
if
(
row_num
==
-
1
)
{
nlohmann
::
json
double_data_json
(
double_data
);
column_attrs_json
[
field_names
[
i
]]
=
double_data_json
;
}
else
{
nlohmann
::
json
topk_double_result
;
int64_t
topk
=
int_data
.
size
()
/
row_num
;
for
(
int64_t
j
=
0
;
j
<
row_num
;
j
++
)
{
std
::
vector
<
double
>
one_double_result
(
topk
);
memcpy
(
one_double_result
.
data
(),
double_data
.
data
()
+
j
*
topk
,
sizeof
(
double
)
*
topk
);
nlohmann
::
json
one_double_result_json
(
one_double_result
);
std
::
string
tag
=
"top"
+
std
::
to_string
(
j
);
topk_double_result
[
tag
]
=
one_double_result_json
;
for
(
auto
&
json
:
should_json
)
{
if
(
json
.
contains
(
"must"
)
||
json
.
contains
(
"should"
)
||
json
.
contains
(
"must_not"
))
{
auto
must_not_query
=
std
::
make_shared
<
query
::
BooleanQuery
>
();
STATUS_CHECK
(
ProcessBooleanQueryJson
(
json
,
must_not_query
,
query_ptr
));
boolean_query
->
AddBooleanQuery
(
must_not_query
);
}
else
{
std
::
string
field_name
;
STATUS_CHECK
(
ProcessLeafQueryJson
(
json
,
boolean_query
,
field_name
,
query_ptr
));
if
(
!
field_name
.
empty
())
{
query_ptr
->
index_fields
.
insert
(
field_name
);
}
}
column_attrs_json
[
field_names
[
i
]]
=
topk_double_result
;
}
}
else
{
std
::
string
msg
=
"BoolQuery json string does not include bool query"
;
return
Status
{
SERVER_INVALID_DSL_PARAMETER
,
msg
};
}
}
return
status
;
}
Status
...
...
@@ -724,7 +678,7 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
auto
boolean_query
=
std
::
make_shared
<
query
::
BooleanQuery
>
();
query_ptr_
=
std
::
make_shared
<
query
::
Query
>
();
status
=
ProcessBool
QueryJson
(
boolean_query_json
,
boolean_query
);
status
=
ProcessBool
eanQueryJson
(
boolean_query_json
,
boolean_query
,
query_ptr_
);
if
(
!
status
.
ok
())
{
return
status
;
}
...
...
@@ -749,22 +703,75 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
return
Status
::
OK
();
}
auto
step
=
result
->
result_ids_
.
size
()
/
result
->
row_num_
;
nlohmann
::
json
search_result_json
;
auto
step
=
result
->
result_ids_
.
size
()
/
result
->
row_num_
;
// topk
auto
field_data
=
result
->
data_chunk_
->
fixed_fields_
;
for
(
int64_t
i
=
0
;
i
<
result
->
row_num_
;
i
++
)
{
nlohmann
::
json
raw_result_json
;
for
(
size_t
j
=
0
;
j
<
step
;
j
++
)
{
nlohmann
::
json
one_result_json
;
one_result_json
[
"id"
]
=
std
::
to_string
(
result
->
result_ids_
.
at
(
i
*
step
+
j
));
one_result_json
[
"distance"
]
=
std
::
to_string
(
result
->
result_distances_
.
at
(
i
*
step
+
j
));
raw_result_json
.
emplace_back
(
one_result_json
);
nlohmann
::
json
one_entity_json
;
for
(
const
auto
&
field
:
field_mappings
)
{
auto
field_name
=
field
.
first
->
GetName
();
switch
((
int64_t
)
field
.
first
->
GetFtype
())
{
case
engine
::
DataType
::
INT32
:
{
int32_t
int32_value
;
int64_t
offset
=
(
i
*
step
+
j
)
*
sizeof
(
int32_t
);
memcpy
(
&
int32_value
,
field_data
.
at
(
field_name
)
->
data_
.
data
()
+
offset
,
sizeof
(
int32_t
));
one_entity_json
[
field_name
]
=
int32_value
;
break
;
}
case
engine
::
DataType
::
INT64
:
{
int64_t
int64_value
;
int64_t
offset
=
(
i
*
step
+
j
)
*
sizeof
(
int64_t
);
memcpy
(
&
int64_value
,
field_data
.
at
(
field_name
)
->
data_
.
data
()
+
offset
,
sizeof
(
int64_t
));
one_entity_json
[
field_name
]
=
int64_value
;
break
;
}
case
engine
::
DataType
::
FLOAT
:
{
float
float_value
;
int64_t
offset
=
(
i
*
step
+
j
)
*
sizeof
(
float
);
memcpy
(
&
float_value
,
field_data
.
at
(
field_name
)
->
data_
.
data
()
+
offset
,
sizeof
(
float
));
one_entity_json
[
field_name
]
=
float_value
;
break
;
}
case
engine
::
DataType
::
DOUBLE
:
{
double
double_value
;
int64_t
offset
=
(
i
*
step
+
j
)
*
sizeof
(
double
);
memcpy
(
&
double_value
,
field_data
.
at
(
field_name
)
->
data_
.
data
()
+
offset
,
sizeof
(
double
));
one_entity_json
[
field_name
]
=
double_value
;
break
;
}
case
engine
::
DataType
::
VECTOR_FLOAT
:
{
std
::
vector
<
float
>
float_vector
;
auto
dim
=
field_data
.
at
(
field_name
)
->
data_
.
size
()
/
(
result
->
result_ids_
.
size
()
*
sizeof
(
float
));
int64_t
offset
=
(
i
*
step
+
j
)
*
dim
*
sizeof
(
float
);
float_vector
.
resize
(
dim
);
memcpy
(
float_vector
.
data
(),
field_data
.
at
(
field_name
)
->
data_
.
data
()
+
offset
,
dim
*
sizeof
(
float
));
one_entity_json
[
field_name
]
=
float_vector
;
break
;
}
case
engine
::
DataType
::
VECTOR_BINARY
:
{
std
::
vector
<
int8_t
>
binary_vector
;
auto
dim
=
field_data
.
at
(
field_name
)
->
data_
.
size
()
/
(
result
->
result_ids_
.
size
());
int64_t
offset
=
(
i
*
step
+
j
)
*
dim
;
binary_vector
.
resize
(
dim
);
memcpy
(
binary_vector
.
data
(),
field_data
.
at
(
field_name
)
->
data_
.
data
()
+
offset
,
dim
*
sizeof
(
int8_t
));
one_entity_json
[
field_name
]
=
binary_vector
;
break
;
}
default:
{
return
Status
(
SERVER_UNEXPECTED_ERROR
,
"Return field data type is wrong"
);
}
}
}
one_result_json
[
"entity"
]
=
one_entity_json
;
raw_result_json
.
push_back
(
one_result_json
);
}
search_
result_json
.
emplace_back
(
raw_result_json
);
result_json
.
emplace_back
(
raw_result_json
);
}
nlohmann
::
json
attr_json
;
// ConvertRowToColumnJson(result->attrs_, query_ptr_->field_names, result->row_num_, attr_json);
result_json
[
"Entity"
]
=
attr_json
;
result_json
[
"result"
]
=
search_result_json
;
result_str
=
result_json
.
dump
();
}
...
...
@@ -774,7 +781,7 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
Status
WebRequestHandler
::
DeleteByIDs
(
const
std
::
string
&
collection_name
,
const
nlohmann
::
json
&
json
,
std
::
string
&
result_str
)
{
std
::
vector
<
int64_t
>
vector
_ids
;
std
::
vector
<
int64_t
>
entity
_ids
;
if
(
!
json
.
contains
(
"ids"
))
{
return
Status
(
BODY_FIELD_LOSS
,
"Field
\"
delete
\"
must contains
\"
ids
\"
"
);
}
...
...
@@ -788,10 +795,10 @@ WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohman
if
(
!
ValidateStringIsNumber
(
id_str
).
ok
())
{
return
Status
(
ILLEGAL_BODY
,
"Members in
\"
ids
\"
must be integer string"
);
}
vector
_ids
.
emplace_back
(
std
::
stol
(
id_str
));
entity
_ids
.
emplace_back
(
std
::
stol
(
id_str
));
}
auto
status
=
req_handler_
.
DeleteEntityByID
(
context_ptr_
,
collection_name
,
vector
_ids
);
auto
status
=
req_handler_
.
DeleteEntityByID
(
context_ptr_
,
collection_name
,
entity
_ids
);
nlohmann
::
json
result_json
;
AddStatusToJson
(
result_json
,
status
.
code
(),
status
.
message
());
...
...
@@ -807,89 +814,23 @@ WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std:
engine
::
DataChunkPtr
data_chunk
;
engine
::
snapshot
::
FieldElementMappings
field_mappings
;
std
::
vector
<
engine
::
AttrsData
>
attr_batch
;
std
::
vector
<
engine
::
VectorsData
>
vector_batch
;
auto
status
=
req_handler_
.
GetEntityByID
(
context_ptr_
,
collection_name
,
ids
,
field_names
,
valid_row
,
field_mappings
,
data_chunk
);
if
(
!
status
.
ok
())
{
return
status
;
}
std
::
vector
<
uint8_t
>
id_array
=
data_chunk
->
fixed_fields_
[
engine
::
FIELD_UID
]
->
data_
;
for
(
const
auto
&
it
:
field_mappings
)
{
std
::
string
name
=
it
.
first
->
GetName
();
uint64_t
type
=
it
.
first
->
GetFtype
();
std
::
vector
<
uint8_t
>&
data
=
data_chunk
->
fixed_fields_
[
name
]
->
data_
;
if
(
type
==
engine
::
DataType
::
VECTOR_BINARY
)
{
engine
::
VectorsData
vectors_data
;
memcpy
(
vectors_data
.
binary_data_
.
data
(),
data
.
data
(),
data
.
size
());
memcpy
(
vectors_data
.
id_array_
.
data
(),
id_array
.
data
(),
id_array
.
size
());
vector_batch
.
emplace_back
(
vectors_data
);
}
else
if
(
type
==
engine
::
DataType
::
VECTOR_FLOAT
)
{
engine
::
VectorsData
vectors_data
;
memcpy
(
vectors_data
.
float_data_
.
data
(),
data
.
data
(),
data
.
size
());
memcpy
(
vectors_data
.
id_array_
.
data
(),
id_array
.
data
(),
id_array
.
size
());
vector_batch
.
emplace_back
(
vectors_data
);
}
else
{
engine
::
AttrsData
attrs_data
;
attrs_data
.
attr_type_
[
name
]
=
static_cast
<
engine
::
DataType
>
(
type
);
attrs_data
.
attr_data_
[
name
]
=
data
;
memcpy
(
attrs_data
.
id_array_
.
data
(),
id_array
.
data
(),
id_array
.
size
());
attr_batch
.
emplace_back
(
attrs_data
);
}
}
bool
bin
;
status
=
IsBinaryCollection
(
collection_name
,
bin
);
if
(
!
status
.
ok
())
{
return
status
;
}
nlohmann
::
json
vectors_json
,
attrs_json
;
for
(
size_t
i
=
0
;
i
<
vector_batch
.
size
();
i
++
)
{
nlohmann
::
json
vector_json
;
if
(
bin
)
{
vector_json
[
"vector"
]
=
vector_batch
.
at
(
i
).
binary_data_
;
}
else
{
vector_json
[
"vector"
]
=
vector_batch
.
at
(
i
).
float_data_
;
}
vector_json
[
"id"
]
=
std
::
to_string
(
ids
[
i
]);
vectors_json
.
push_back
(
vector_json
);
}
ConvertRowToColumnJson
(
attr_batch
,
field_names
,
-
1
,
attrs_json
);
json_out
[
"vectors"
]
=
vectors_json
;
json_out
[
"attributes"
]
=
attrs_json
;
return
Status
::
OK
();
}
Status
WebRequestHandler
::
GetVectorsByIDs
(
const
std
::
string
&
collection_name
,
const
std
::
vector
<
int64_t
>&
ids
,
nlohmann
::
json
&
json_out
)
{
std
::
vector
<
engine
::
VectorsData
>
vector_batch
;
auto
status
=
Status
::
OK
();
// auto status = req_handler_.GetVectorsByID(context_ptr_, collection_name, ids, vector_batch);
if
(
!
status
.
ok
())
{
return
status
;
}
bool
bin
;
status
=
IsBinaryCollection
(
collection_name
,
bin
);
if
(
!
status
.
ok
())
{
return
status
;
}
nlohmann
::
json
vectors_json
;
for
(
size_t
i
=
0
;
i
<
vector_batch
.
size
();
i
++
)
{
nlohmann
::
json
vector_json
;
if
(
bin
)
{
vector_json
[
"vector"
]
=
vector_batch
.
at
(
i
).
binary_data_
;
}
else
{
vector_json
[
"vector"
]
=
vector_batch
.
at
(
i
).
float_data_
;
int64_t
valid_size
=
0
;
for
(
auto
row
:
valid_row
)
{
if
(
row
)
{
valid_size
++
;
}
vector_json
[
"id"
]
=
std
::
to_string
(
ids
[
i
]);
json_out
.
push_back
(
vector_json
);
}
std
::
vector
<
uint8_t
>
id_data
=
data_chunk
->
fixed_fields_
[
engine
::
FIELD_UID
]
->
data_
;
std
::
vector
<
int64_t
>
id_array
(
valid_size
);
memcpy
(
id_array
.
data
(),
id_data
.
data
(),
valid_size
*
sizeof
(
int64_t
));
CopyData2Json
(
data_chunk
,
field_mappings
,
id_array
,
json_out
);
return
Status
::
OK
();
}
...
...
@@ -1169,34 +1110,7 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt
* Collection {
*/
StatusDto
::
ObjectWrapper
WebRequestHandler
::
CreateCollection
(
const
CollectionRequestDto
::
ObjectWrapper
&
collection_schema
)
{
if
(
nullptr
==
collection_schema
->
collection_name
.
get
())
{
RETURN_STATUS_DTO
(
BODY_FIELD_LOSS
,
"Field
\'
collection_name
\'
is missing"
)
}
if
(
nullptr
==
collection_schema
->
dimension
.
get
())
{
RETURN_STATUS_DTO
(
BODY_FIELD_LOSS
,
"Field
\'
dimension
\'
is missing"
)
}
if
(
nullptr
==
collection_schema
->
index_file_size
.
get
())
{
RETURN_STATUS_DTO
(
BODY_FIELD_LOSS
,
"Field
\'
index_file_size
\'
is missing"
)
}
if
(
nullptr
==
collection_schema
->
metric_type
.
get
())
{
RETURN_STATUS_DTO
(
BODY_FIELD_LOSS
,
"Field
\'
metric_type
\'
is missing"
)
}
auto
status
=
Status
::
OK
();
// auto status = req_handler_.CreateCollection(
// context_ptr_, collection_schema->collection_name->std_str(), collection_schema->dimension,
// collection_schema->index_file_size,
// static_cast<int64_t>(MetricNameMap.at(collection_schema->metric_type->std_str())));
ASSIGN_RETURN_STATUS_DTO
(
status
)
}
StatusDto
::
ObjectWrapper
WebRequestHandler
::
CreateHybridCollection
(
const
milvus
::
server
::
web
::
OString
&
body
)
{
WebRequestHandler
::
CreateCollection
(
const
milvus
::
server
::
web
::
OString
&
body
)
{
auto
json_str
=
nlohmann
::
json
::
parse
(
body
->
c_str
());
std
::
string
collection_name
=
json_str
[
"collection_name"
];
...
...
@@ -1208,24 +1122,14 @@ WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& bo
field_schema
.
field_params_
=
field
[
"extra_params"
];
const
std
::
string
&
field_type
=
field
[
"field_type"
];
if
(
field_type
==
"int8"
)
{
field_schema
.
field_type_
=
engine
::
DataType
::
INT8
;
}
else
if
(
field_type
==
"int16"
)
{
field_schema
.
field_type_
=
engine
::
DataType
::
INT16
;
}
else
if
(
field_type
==
"int32"
)
{
field_schema
.
field_type_
=
engine
::
DataType
::
INT32
;
}
else
if
(
field_type
==
"int64"
)
{
field_schema
.
field_type_
=
engine
::
DataType
::
INT64
;
}
else
if
(
field_type
==
"float"
)
{
field_schema
.
field_type_
=
engine
::
DataType
::
FLOAT
;
}
else
if
(
field_type
==
"double"
)
{
field_schema
.
field_type_
=
engine
::
DataType
::
DOUBLE
;
}
else
if
(
field_type
==
"vector"
)
{
}
else
{
std
::
string
field_type
=
field
[
"field_type"
];
std
::
transform
(
field_type
.
begin
(),
field_type
.
end
(),
field_type
.
begin
(),
::
tolower
);
if
(
str2type
.
find
(
field_type
)
==
str2type
.
end
())
{
std
::
string
msg
=
field_name
+
" has wrong field_type"
;
RETURN_STATUS_DTO
(
BODY_PARSE_FAIL
,
msg
.
c_str
());
}
field_schema
.
field_type_
=
str2type
.
at
(
field_type
);
fields
[
field_name
]
=
field_schema
;
}
...
...
@@ -1336,18 +1240,15 @@ WebRequestHandler::DropCollection(const OString& collection_name) {
*/
StatusDto
::
ObjectWrapper
WebRequestHandler
::
CreateIndex
(
const
OString
&
collection_name
,
const
OString
&
body
)
{
WebRequestHandler
::
CreateIndex
(
const
OString
&
collection_name
,
const
OString
&
field_name
,
const
OString
&
body
)
{
try
{
auto
request_json
=
nlohmann
::
json
::
parse
(
body
->
std_str
());
std
::
string
field_name
,
index_name
;
if
(
!
request_json
.
contains
(
"index_type"
))
{
RETURN_STATUS_DTO
(
BODY_FIELD_LOSS
,
"Field
\'
index_type
\'
is required"
);
}
auto
status
=
Status
::
OK
();
// auto status =
// req_handler_.CreateIndex(context_ptr_, collection_name->std_str(), index,
// request_json["params"]);
auto
status
=
req_handler_
.
CreateIndex
(
context_ptr_
,
collection_name
->
std_str
(),
field_name
->
std_str
(),
""
,
request_json
);
ASSIGN_RETURN_STATUS_DTO
(
status
);
}
catch
(
nlohmann
::
detail
::
parse_error
&
e
)
{
RETURN_STATUS_DTO
(
BODY_PARSE_FAIL
,
e
.
what
())
...
...
@@ -1359,10 +1260,8 @@ WebRequestHandler::CreateIndex(const OString& collection_name, const OString& bo
}
StatusDto
::
ObjectWrapper
WebRequestHandler
::
DropIndex
(
const
OString
&
collection_name
)
{
auto
status
=
Status
::
OK
();
// auto status = req_handler_.DropIndex(context_ptr_, collection_name->std_str());
WebRequestHandler
::
DropIndex
(
const
OString
&
collection_name
,
const
OString
&
field_name
)
{
auto
status
=
req_handler_
.
DropIndex
(
context_ptr_
,
collection_name
->
std_str
(),
field_name
->
std_str
(),
""
);
ASSIGN_RETURN_STATUS_DTO
(
status
)
}
...
...
@@ -1583,9 +1482,12 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
std
::
string
partition_name
=
body_json
[
"partition_tag"
];
int32_t
row_num
=
body_json
[
"row_num"
];
CollectionSchema
collection_schema
;
std
::
unordered_map
<
std
::
string
,
engine
::
DataType
>
field_types
;
auto
status
=
Status
::
OK
();
// auto status = req_handler_.DescribeHybridCollection(context_ptr_, collection_name->c_str(), field_types);
auto
status
=
req_handler_
.
GetCollectionInfo
(
context_ptr_
,
collection_name
->
std_str
(),
collection_schema
);
for
(
const
auto
&
field
:
collection_schema
.
fields_
)
{
field_types
.
insert
({
field
.
first
,
field
.
second
.
field_type_
});
}
auto
entities
=
body_json
[
"entity"
];
if
(
!
entities
.
is_array
())
{
...
...
@@ -1621,15 +1523,12 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
break
;
}
case
engine
::
DataType
::
VECTOR_FLOAT
:
{
bool
bin_flag
;
status
=
IsBinaryCollection
(
collection_name
->
c_str
(),
bin_flag
);
if
(
!
status
.
ok
())
{
ASSIGN_RETURN_STATUS_DTO
(
status
)
}
// engine::VectorsData vectors;
// CopyRecordsFromJson(field_value, vectors, bin_flag);
// vector_datas.insert(std::make_pair(field_name, vectors));
CopyRecordsFromJson
(
field_value
,
temp_data
,
false
);
break
;
}
case
engine
::
DataType
::
VECTOR_BINARY
:
{
CopyRecordsFromJson
(
field_value
,
temp_data
,
true
);
break
;
}
default:
{}
}
...
...
@@ -1702,47 +1601,7 @@ WebRequestHandler::GetEntity(const milvus::server::web::OString& collection_name
}
StatusDto
::
ObjectWrapper
WebRequestHandler
::
GetVector
(
const
OString
&
collection_name
,
const
OQueryParams
&
query_params
,
OString
&
response
)
{
auto
status
=
Status
::
OK
();
try
{
auto
query_ids
=
query_params
.
get
(
"ids"
);
if
(
query_ids
==
nullptr
||
query_ids
.
get
()
==
nullptr
)
{
RETURN_STATUS_DTO
(
QUERY_PARAM_LOSS
,
"Query param ids is required."
);
}
std
::
vector
<
std
::
string
>
ids
;
StringHelpFunctions
::
SplitStringByDelimeter
(
query_ids
->
c_str
(),
","
,
ids
);
std
::
vector
<
int64_t
>
vector_ids
;
for
(
auto
&
id
:
ids
)
{
vector_ids
.
push_back
(
std
::
stol
(
id
));
}
engine
::
VectorsData
vectors
;
nlohmann
::
json
vectors_json
;
status
=
GetVectorsByIDs
(
collection_name
->
std_str
(),
vector_ids
,
vectors_json
);
if
(
!
status
.
ok
())
{
response
=
"NULL"
;
ASSIGN_RETURN_STATUS_DTO
(
status
)
}
FloatJson
json
;
json
[
"code"
]
=
(
int64_t
)
status
.
code
();
json
[
"message"
]
=
status
.
message
();
if
(
vectors_json
.
empty
())
{
json
[
"vectors"
]
=
std
::
vector
<
int64_t
>
();
}
else
{
json
[
"vectors"
]
=
vectors_json
;
}
response
=
json
.
dump
().
c_str
();
}
catch
(
std
::
exception
&
e
)
{
RETURN_STATUS_DTO
(
SERVER_UNEXPECTED_ERROR
,
e
.
what
());
}
ASSIGN_RETURN_STATUS_DTO
(
status
);
}
StatusDto
::
ObjectWrapper
WebRequestHandler
::
VectorsOp
(
const
OString
&
collection_name
,
const
OString
&
payload
,
OString
&
response
)
{
WebRequestHandler
::
EntityOp
(
const
OString
&
collection_name
,
const
OString
&
payload
,
OString
&
response
)
{
auto
status
=
Status
::
OK
();
std
::
string
result_str
;
...
...
core/src/server/web_impl/handler/WebRequestHandler.h
浏览文件 @
4beb0549
...
...
@@ -85,7 +85,11 @@ class WebRequestHandler {
IsBinaryCollection
(
const
std
::
string
&
collection_name
,
bool
&
bin
);
Status
CopyRecordsFromJson
(
const
nlohmann
::
json
&
json
,
engine
::
VectorsData
&
vectors
,
bool
bin
);
CopyRecordsFromJson
(
const
nlohmann
::
json
&
json
,
std
::
vector
<
uint8_t
>&
vectors_data
,
bool
bin
);
Status
CopyData2Json
(
const
engine
::
DataChunkPtr
&
data_chunk
,
const
engine
::
snapshot
::
FieldElementMappings
&
field_mappings
,
const
std
::
vector
<
int64_t
>&
id_array
,
nlohmann
::
json
&
json_res
);
protected:
Status
...
...
@@ -124,10 +128,12 @@ class WebRequestHandler {
SetConfig
(
const
nlohmann
::
json
&
json
,
std
::
string
&
result_str
);
Status
ProcessLeafQueryJson
(
const
nlohmann
::
json
&
json
,
query
::
BooleanQueryPtr
&
boolean_query
);
ProcessLeafQueryJson
(
const
nlohmann
::
json
&
json
,
query
::
BooleanQueryPtr
&
boolean_query
,
std
::
string
&
field_name
,
query
::
QueryPtr
&
query_ptr
);
Status
ProcessBoolQueryJson
(
const
nlohmann
::
json
&
query_json
,
query
::
BooleanQueryPtr
&
boolean_query
);
ProcessBooleanQueryJson
(
const
nlohmann
::
json
&
query_json
,
query
::
BooleanQueryPtr
&
boolean_query
,
query
::
QueryPtr
&
query_ptr
);
Status
Search
(
const
std
::
string
&
collection_name
,
const
nlohmann
::
json
&
json
,
std
::
string
&
result_str
);
...
...
@@ -135,9 +141,6 @@ class WebRequestHandler {
Status
DeleteByIDs
(
const
std
::
string
&
collection_name
,
const
nlohmann
::
json
&
json
,
std
::
string
&
result_str
);
Status
GetVectorsByIDs
(
const
std
::
string
&
collection_name
,
const
std
::
vector
<
int64_t
>&
ids
,
nlohmann
::
json
&
json_out
);
Status
GetEntityByIDs
(
const
std
::
string
&
collection_name
,
const
std
::
vector
<
int64_t
>&
ids
,
std
::
vector
<
std
::
string
>&
field_names
,
nlohmann
::
json
&
json_out
);
...
...
@@ -167,12 +170,10 @@ class WebRequestHandler {
#endif
StatusDto
::
ObjectWrapper
CreateCollection
(
const
CollectionRequestDto
::
ObjectWrapper
&
table_schema
);
StatusDto
::
ObjectWrapper
ShowCollections
(
const
OQueryParams
&
query_params
,
OString
&
result
);
CreateCollection
(
const
milvus
::
server
::
web
::
OString
&
body
);
StatusDto
::
ObjectWrapper
CreateHybridCollection
(
const
OString
&
body
);
ShowCollections
(
const
OQueryParams
&
query_params
,
OString
&
result
);
StatusDto
::
ObjectWrapper
GetCollection
(
const
OString
&
collection_name
,
const
OQueryParams
&
query_params
,
OString
&
result
);
...
...
@@ -181,10 +182,10 @@ class WebRequestHandler {
DropCollection
(
const
OString
&
collection_name
);
StatusDto
::
ObjectWrapper
CreateIndex
(
const
OString
&
collection_name
,
const
OString
&
body
);
CreateIndex
(
const
OString
&
collection_name
,
const
OString
&
field_name
,
const
OString
&
body
);
StatusDto
::
ObjectWrapper
DropIndex
(
const
OString
&
collection_name
);
DropIndex
(
const
OString
&
collection_name
,
const
OString
&
field_name
);
StatusDto
::
ObjectWrapper
CreatePartition
(
const
OString
&
collection_name
,
const
PartitionRequestDto
::
ObjectWrapper
&
param
);
...
...
@@ -221,7 +222,7 @@ class WebRequestHandler {
GetVector
(
const
OString
&
collection_name
,
const
OQueryParams
&
query_params
,
OString
&
response
);
StatusDto
::
ObjectWrapper
Vectors
Op
(
const
OString
&
collection_name
,
const
OString
&
payload
,
OString
&
response
);
Entity
Op
(
const
OString
&
collection_name
,
const
OString
&
payload
,
OString
&
response
);
/**
*
...
...
sdk/examples/simple/src/ClientTest.cpp
浏览文件 @
4beb0549
...
...
@@ -27,7 +27,7 @@ const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str();
constexpr
int64_t
COLLECTION_DIMENSION
=
512
;
constexpr
int64_t
COLLECTION_INDEX_FILE_SIZE
=
1024
;
constexpr
milvus
::
MetricType
COLLECTION_METRIC_TYPE
=
milvus
::
MetricType
::
L2
;
constexpr
int64_t
BATCH_ENTITY_COUNT
=
4
000
;
constexpr
int64_t
BATCH_ENTITY_COUNT
=
10
000
;
constexpr
int64_t
NQ
=
5
;
constexpr
int64_t
TOP_K
=
10
;
constexpr
int64_t
NPROBE
=
32
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录