未验证 提交 3d9a8106 编写于 作者: G godchen0212 提交者: GitHub

Improvement in GetEntityById grpc interface (#3019)

* Fix bug in GetEntityById
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* format code and delete useless code
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>
上级 10799ce1
......@@ -16,6 +16,7 @@
// under the License.
#include "server/delivery/hybrid_request/GetEntityByIDRequest.h"
#include "db/meta/MetaTypes.h"
#include "server/DBWrapper.h"
#include "server/ValidationUtil.h"
#include "utils/Log.h"
......@@ -30,8 +31,8 @@ namespace server {
constexpr uint64_t MAX_COUNT_RETURNED = 1000;
GetEntityByIDRequest::GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context,
const std::string collection_name, const engine::IDNumbers& id_array,
const std::vector<std::string>& field_names,
const std::string& collection_name, const engine::IDNumbers& id_array,
std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk)
: BaseRequest(context, BaseRequest::kGetVectorByID),
......@@ -43,8 +44,9 @@ GetEntityByIDRequest::GetEntityByIDRequest(const std::shared_ptr<milvus::server:
}
BaseRequestPtr
GetEntityByIDRequest::Create(const std::shared_ptr<milvus::server::Context>& context, std::string collection_name,
const engine::IDNumbers& id_array, const std::vector<std::string>& field_names_,
GetEntityByIDRequest::Create(const std::shared_ptr<milvus::server::Context>& context,
const std::string& collection_name, const engine::IDNumbers& id_array,
std::vector<std::string>& field_names_,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) {
return std::shared_ptr<BaseRequest>(
new GetEntityByIDRequest(context, collection_name, id_array, field_names_, field_mappings, data_chunk));
......@@ -81,19 +83,17 @@ GetEntityByIDRequest::OnExecute() {
}
if (field_names_.empty()) {
for (const auto& schema : field_mappings_)
for (const auto& it : schema.second) {
field_names_.emplace_back(it->GetName());
}
for (const auto& schema : field_mappings_) {
if (schema.first->GetFtype() != engine::meta::hybrid::DataType::UID)
field_names_.emplace_back(schema.first->GetName());
}
} else {
for (const auto& name : field_names_) {
bool find_field_name = false;
for (const auto& schema : field_mappings_) {
for (const auto& it : schema.second) {
if (name == it->GetName()) {
find_field_name = true;
break;
}
if (name == schema.first->GetName()) {
find_field_name = true;
break;
}
}
if (not find_field_name) {
......@@ -107,6 +107,7 @@ GetEntityByIDRequest::OnExecute() {
if (!status.ok()) {
return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_));
}
return Status::OK();
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
}
......
......@@ -32,13 +32,13 @@ namespace server {
class GetEntityByIDRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<milvus::server::Context>& context, std::string collection_name,
const engine::IDNumbers& id_array, const std::vector<std::string>& field_names_,
Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const engine::IDNumbers& id_array, std::vector<std::string>& field_names_,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
protected:
GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context, std::string collection_name,
const engine::IDNumbers& id_array, const std::vector<std::string>& field_names,
GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const engine::IDNumbers& id_array, std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
Status
......@@ -47,9 +47,9 @@ class GetEntityByIDRequest : public BaseRequest {
private:
std::string collection_name_;
engine::IDNumbers id_array_;
std::vector<std::string> field_names_;
engine::snapshot::CollectionMappings field_mappings_;
engine::DataChunkPtr data_chunk_;
std::vector<std::string>& field_names_;
engine::snapshot::CollectionMappings& field_mappings_;
engine::DataChunkPtr& data_chunk_;
};
} // namespace server
......
......@@ -85,6 +85,7 @@ RequestMap(BaseRequest::RequestType request_type) {
{BaseRequest::kSearchByID, "SearchByID"},
{BaseRequest::kHybridSearch, "HybridSearch"},
{BaseRequest::kFlush, "Flush"},
{BaseRequest::kGetEntityByID, "GetEntityByID"},
{BaseRequest::kCompact, "Compact"},
};
......@@ -756,33 +757,80 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
Status status = request_handler_.GetEntityByID(GetContext(context), request->collection_name(), vector_ids,
field_names, field_mappings, data_chunk);
std::vector<uint8_t> id_array = data_chunk->fixed_fields_[engine::DEFAULT_UID_NAME];
for (const auto& it : field_mappings) {
auto type = it.first->GetFtype();
std::string name = it.first->GetName();
uint64_t type = it.first->GetFtype();
std::vector<uint8_t> data = data_chunk->fixed_fields_[name];
if (type == engine::FieldType::VECTOR_BINARY) {
engine::VectorsData vectors_data;
memcpy(vectors_data.binary_data_.data(), data.data(), data.size());
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
vectors.emplace_back(vectors_data);
} else if (type == engine::FieldType::VECTOR_FLOAT) {
engine::VectorsData vectors_data;
memcpy(vectors_data.float_data_.data(), data.data(), data.size());
memcpy(vectors_data.id_array_.data(), id_array.data(), id_array.size());
vectors.emplace_back(vectors_data);
if (type == engine::meta::hybrid::DataType::UID) {
response->mutable_ids()->Resize(data.size(), 0);
memcpy(response->mutable_ids()->mutable_data(), data.data(), data.size() * sizeof(uint64_t));
continue;
}
auto field_value = response->add_fields();
auto vector_record = field_value->mutable_vector_record();
field_value->set_field_name(name);
field_value->set_type(static_cast<milvus::grpc::DataType>(type));
// general data
if (type == engine::meta::hybrid::DataType::VECTOR_BINARY) {
// add binary vector data
auto vector_row_record = vector_record->add_records();
std::vector<int8_t> binary_vector;
binary_vector.resize(data.size());
memcpy(binary_vector.data(), data.data(), data.size());
vector_row_record->mutable_binary_data()->resize(binary_vector.size());
memcpy(vector_row_record->mutable_binary_data()->data(), binary_vector.data(), binary_vector.size());
continue;
} else if (type == engine::meta::hybrid::DataType::VECTOR_FLOAT) {
// add float vector data
auto vector_row_record = vector_record->add_records();
std::vector<float> float_vector;
float_vector.resize(data.size() * sizeof(int8_t) / sizeof(float));
memcpy(float_vector.data(), data.data(), data.size());
vector_row_record->mutable_float_data()->Resize(float_vector.size(), 0.0);
memcpy(vector_row_record->mutable_float_data()->mutable_data(), float_vector.data(),
float_vector.size() * sizeof(float));
continue;
} else {
engine::AttrsData attrs_data;
attrs_data.attr_type_[name] = static_cast<engine::meta::hybrid::DataType>(type);
attrs_data.attr_data_[name] = data;
memcpy(attrs_data.id_array_.data(), id_array.data(), id_array.size());
attrs.emplace_back(attrs_data);
// add attribute data
auto attr_record = field_value->mutable_attr_record();
if (type == engine::meta::hybrid::DataType::INT32) {
// add int32 data
std::vector<int32_t> int32_value;
int32_value.resize(data.size() * sizeof(int8_t) / sizeof(int32_t));
memcpy(int32_value.data(), data.data(), data.size());
attr_record->mutable_int32_value()->Resize(int32_value.size(), 0);
memcpy(attr_record->mutable_int32_value()->mutable_data(), int32_value.data(), int32_value.size());
} else if (type == engine::meta::hybrid::DataType::INT64) {
// add int64 data
std::vector<int64_t> int64_value;
int64_value.resize(data.size() * sizeof(int8_t) / sizeof(int64_t));
memcpy(int64_value.data(), data.data(), data.size());
attr_record->mutable_int64_value()->Resize(int64_value.size(), 0);
memcpy(attr_record->mutable_int64_value()->mutable_data(), int64_value.data(), int64_value.size());
} else if (type == engine::meta::hybrid::DataType::DOUBLE) {
// add double data
std::vector<double> double_value;
double_value.resize(data.size() * sizeof(int8_t) / sizeof(double));
memcpy(double_value.data(), data.data(), data.size());
attr_record->mutable_double_value()->Resize(double_value.size(), 0.0);
memcpy(attr_record->mutable_double_value()->mutable_data(), double_value.data(), double_value.size());
} else if (type == engine::meta::hybrid::DataType::FLOAT) {
// add float data
std::vector<float> float_value;
float_value.resize(data.size() * sizeof(int8_t) / sizeof(float));
memcpy(float_value.data(), data.data(), data.size());
attr_record->mutable_float_value()->Resize(float_value.size(), 0.0);
memcpy(attr_record->mutable_float_value()->mutable_data(), float_value.data(), float_value.size());
}
}
}
ConstructEntityResults(attrs, vectors, field_names, response);
LOG_SERVER_INFO_ << LogOut("Request [%s] %s end.", GetContext(context)->RequestID().c_str(), __func__);
SET_RESPONSE(response->mutable_status(), status, context);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册