未验证 提交 4d7fb16e 编写于 作者: G godchen0212 提交者: GitHub

add GetEntityById realization (#3004)

* add web_impl testcase
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* add GetEntityById realization
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* merge files
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* fix bug
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* fix bug
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>
上级 cd281453
......@@ -290,10 +290,10 @@ RequestHandler::InsertEntity(const std::shared_ptr<Context>& context, const std:
Status
RequestHandler::GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name,
std::vector<std::string>& field_names, const std::vector<int64_t>& ids,
std::vector<engine::AttrsData>& attrs, std::vector<engine::VectorsData>& vectors) {
const engine::IDNumbers& ids, std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) {
BaseRequestPtr request_ptr =
GetEntityByIDRequest::Create(context, collection_name, field_names, ids, attrs, vectors);
GetEntityByIDRequest::Create(context, collection_name, ids, field_names, field_mappings, data_chunk);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
......
......@@ -11,6 +11,8 @@
#pragma once
#include <src/db/snapshot/Context.h>
#include <src/segment/Segment.h>
#include <memory>
#include <string>
#include <unordered_map>
......@@ -134,8 +136,8 @@ class RequestHandler {
Status
GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name,
std::vector<std::string>& field_names, const std::vector<int64_t>& ids,
std::vector<engine::AttrsData>& attrs, std::vector<engine::VectorsData>& vectors);
const engine::IDNumbers& ids, std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
Status
HybridSearch(const std::shared_ptr<milvus::server::Context>& context, const query::QueryPtr& query_ptr,
......
......@@ -30,24 +30,24 @@ namespace server {
constexpr uint64_t MAX_COUNT_RETURNED = 1000;
GetEntityByIDRequest::GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context,
const std::string& collection_name, std::vector<std::string>& field_names,
const std::vector<int64_t>& ids, std::vector<engine::AttrsData>& attrs,
std::vector<engine::VectorsData>& vectors)
const std::string collection_name, const engine::IDNumbers& id_array,
const std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk)
: BaseRequest(context, BaseRequest::kGetVectorByID),
collection_name_(collection_name),
id_array_(id_array),
field_names_(field_names),
ids_(ids),
attrs_(attrs),
vectors_(vectors) {
field_mappings_(field_mappings),
data_chunk_(data_chunk) {
}
BaseRequestPtr
GetEntityByIDRequest::Create(const std::shared_ptr<milvus::server::Context>& context,
const std::string& collection_name, std::vector<std::string>& field_names,
const std::vector<int64_t>& ids, std::vector<engine::AttrsData>& attrs,
std::vector<engine::VectorsData>& vectors) {
GetEntityByIDRequest::Create(const std::shared_ptr<milvus::server::Context>& context, std::string collection_name,
const engine::IDNumbers& id_array, const std::vector<std::string>& field_names_,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) {
return std::shared_ptr<BaseRequest>(
new GetEntityByIDRequest(context, collection_name, field_names, ids, attrs, vectors));
new GetEntityByIDRequest(context, collection_name, id_array, field_names_, field_mappings, data_chunk));
}
Status
......@@ -57,11 +57,11 @@ GetEntityByIDRequest::OnExecute() {
TimeRecorderAuto rc(hdr);
// step 1: check arguments
if (ids_.empty()) {
if (id_array_.empty()) {
return Status(SERVER_INVALID_ARGUMENT, "No entity id specified");
}
if (ids_.size() > MAX_COUNT_RETURNED) {
if (id_array_.size() > MAX_COUNT_RETURNED) {
std::string msg = "Input id array size cannot exceed: " + std::to_string(MAX_COUNT_RETURNED);
return Status(SERVER_INVALID_ARGUMENT, msg);
}
......@@ -74,33 +74,26 @@ GetEntityByIDRequest::OnExecute() {
// TODO(yukun) ValidateFieldNames
// only process root collection, ignore partition collection
engine::meta::CollectionSchema collection_schema;
engine::meta::hybrid::FieldsSchema fields_schema;
collection_schema.collection_id_ = collection_name_;
status = DBWrapper::DB()->DescribeHybridCollection(collection_schema, fields_schema);
if (!status.ok()) {
if (status.code() == DB_NOT_FOUND) {
return Status(SERVER_COLLECTION_NOT_EXIST, CollectionNotExistMsg(collection_name_));
} else {
return status;
}
} else {
if (!collection_schema.owner_collection_.empty()) {
return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_));
}
engine::snapshot::CollectionPtr collectionPtr;
status = DBWrapper::SSDB()->DescribeCollection(collection_name_, collectionPtr, field_mappings_);
if (collectionPtr == nullptr) {
return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_));
}
if (field_names_.empty()) {
for (const auto& schema : fields_schema.fields_schema_) {
field_names_.emplace_back(schema.field_name_);
}
for (const auto& schema : field_mappings_)
for (const auto& it : schema.second) {
field_names_.emplace_back(it->GetName());
}
} else {
for (const auto& name : field_names_) {
bool find_field_name = false;
for (const auto& schema : fields_schema.fields_schema_) {
if (name == schema.field_name_) {
find_field_name = true;
break;
for (const auto& schema : field_mappings_) {
for (const auto& it : schema.second) {
if (name == it->GetName()) {
find_field_name = true;
break;
}
}
}
if (not find_field_name) {
......@@ -110,7 +103,10 @@ GetEntityByIDRequest::OnExecute() {
}
// step 2: get vector data, now only support get one id
return DBWrapper::DB()->GetEntitiesByID(collection_name_, ids_, field_names_, vectors_, attrs_);
status = DBWrapper::SSDB()->GetEntityByID(collection_name_, id_array_, field_names_, data_chunk_);
if (!status.ok()) {
return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_));
}
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
}
......
......@@ -19,6 +19,9 @@
#include "server/delivery/request/BaseRequest.h"
#include <src/db/snapshot/Context.h>
#include <src/db/snapshot/Resources.h>
#include <src/segment/Segment.h>
#include <memory>
#include <string>
#include <vector>
......@@ -29,24 +32,24 @@ namespace server {
class GetEntityByIDRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
std::vector<std::string>& field_names, const std::vector<int64_t>& ids,
std::vector<engine::AttrsData>& attrs, std::vector<engine::VectorsData>& vectors);
Create(const std::shared_ptr<milvus::server::Context>& context, std::string collection_name,
const engine::IDNumbers& id_array, const std::vector<std::string>& field_names_,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
protected:
GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
std::vector<std::string>& field_names, const std::vector<int64_t>& ids,
std::vector<engine::AttrsData>& attrs, std::vector<engine::VectorsData>& vectors);
GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context, std::string collection_name,
const engine::IDNumbers& id_array, const std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
Status
OnExecute() override;
private:
std::string collection_name_;
std::vector<std::string>& field_names_;
std::vector<int64_t> ids_;
std::vector<engine::AttrsData>& attrs_;
std::vector<engine::VectorsData>& vectors_;
engine::IDNumbers id_array_;
std::vector<std::string> field_names_;
engine::snapshot::CollectionMappings field_mappings_;
engine::DataChunkPtr data_chunk_;
};
} // namespace server
......
......@@ -19,7 +19,6 @@
#include <utility>
#include <vector>
#include "context/HybridSearchContext.h"
#include "query/BinaryQuery.h"
#include "server/ValidationUtil.h"
#include "server/context/ConnectionContext.h"
......@@ -27,7 +26,6 @@
#include "tracing/TracerUtil.h"
#include "utils/Log.h"
#include "utils/LogUtil.h"
#include "utils/TimeRecorder.h"
namespace milvus {
namespace server {
......@@ -739,7 +737,7 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
CHECK_NULLPTR_RETURN(request);
LOG_SERVER_INFO_ << LogOut("Request [%s] %s begin.", GetContext(context)->RequestID().c_str(), __func__);
std::vector<int64_t> vector_ids;
engine::IDNumbers vector_ids;
vector_ids.reserve(request->id_array_size());
for (int i = 0; i < request->id_array_size(); i++) {
vector_ids.push_back(request->id_array(i));
......@@ -750,10 +748,39 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
field_names[i] = request->field_names(i);
}
engine::DataChunkPtr data_chunk;
engine::snapshot::CollectionMappings field_mappings;
std::vector<engine::AttrsData> attrs;
std::vector<engine::VectorsData> vectors;
Status status = request_handler_.GetEntityByID(GetContext(context), request->collection_name(), field_names,
vector_ids, attrs, vectors);
Status status = request_handler_.GetEntityByID(GetContext(context), request->collection_name(), vector_ids,
field_names, field_mappings, data_chunk);
std::vector<uint8_t> id_array = data_chunk->fixed_fields_[engine::DEFAULT_UID_NAME];
for (const auto& it : field_mappings) {
std::string name = it.first->GetName();
uint64_t type = it.first->GetFtype();
std::vector<uint8_t> data = data_chunk->fixed_fields_[name];
if (type == engine::FieldType::VECTOR_BINARY) {
engine::VectorsData vectors_data;
memcpy(vectors_data.binary_data_.data(), data.data(), data.size());
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
vectors.emplace_back(vectors_data);
} else if (type == engine::FieldType::VECTOR_FLOAT) {
engine::VectorsData vectors_data;
memcpy(vectors_data.float_data_.data(), data.data(), data.size());
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
vectors.emplace_back(vectors_data);
} else {
engine::AttrsData attrs_data;
attrs_data.attr_type_[name] = static_cast<engine::meta::hybrid::DataType>(type);
attrs_data.attr_data_[name] = data;
memcpy(attrs_data.id_array_.data(), id_array.data(), id_array.size());
attrs.emplace_back(attrs_data);
}
}
ConstructEntityResults(attrs, vectors, field_names, response);
......
......@@ -921,13 +921,40 @@ WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohman
Status
WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std::vector<int64_t>& ids,
std::vector<std::string>& field_names, nlohmann::json& json_out) {
std::vector<engine::VectorsData> vector_batch;
engine::DataChunkPtr data_chunk;
engine::snapshot::CollectionMappings field_mappings;
std::vector<engine::AttrsData> attr_batch;
std::vector<engine::VectorsData> vector_batch;
auto status =
request_handler_.GetEntityByID(context_ptr_, collection_name, field_names, ids, attr_batch, vector_batch);
request_handler_.GetEntityByID(context_ptr_, collection_name, ids, field_names, field_mappings, data_chunk);
if (!status.ok()) {
return status;
}
std::vector<uint8_t> id_array = data_chunk->fixed_fields_[engine::DEFAULT_UID_NAME];
for (const auto& it : field_mappings) {
std::string name = it.first->GetName();
uint64_t type = it.first->GetFtype();
std::vector<uint8_t> data = data_chunk->fixed_fields_[name];
if (type == engine::FieldType::VECTOR_BINARY) {
engine::VectorsData vectors_data;
memcpy(vectors_data.binary_data_.data(), data.data(), data.size());
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
vector_batch.emplace_back(vectors_data);
} else if (type == engine::FieldType::VECTOR_FLOAT) {
engine::VectorsData vectors_data;
memcpy(vectors_data.float_data_.data(), data.data(), data.size());
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
vector_batch.emplace_back(vectors_data);
} else {
engine::AttrsData attrs_data;
attrs_data.attr_type_[name] = static_cast<engine::meta::hybrid::DataType>(type);
attrs_data.attr_data_[name] = data;
memcpy(attrs_data.id_array_.data(), id_array.data(), id_array.size());
attr_batch.emplace_back(attrs_data);
}
}
bool bin;
status = IsBinaryCollection(collection_name, bin);
......@@ -949,6 +976,7 @@ WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std:
ConvertRowToColumnJson(attr_batch, field_names, -1, attrs_json);
json_out["vectors"] = vectors_json;
json_out["attributes"] = attrs_json;
return Status::OK();
}
Status
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册