未验证 提交 79cc9780 编写于 作者: Y yukun 提交者: GitHub

Merge pull request #2757 from fishpenguin/scalar-field-filtering

Fix C++ sdk
......@@ -789,8 +789,8 @@ CopyToAttr(const std::vector<uint8_t>& record, int64_t row_num, const std::vecto
std::vector<uint8_t> data;
data.resize(row_num * sizeof(int8_t));
std::vector<int64_t> attr_value(row_num, 0);
memcpy(attr_value.data(), record.data() + offset, row_num * sizeof(int64_t));
std::vector<int32_t> attr_value(row_num, 0);
memcpy(attr_value.data(), record.data() + offset, row_num * sizeof(int32_t));
std::vector<int8_t> raw_value(row_num, 0);
for (uint64_t i = 0; i < row_num; ++i) {
......@@ -809,8 +809,8 @@ CopyToAttr(const std::vector<uint8_t>& record, int64_t row_num, const std::vecto
std::vector<uint8_t> data;
data.resize(row_num * sizeof(int16_t));
std::vector<int64_t> attr_value(row_num, 0);
memcpy(attr_value.data(), record.data() + offset, row_num * sizeof(int64_t));
std::vector<int32_t> attr_value(row_num, 0);
memcpy(attr_value.data(), record.data() + offset, row_num * sizeof(int32_t));
std::vector<int16_t> raw_value(row_num, 0);
for (uint64_t i = 0; i < row_num; ++i) {
......@@ -829,15 +829,10 @@ CopyToAttr(const std::vector<uint8_t>& record, int64_t row_num, const std::vecto
std::vector<uint8_t> data;
data.resize(row_num * sizeof(int32_t));
std::vector<int64_t> attr_value(row_num, 0);
memcpy(attr_value.data(), record.data() + offset, row_num * sizeof(int64_t));
std::vector<int32_t> attr_value(row_num, 0);
memcpy(attr_value.data(), record.data() + offset, row_num * sizeof(int32_t));
std::vector<int32_t> raw_value(row_num, 0);
for (uint64_t i = 0; i < row_num; ++i) {
raw_value[i] = attr_value[i];
}
memcpy(data.data(), raw_value.data(), row_num * sizeof(int32_t));
memcpy(data.data(), attr_value.data(), row_num * sizeof(int32_t));
attr_datas.insert(std::make_pair(name, data));
attr_nbytes.insert(std::make_pair(name, sizeof(int32_t)));
......@@ -863,15 +858,10 @@ CopyToAttr(const std::vector<uint8_t>& record, int64_t row_num, const std::vecto
std::vector<uint8_t> data;
data.resize(row_num * sizeof(float));
std::vector<double> attr_value(row_num, 0);
memcpy(attr_value.data(), record.data() + offset, row_num * sizeof(double));
std::vector<float> attr_value(row_num, 0);
memcpy(attr_value.data(), record.data() + offset, row_num * sizeof(float));
std::vector<float> raw_value(row_num, 0);
for (uint64_t i = 0; i < row_num; ++i) {
raw_value[i] = attr_value[i];
}
memcpy(data.data(), raw_value.data(), row_num * sizeof(float));
memcpy(data.data(), attr_value.data(), row_num * sizeof(float));
attr_datas.insert(std::make_pair(name, data));
attr_nbytes.insert(std::make_pair(name, sizeof(float)));
......@@ -1395,7 +1385,7 @@ DBImpl::GetEntitiesByID(const std::string& collection_id, const milvus::engine::
return status;
}
std::unordered_map<std::string, engine::meta::hybrid::DataType> attr_type;
for (auto schema : fields_schema.fields_schema_) {
for (const auto& schema : fields_schema.fields_schema_) {
if (schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::FLOAT_VECTOR ||
schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::BINARY_VECTOR) {
continue;
......@@ -1705,27 +1695,27 @@ DBImpl::GetEntitiesByIdHelper(const std::string& collection_id, const milvus::en
size_t num_bytes;
switch (attr_it->second) {
case engine::meta::hybrid::DataType::INT8: {
num_bytes = 1;
num_bytes = sizeof(int8_t);
break;
}
case engine::meta::hybrid::DataType::INT16: {
num_bytes = 2;
num_bytes = sizeof(int16_t);
break;
}
case engine::meta::hybrid::DataType::INT32: {
num_bytes = 4;
num_bytes = sizeof(int32_t);
break;
}
case engine::meta::hybrid::DataType::INT64: {
num_bytes = 8;
num_bytes = sizeof(int64_t);
break;
}
case engine::meta::hybrid::DataType::FLOAT: {
num_bytes = 4;
num_bytes = sizeof(float);
break;
}
case engine::meta::hybrid::DataType::DOUBLE: {
num_bytes = 8;
num_bytes = sizeof(double);
break;
}
default: {
......@@ -1775,11 +1765,11 @@ DBImpl::GetEntitiesByIdHelper(const std::string& collection_id, const milvus::en
if (data.vector_count_ > 0) {
data.float_data_ = vector_ref.float_data_; // copy data since there could be duplicated id
data.binary_data_ = vector_ref.binary_data_; // copy data since there could be duplicated id
}
data.id_array_.emplace_back(id);
vectors.emplace_back(data);
data.id_array_.emplace_back(id);
vectors.emplace_back(data);
attrs.emplace_back(map_id2attr[id]);
attrs.emplace_back(map_id2attr[id]);
}
}
if (vectors.empty()) {
......@@ -3170,7 +3160,7 @@ DBImpl::ExecWalRecord(const wal::MXLogRecord& record) {
vectors.vector_type_ = Vectors::FLOAT;
vectors.float_vector = (const float*)record.data;
status = mem_mgr_->InsertEntities(target_collection_name, record.length, record.ids,
(record.data_size / record.length / sizeof(uint8_t)), vectors,
(record.data_size / record.length / sizeof(float)), vectors,
record.attr_nbytes, record.attr_data_size, record.attr_data, record.lsn);
// status = mem_mgr_->InsertVectors(target_collection_name, record.length, record.ids,
......@@ -3237,11 +3227,6 @@ DBImpl::ExecWalRecord(const wal::MXLogRecord& record) {
break;
}
flushed_collections.insert(collection_id);
// status = FlushAttrsIndex(collection_id);
// if (!status.ok()) {
// return status;
// }
}
collections_flushed(record.collection_id, flushed_collections);
......
......@@ -46,11 +46,6 @@ InstanceStructuredIndex::CreateStructuredIndex(const std::string& collection_id,
return Status::OK();
}
std::unordered_map<std::string, std::vector<uint8_t>> attr_datas;
std::unordered_map<std::string, int64_t> attr_sizes;
std::unordered_map<std::string, engine::meta::hybrid::DataType> attr_types;
std::vector<std::string> field_names;
for (auto& segment_schema : files_holder.HoldFiles()) {
std::string segment_dir;
engine::utils::GetParentPath(segment_schema.location_, segment_dir);
......@@ -63,6 +58,11 @@ InstanceStructuredIndex::CreateStructuredIndex(const std::string& collection_id,
return status;
}
std::unordered_map<std::string, std::vector<uint8_t>> attr_datas;
std::unordered_map<std::string, int64_t> attr_sizes;
std::unordered_map<std::string, engine::meta::hybrid::DataType> attr_types;
std::vector<std::string> field_names;
for (auto& field_schema : fields_schema.fields_schema_) {
if (field_schema.field_type_ != (int32_t)engine::meta::hybrid::DataType::FLOAT_VECTOR) {
attr_types.insert(
......
......@@ -2152,7 +2152,7 @@ SqliteMetaImpl::CreateHybridCollection(meta::CollectionSchema& collection_schema
// multi-threads call sqlite update may get exception('bad logic', etc), so we add a lock here
std::lock_guard<std::mutex> meta_lock(meta_mutex_);
if (collection_schema.collection_id_ == "") {
if (collection_schema.collection_id_.empty()) {
NextCollectionId(collection_schema.collection_id_);
} else {
fiu_do_on("SqliteMetaImpl.CreateCollection.throw_exception", throw std::exception());
......
......@@ -64,6 +64,8 @@ DescribeHybridCollectionRequest::OnExecute() {
for (auto schema : fields_schema.fields_schema_) {
field_types_.insert(std::make_pair(schema.field_name_, (engine::meta::hybrid::DataType)schema.field_type_));
milvus::json json_param = milvus::json::parse(schema.index_param_);
index_params_.insert(std::make_pair("index_params", json_param));
}
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
......
......@@ -101,11 +101,11 @@ HybridSearchRequest::OnExecute() {
}
}
if (field_names_.empty()) {
for (const auto& field : fields_schema.fields_schema_) {
field_names_.emplace_back(field.field_name_);
}
}
// if (field_names_.empty()) {
// for (const auto& field : fields_schema.fields_schema_) {
// field_names_.emplace_back(field.field_name_);
// }
// }
status = DBWrapper::DB()->HybridQuery(context_, collection_name_, partition_list_, general_query_, query_ptr_,
field_names_, attr_type, result_);
......
......@@ -139,6 +139,7 @@ InsertEntityRequest::OnExecute() {
entity.attr_value_ = attr_values_;
entity.vector_data_.insert(std::make_pair(vector_datas_it->first, vector_datas_it->second));
entity.id_array_ = std::move(vector_datas_it->second.id_array_);
rc.RecordSection("prepare vectors data");
status = DBWrapper::DB()->InsertEntities(collection_name_, partition_tag_, field_names_, entity, field_types);
......
......@@ -268,6 +268,9 @@ ConstructEntityResults(const std::vector<engine::AttrsData>& attrs, const std::v
std::vector<float> float_data;
std::vector<double> double_data;
for (auto& attr : attrs) {
if (attr.attr_data_.find(field_name) == attr.attr_data_.end()) {
continue;
}
auto attr_data = attr.attr_data_.at(field_name);
int32_t grpc_int32_data;
int64_t grpc_int64_data;
......@@ -374,6 +377,7 @@ ConstructEntityResults(const std::vector<engine::AttrsData>& attrs, const std::v
memcpy(response->mutable_ids()->mutable_data(), id_array.data(), size * sizeof(int64_t));
auto grpc_field = response->add_fields();
grpc_field->set_field_name(vector_field_name);
::milvus::grpc::VectorRecord* grpc_vector_data = grpc_field->mutable_vector_record();
for (auto& vector : vectors) {
auto grpc_data = grpc_vector_data->add_records();
......@@ -1534,8 +1538,8 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery(
nlohmann::json dsl_json = json::parse(dsl_string);
auto status = Status::OK();
for (size_t i = 0; i < vector_params.size(); i++) {
std::string vector_string = vector_params.at(i).json();
for (const auto& vector_param : vector_params) {
std::string vector_string = vector_param.json();
nlohmann::json vector_json = json::parse(vector_string);
json::iterator it = vector_json.begin();
std::string placeholder = it.key();
......@@ -1554,7 +1558,7 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery(
}
engine::VectorsData vector_data;
CopyRowRecords(vector_params.at(i).row_record().records(),
CopyRowRecords(vector_param.row_record().records(),
google::protobuf::RepeatedField<google::protobuf::int64>(), vector_data);
vector_query->query_vector.binary_data = vector_data.binary_data_;
vector_query->query_vector.float_data = vector_data.float_data_;
......
......@@ -27,12 +27,12 @@ const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str();
constexpr int64_t COLLECTION_DIMENSION = 512;
constexpr int64_t COLLECTION_INDEX_FILE_SIZE = 1024;
constexpr milvus::MetricType COLLECTION_METRIC_TYPE = milvus::MetricType::L2;
constexpr int64_t BATCH_ENTITY_COUNT = 100000;
constexpr int64_t BATCH_ENTITY_COUNT = 10000;
constexpr int64_t NQ = 5;
constexpr int64_t TOP_K = 10;
constexpr int64_t NPROBE = 32;
constexpr int64_t SEARCH_TARGET = BATCH_ENTITY_COUNT / 2; // change this value, result is different
constexpr int64_t ADD_ENTITY_LOOP = 5;
constexpr int64_t ADD_ENTITY_LOOP = 1;
constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFFLAT;
constexpr int32_t NLIST = 16384;
......@@ -114,7 +114,7 @@ ClientTest::CreateCollection(const std::string& collection_name) {
field_ptr3->extra_params = extra_params_3.dump();
JSON extra_params;
extra_params["segment_size"] = " ";
extra_params["segment_size"] = 1024;
milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr3}};
milvus::Status stat = conn_->CreateCollection(mapping, extra_params.dump());
......@@ -126,30 +126,21 @@ ClientTest::GetCollectionInfo(const std::string& collection_name) {
}
void
ClientTest::InsertEntities(const std::string& collection_name, int64_t row_num) {
milvus::FieldValue field_value;
std::vector<int64_t> value1;
std::vector<float> value2;
value1.resize(row_num);
value2.resize(row_num);
for (uint64_t i = 0; i < row_num; ++i) {
value1[i] = i;
value2[i] = (float)(i + row_num);
}
field_value.int64_value.insert(std::make_pair("field_1", value1));
field_value.float_value.insert(std::make_pair("field_2", value2));
std::unordered_map<std::string, std::vector<milvus::VectorData>> vector_value;
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> record_ids;
{ // generate vectors
milvus_sdk::Utils::BuildEntities(0, row_num, entity_array, record_ids, COLLECTION_DIMENSION);
ClientTest::InsertEntities(const std::string& collection_name) {
for (int64_t i = 0; i < ADD_ENTITY_LOOP; i++) {
milvus::FieldValue field_value;
std::vector<int64_t> entity_ids;
int64_t begin_index = i * BATCH_ENTITY_COUNT;
{
milvus_sdk::TimeRecorder rc("Build entities No." + std::to_string(i));
milvus_sdk::Utils::BuildEntities(begin_index, begin_index + BATCH_ENTITY_COUNT, field_value, entity_ids,
COLLECTION_DIMENSION);
}
milvus::Status status = conn_->Insert(collection_name, "", field_value, entity_ids);
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
std::cout << "Returned id array count: " << entity_ids.size() << std::endl;
}
field_value.vector_value.insert(std::make_pair("field_vec", entity_array));
milvus::Status status = conn_->Insert(collection_name, "", field_value, record_ids);
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
}
void
......@@ -168,15 +159,29 @@ ClientTest::GetCollectionStats(const std::string& collection_name) {
std::cout << "GetCollectionStats function call status: " << stat.message() << std::endl;
}
void
ClientTest::BuildVectors(int64_t nq, int64_t dimension) {
search_entity_array_.clear();
search_id_array_.clear();
for (int64_t i = 0; i < nq; i++) {
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> record_ids;
int64_t index = i * BATCH_ENTITY_COUNT + SEARCH_TARGET;
milvus_sdk::Utils::ConstructVectors(index, index + 1, entity_array, record_ids, dimension);
search_entity_array_.push_back(std::make_pair(record_ids[0], entity_array[0]));
search_id_array_.push_back(record_ids[0]);
}
}
void
ClientTest::GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array) {
std::string result;
{
milvus_sdk::TimeRecorder rc("GetHybridEntityByID");
milvus_sdk::TimeRecorder rc("GetEntityByID");
milvus::Status stat = conn_->GetEntityByID(collection_name, id_array, result);
std::cout << "GetEntitiesByID function call status: " << stat.message() << std::endl;
std::cout << "GetEntityByID function call status: " << stat.message() << std::endl;
}
std::cout << "GetEntityByID function result: " << result << std::endl;
}
......@@ -185,18 +190,19 @@ ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int
nlohmann::json dsl_json, vector_param_json;
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json);
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> record_ids;
{ // generate vectors
milvus_sdk::Utils::ConstructVector(NQ, COLLECTION_DIMENSION, entity_array);
std::vector<milvus::VectorData> temp_entity_array;
for (auto& pair : search_entity_array_) {
temp_entity_array.push_back(pair.second);
}
milvus::VectorParam vector_param = {vector_param_json.dump(), entity_array};
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
auto status = conn_->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
std::cout << "Search function call result: " << std::endl;
milvus_sdk::Utils::PrintTopKQueryResult(topk_query_result);
std::cout << "Search function call status: " << status.message() << std::endl;
}
......@@ -312,19 +318,21 @@ ClientTest::Test() {
CreateCollection(collection_name);
GetCollectionInfo(collection_name);
InsertEntities(collection_name, 10000);
InsertEntities(collection_name);
Flush(collection_name);
GetCollectionStats(collection_name);
BuildVectors(NQ, COLLECTION_DIMENSION);
GetEntityByID(collection_name, search_id_array_);
SearchEntities(collection_name, TOP_K, NPROBE);
CreateIndex(collection_name, NLIST);
SearchEntities(collection_name, TOP_K, NPROBE);
// CreateIndex(collection_name, NLIST);
// SearchEntities(collection_name, TOP_K, NPROBE);
GetCollectionStats(collection_name);
std::vector<int64_t> delete_ids = {search_id_array_[0], search_id_array_[1]};
DeleteByIds(collection_name, delete_ids);
GetEntityByID(collection_name, search_id_array_);
CompactCollection(collection_name);
LoadCollection(collection_name);
......
......@@ -43,7 +43,7 @@ class ClientTest {
GetCollectionInfo(const std::string&);
void
InsertEntities(const std::string&, int64_t);
InsertEntities(const std::string&);
void BuildSearchEntities(int64_t, int64_t);
......@@ -53,6 +53,9 @@ class ClientTest {
void
GetCollectionStats(const std::string&);
void
BuildVectors(int64_t nq, int64_t dimension);
void
GetEntityByID(const std::string&, const std::vector<int64_t>&);
......@@ -82,6 +85,6 @@ class ClientTest {
private:
std::shared_ptr<milvus::Connection> conn_;
// std::vector<std::pair<int64_t, milvus::Entity>> search_entity_array_;
std::vector<std::pair<int64_t, milvus::VectorData>> search_entity_array_;
std::vector<int64_t> search_id_array_;
};
......@@ -25,6 +25,8 @@
namespace milvus_sdk {
constexpr int64_t SECONDS_EACH_HOUR = 3600;
constexpr int64_t BATCH_ENTITY_COUNT = 100000;
constexpr int64_t SEARCH_TARGET = BATCH_ENTITY_COUNT / 2; // change this value, result is different
#define BLOCK_SPLITER std::cout << "===========================================" << std::endl;
......@@ -152,16 +154,18 @@ Utils::PrintIndexParam(const milvus::IndexParam& index_param) {
}
void
Utils::BuildEntities(int64_t from, int64_t to, std::vector<milvus::VectorData>& entity_array,
std::vector<int64_t>& entity_ids, int64_t dimension) {
Utils::BuildEntities(int64_t from, int64_t to, milvus::FieldValue& field_value, std::vector<int64_t>& entity_ids,
int64_t dimension) {
if (to <= from) {
return;
}
int64_t row_num = to - from;
std::vector<int64_t> int_data(row_num);
std::vector<float> float_data(row_num);
std::vector<milvus::VectorData> entity_array;
entity_array.clear();
entity_ids.clear();
std::default_random_engine e;
std::uniform_real_distribution<float> u(0, 1);
for (int64_t k = from; k < to; k++) {
milvus::VectorData vector_data;
vector_data.float_data.resize(dimension);
......@@ -169,9 +173,15 @@ Utils::BuildEntities(int64_t from, int64_t to, std::vector<milvus::VectorData>&
vector_data.float_data[i] = (float)((k + 100) % (i + 1));
}
int_data[k - from] = k;
float_data[k - from] = (float)k + row_num;
entity_array.emplace_back(vector_data);
entity_ids.push_back(k);
}
field_value.int64_value.insert(std::make_pair("field_1", int_data));
field_value.float_value.insert(std::make_pair("field_2", float_data));
field_value.vector_value.insert(std::make_pair("field_vec", entity_array));
}
void
......@@ -250,15 +260,23 @@ Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& col
}
void
Utils::ConstructVector(uint64_t nq, uint64_t dimension, std::vector<milvus::VectorData>& query_vector) {
query_vector.resize(nq);
std::default_random_engine e;
std::uniform_real_distribution<float> u(0, 1);
for (uint64_t i = 0; i < nq; ++i) {
query_vector[i].float_data.resize(dimension);
for (uint64_t j = 0; j < dimension; ++j) {
query_vector[i].float_data[j] = (float)((j + 100) % (i + 1));
Utils::ConstructVectors(int64_t from, int64_t to, std::vector<milvus::VectorData>& query_vector,
std::vector<int64_t>& search_ids, int64_t dimension) {
if (to <= from) {
return;
}
query_vector.clear();
search_ids.clear();
for (int64_t k = from; k < to; k++) {
milvus::VectorData entity;
entity.float_data.resize(dimension);
for (int64_t i = 0; i < dimension; i++) {
entity.float_data[i] = (float)((k + 100) % (i + 1));
}
query_vector.emplace_back(entity);
search_ids.push_back(k);
}
}
......@@ -287,7 +305,17 @@ Utils::GenLeafQuery() {
uint64_t DIMENSION = 128;
uint64_t NPROBE = 32;
milvus::VectorQueryPtr vq = std::make_shared<milvus::VectorQuery>();
ConstructVector(NQ, DIMENSION, vq->query_vector);
std::vector<milvus::VectorData> search_entity_array;
for (int64_t i = 0; i < NQ; i++) {
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> record_ids;
int64_t index = i * BATCH_ENTITY_COUNT + SEARCH_TARGET;
milvus_sdk::Utils::ConstructVectors(index, index + 1, entity_array, record_ids, DIMENSION);
search_entity_array.push_back(entity_array[0]);
}
vq->query_vector = search_entity_array;
vq->field_name = "field_vec";
vq->topk = 10;
JSON json_params = {{"nprobe", NPROBE}};
......
......@@ -54,8 +54,8 @@ class Utils {
PrintIndexParam(const milvus::IndexParam& index_param);
static void
BuildEntities(int64_t from, int64_t to, std::vector<milvus::VectorData>& entity_array,
std::vector<int64_t>& entity_ids, int64_t dimension);
BuildEntities(int64_t from, int64_t to, milvus::FieldValue& field_value, std::vector<int64_t>& entity_ids,
int64_t dimension);
static void
PrintSearchResult(const std::vector<std::pair<int64_t, milvus::VectorData>>& entity_array,
......@@ -72,7 +72,8 @@ class Utils {
milvus::TopKQueryResult& topk_query_result);
static void
ConstructVector(uint64_t nq, uint64_t dimension, std::vector<milvus::VectorData>& query_vector);
ConstructVectors(int64_t from, int64_t to, std::vector<milvus::VectorData>& query_vector,
std::vector<int64_t>& search_ids, int64_t dimension);
static std::vector<milvus::LeafQueryPtr>
GenLeafQuery();
......
......@@ -292,10 +292,10 @@ CopyEntityToJson(::milvus::grpc::Entities& grpc_entities, JSON& json_entity) {
case ::milvus::grpc::FLOAT: {
row_num = grpc_attr_record.float_value_size();
std::vector<float> data(row_num);
memcpy(data.data(), grpc_attr_record.int64_value().data(), row_num * sizeof(float));
memcpy(data.data(), grpc_attr_record.float_value().data(), row_num * sizeof(float));
float_data.insert(std::make_pair(grpc_field.field_name(), data));
break;
}
}
case ::milvus::grpc::DOUBLE: {
row_num = grpc_attr_record.double_value_size();
std::vector<double> data(row_num);
......@@ -307,7 +307,8 @@ CopyEntityToJson(::milvus::grpc::Entities& grpc_entities, JSON& json_entity) {
row_num = grpc_vector_record.records_size();
std::vector<milvus::VectorData> data(row_num);
for (int j = 0; j < row_num; j++) {
data[j].float_data.resize(row_num);
size_t dim = grpc_vector_record.records(j).float_data_size();
data[j].float_data.resize(dim);
memcpy(data[j].float_data.data(), grpc_vector_record.records(j).float_data().data(),
row_num * sizeof(float));
}
......@@ -568,6 +569,7 @@ ClientProxy::GetEntityByID(const std::string& collection_name, const std::vector
JSON json_entities;
CopyEntityToJson(grpc_entities, json_entities);
entities = json_entities.dump();
return status;
} catch (std::exception& ex) {
return Status(StatusCode::UnknownError, "Failed to get entity by id: " + std::string(ex.what()));
......
......@@ -130,7 +130,7 @@ GrpcClient::GetEntityByID(const grpc::EntityIdentity& entity_identity, ::milvus:
::grpc::Status grpc_status = stub_->GetEntityByID(&context, entity_identity, &entities);
if (!grpc_status.ok()) {
std::cerr << "GetVectorByID rpc failed!" << std::endl;
std::cerr << "GetEntityByID rpc failed!" << std::endl;
return Status(StatusCode::RPCFailed, grpc_status.error_message());
}
if (entities.status().error_code() != grpc::SUCCESS) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册