未验证 提交 a7f211f8 编写于 作者: C chen qingxiang 提交者: GitHub

add valid row implementation (#3053)

* fix bug in test case
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* fix bugs in test case
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* fix bug in DB GetEntityById interface
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* add valid row implementation
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* cancel annotation of index thread
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>
上级 666524a4
...@@ -95,7 +95,8 @@ class DB { ...@@ -95,7 +95,8 @@ class DB {
virtual Status virtual Status
GetEntityByID(const std::string& collection_name, const IDNumbers& id_array, GetEntityByID(const std::string& collection_name, const IDNumbers& id_array,
const std::vector<std::string>& field_names, DataChunkPtr& data_chunk) = 0; const std::vector<std::string>& field_names, std::vector<bool>& valid_row,
DataChunkPtr& data_chunk) = 0;
virtual Status virtual Status
DeleteEntityByID(const std::string& collection_name, const engine::IDNumbers entity_ids) = 0; DeleteEntityByID(const std::string& collection_name, const engine::IDNumbers entity_ids) = 0;
......
...@@ -524,14 +524,16 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_ ...@@ -524,14 +524,16 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
Status Status
DBImpl::GetEntityByID(const std::string& collection_name, const IDNumbers& id_array, DBImpl::GetEntityByID(const std::string& collection_name, const IDNumbers& id_array,
const std::vector<std::string>& field_names, DataChunkPtr& data_chunk) { const std::vector<std::string>& field_names, std::vector<bool>& valid_row,
DataChunkPtr& data_chunk) {
CHECK_INITIALIZED; CHECK_INITIALIZED;
snapshot::ScopedSnapshotT ss; snapshot::ScopedSnapshotT ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name));
std::string dir_root = options_.meta_.path_; std::string dir_root = options_.meta_.path_;
auto handler = std::make_shared<GetEntityByIdSegmentHandler>(nullptr, ss, dir_root, id_array, field_names); auto handler =
std::make_shared<GetEntityByIdSegmentHandler>(nullptr, ss, dir_root, id_array, field_names, valid_row);
handler->Iterate(); handler->Iterate();
STATUS_CHECK(handler->GetStatus()); STATUS_CHECK(handler->GetStatus());
......
...@@ -90,7 +90,8 @@ class DBImpl : public DB { ...@@ -90,7 +90,8 @@ class DBImpl : public DB {
Status Status
GetEntityByID(const std::string& collection_name, const IDNumbers& id_array, GetEntityByID(const std::string& collection_name, const IDNumbers& id_array,
const std::vector<std::string>& field_names, DataChunkPtr& data_chunk) override; const std::vector<std::string>& field_names, std::vector<bool>& valid_row,
DataChunkPtr& data_chunk) override;
Status Status
DeleteEntityByID(const std::string& collection_name, const engine::IDNumbers entity_ids) override; DeleteEntityByID(const std::string& collection_name, const engine::IDNumbers entity_ids) override;
......
...@@ -123,8 +123,9 @@ SegmentsToIndexCollector::Handle(const snapshot::SegmentCommitPtr& segment_commi ...@@ -123,8 +123,9 @@ SegmentsToIndexCollector::Handle(const snapshot::SegmentCommitPtr& segment_commi
GetEntityByIdSegmentHandler::GetEntityByIdSegmentHandler(const std::shared_ptr<milvus::server::Context>& context, GetEntityByIdSegmentHandler::GetEntityByIdSegmentHandler(const std::shared_ptr<milvus::server::Context>& context,
engine::snapshot::ScopedSnapshotT ss, engine::snapshot::ScopedSnapshotT ss,
const std::string& dir_root, const IDNumbers& ids, const std::string& dir_root, const IDNumbers& ids,
const std::vector<std::string>& field_names) const std::vector<std::string>& field_names,
: BaseT(ss), context_(context), dir_root_(dir_root), ids_(ids), field_names_(field_names) { std::vector<bool>& valid_row)
: BaseT(ss), context_(context), dir_root_(dir_root), ids_(ids), field_names_(field_names), valid_row_(valid_row) {
} }
Status Status
...@@ -149,6 +150,7 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) { ...@@ -149,6 +150,7 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) {
for (auto id : ids_) { for (auto id : ids_) {
// fast check using bloom filter // fast check using bloom filter
if (!id_bloom_filter_ptr->Check(id)) { if (!id_bloom_filter_ptr->Check(id)) {
valid_row_.push_back(false);
continue; continue;
} }
...@@ -158,6 +160,7 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) { ...@@ -158,6 +160,7 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) {
} }
auto found = std::find(uids.begin(), uids.end(), id); auto found = std::find(uids.begin(), uids.end(), id);
if (found == uids.end()) { if (found == uids.end()) {
valid_row_.push_back(false);
continue; continue;
} }
...@@ -170,9 +173,11 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) { ...@@ -170,9 +173,11 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) {
auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs(); auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs();
auto deleted = std::find(deleted_docs.begin(), deleted_docs.end(), offset); auto deleted = std::find(deleted_docs.begin(), deleted_docs.end(), offset);
if (deleted != deleted_docs.end()) { if (deleted != deleted_docs.end()) {
valid_row_.push_back(false);
continue; continue;
} }
} }
valid_row_.push_back(true);
offsets.push_back(offset); offsets.push_back(offset);
} }
......
...@@ -79,7 +79,7 @@ struct GetEntityByIdSegmentHandler : public snapshot::IterateHandler<snapshot::S ...@@ -79,7 +79,7 @@ struct GetEntityByIdSegmentHandler : public snapshot::IterateHandler<snapshot::S
using BaseT = snapshot::IterateHandler<ResourceT>; using BaseT = snapshot::IterateHandler<ResourceT>;
GetEntityByIdSegmentHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss, GetEntityByIdSegmentHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss,
const std::string& dir_root, const IDNumbers& ids, const std::string& dir_root, const IDNumbers& ids,
const std::vector<std::string>& field_names); const std::vector<std::string>& field_names, std::vector<bool>& valid_row);
Status Status
Handle(const typename ResourceT::Ptr&) override; Handle(const typename ResourceT::Ptr&) override;
...@@ -89,6 +89,7 @@ struct GetEntityByIdSegmentHandler : public snapshot::IterateHandler<snapshot::S ...@@ -89,6 +89,7 @@ struct GetEntityByIdSegmentHandler : public snapshot::IterateHandler<snapshot::S
const engine::IDNumbers ids_; const engine::IDNumbers ids_;
const std::vector<std::string> field_names_; const std::vector<std::string> field_names_;
engine::DataChunkPtr data_chunk_; engine::DataChunkPtr data_chunk_;
std::vector<bool>& valid_row_;
}; };
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
......
...@@ -157,9 +157,10 @@ ReqHandler::Insert(const std::shared_ptr<Context>& context, const std::string& c ...@@ -157,9 +157,10 @@ ReqHandler::Insert(const std::shared_ptr<Context>& context, const std::string& c
Status Status
ReqHandler::GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name, ReqHandler::GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name,
const engine::IDNumbers& ids, std::vector<std::string>& field_names, const engine::IDNumbers& ids, std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) { std::vector<bool>& valid_row, engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk) {
BaseReqPtr req_ptr = BaseReqPtr req_ptr =
GetEntityByIDReq::Create(context, collection_name, ids, field_names, field_mappings, data_chunk); GetEntityByIDReq::Create(context, collection_name, ids, field_names, valid_row, field_mappings, data_chunk);
ReqScheduler::ExecReq(req_ptr); ReqScheduler::ExecReq(req_ptr);
return req_ptr->status(); return req_ptr->status();
} }
......
...@@ -86,7 +86,7 @@ class ReqHandler { ...@@ -86,7 +86,7 @@ class ReqHandler {
Status Status
GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name, GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name,
const engine::IDNumbers& ids, std::vector<std::string>& field_names, const engine::IDNumbers& ids, std::vector<std::string>& field_names, std::vector<bool>& valid_row,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk); engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
Status Status
......
...@@ -31,13 +31,14 @@ constexpr uint64_t MAX_COUNT_RETURNED = 1000; ...@@ -31,13 +31,14 @@ constexpr uint64_t MAX_COUNT_RETURNED = 1000;
GetEntityByIDReq::GetEntityByIDReq(const std::shared_ptr<milvus::server::Context>& context, GetEntityByIDReq::GetEntityByIDReq(const std::shared_ptr<milvus::server::Context>& context,
const std::string& collection_name, const engine::IDNumbers& id_array, const std::string& collection_name, const engine::IDNumbers& id_array,
std::vector<std::string>& field_names, std::vector<std::string>& field_names, std::vector<bool>& valid_row,
engine::snapshot::CollectionMappings& field_mappings, engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk) engine::DataChunkPtr& data_chunk)
: BaseReq(context, BaseReq::kGetEntityByID), : BaseReq(context, BaseReq::kGetEntityByID),
collection_name_(collection_name), collection_name_(collection_name),
id_array_(id_array), id_array_(id_array),
field_names_(field_names), field_names_(field_names),
valid_row_(valid_row),
field_mappings_(field_mappings), field_mappings_(field_mappings),
data_chunk_(data_chunk) { data_chunk_(data_chunk) {
} }
...@@ -45,9 +46,10 @@ GetEntityByIDReq::GetEntityByIDReq(const std::shared_ptr<milvus::server::Context ...@@ -45,9 +46,10 @@ GetEntityByIDReq::GetEntityByIDReq(const std::shared_ptr<milvus::server::Context
BaseReqPtr BaseReqPtr
GetEntityByIDReq::Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name, GetEntityByIDReq::Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const engine::IDNumbers& id_array, std::vector<std::string>& field_names_, const engine::IDNumbers& id_array, std::vector<std::string>& field_names_,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) { std::vector<bool>& valid_row, engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk) {
return std::shared_ptr<BaseReq>( return std::shared_ptr<BaseReq>(
new GetEntityByIDReq(context, collection_name, id_array, field_names_, field_mappings, data_chunk)); new GetEntityByIDReq(context, collection_name, id_array, field_names_, valid_row, field_mappings, data_chunk));
} }
Status Status
...@@ -82,7 +84,7 @@ GetEntityByIDReq::OnExecute() { ...@@ -82,7 +84,7 @@ GetEntityByIDReq::OnExecute() {
if (field_names_.empty()) { if (field_names_.empty()) {
for (const auto& schema : field_mappings_) { for (const auto& schema : field_mappings_) {
if (schema.first->GetFtype() != engine::meta::hybrid::DataType::UID) // if (schema.first->GetFtype() != engine::meta::hybrid::DataType::UID)
field_names_.emplace_back(schema.first->GetName()); field_names_.emplace_back(schema.first->GetName());
} }
} else { } else {
...@@ -101,7 +103,7 @@ GetEntityByIDReq::OnExecute() { ...@@ -101,7 +103,7 @@ GetEntityByIDReq::OnExecute() {
} }
// step 2: get vector data, now only support get one id // step 2: get vector data, now only support get one id
status = DBWrapper::DB()->GetEntityByID(collection_name_, id_array_, field_names_, data_chunk_); status = DBWrapper::DB()->GetEntityByID(collection_name_, id_array_, field_names_, valid_row_, data_chunk_);
if (!status.ok()) { if (!status.ok()) {
return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_)); return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_));
} }
......
...@@ -33,13 +33,14 @@ class GetEntityByIDReq : public BaseReq { ...@@ -33,13 +33,14 @@ class GetEntityByIDReq : public BaseReq {
public: public:
static BaseReqPtr static BaseReqPtr
Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name, Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const engine::IDNumbers& id_array, std::vector<std::string>& field_names_, const engine::IDNumbers& id_array, std::vector<std::string>& field_names_, std::vector<bool>& valid_row,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk); engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
protected: protected:
GetEntityByIDReq(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name, GetEntityByIDReq(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const engine::IDNumbers& id_array, std::vector<std::string>& field_names, const engine::IDNumbers& id_array, std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk); std::vector<bool>& valid_row, engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk);
Status Status
OnExecute() override; OnExecute() override;
...@@ -50,6 +51,7 @@ class GetEntityByIDReq : public BaseReq { ...@@ -50,6 +51,7 @@ class GetEntityByIDReq : public BaseReq {
std::vector<std::string>& field_names_; std::vector<std::string>& field_names_;
engine::snapshot::CollectionMappings& field_mappings_; engine::snapshot::CollectionMappings& field_mappings_;
engine::DataChunkPtr& data_chunk_; engine::DataChunkPtr& data_chunk_;
std::vector<bool>& valid_row_;
}; };
} // namespace server } // namespace server
......
...@@ -747,9 +747,14 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus ...@@ -747,9 +747,14 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
std::vector<engine::AttrsData> attrs; std::vector<engine::AttrsData> attrs;
std::vector<engine::VectorsData> vectors; std::vector<engine::VectorsData> vectors;
std::vector<bool> valid_row;
Status status = req_handler_.GetEntityByID(GetContext(context), request->collection_name(), vector_ids, field_names, Status status = req_handler_.GetEntityByID(GetContext(context), request->collection_name(), vector_ids, field_names,
field_mappings, data_chunk); valid_row, field_mappings, data_chunk);
for (auto it : valid_row) {
response->add_valid_row(it);
}
auto id_size = vector_ids.size(); auto id_size = vector_ids.size();
for (const auto& it : field_mappings) { for (const auto& it : field_mappings) {
...@@ -757,9 +762,16 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus ...@@ -757,9 +762,16 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
std::string name = it.first->GetName(); std::string name = it.first->GetName();
std::vector<uint8_t> data = data_chunk->fixed_fields_[name]; std::vector<uint8_t> data = data_chunk->fixed_fields_[name];
auto single_size = data.size() / id_size;
if (type == engine::meta::hybrid::DataType::UID) { if (type == engine::meta::hybrid::DataType::UID) {
response->mutable_ids()->Resize(data.size(), 0); int64_t int64_value;
memcpy(response->mutable_ids()->mutable_data(), data.data(), data.size() * sizeof(uint64_t)); auto int64_size = single_size * sizeof(int8_t) / sizeof(int64_t);
for (int i = 0; i < id_size; i++) {
auto offset = i * single_size;
memcpy(&int64_value, data.data() + offset, single_size);
response->add_ids(int64_value);
}
continue; continue;
} }
...@@ -768,7 +780,6 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus ...@@ -768,7 +780,6 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
field_value->set_field_name(name); field_value->set_field_name(name);
field_value->set_type(static_cast<milvus::grpc::DataType>(type)); field_value->set_type(static_cast<milvus::grpc::DataType>(type));
auto single_size = data.size() / id_size;
// general data // general data
if (type == engine::meta::hybrid::DataType::VECTOR_BINARY) { if (type == engine::meta::hybrid::DataType::VECTOR_BINARY) {
// add binary vector data // add binary vector data
......
...@@ -804,13 +804,14 @@ WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohman ...@@ -804,13 +804,14 @@ WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohman
Status Status
WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std::vector<int64_t>& ids, WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std::vector<int64_t>& ids,
std::vector<std::string>& field_names, nlohmann::json& json_out) { std::vector<std::string>& field_names, nlohmann::json& json_out) {
std::vector<bool> valid_row;
engine::DataChunkPtr data_chunk; engine::DataChunkPtr data_chunk;
engine::snapshot::CollectionMappings field_mappings; engine::snapshot::CollectionMappings field_mappings;
std::vector<engine::AttrsData> attr_batch; std::vector<engine::AttrsData> attr_batch;
std::vector<engine::VectorsData> vector_batch; std::vector<engine::VectorsData> vector_batch;
auto status = auto status = req_handler_.GetEntityByID(context_ptr_, collection_name, ids, field_names, valid_row, field_mappings,
req_handler_.GetEntityByID(context_ptr_, collection_name, ids, field_names, field_mappings, data_chunk); data_chunk);
if (!status.ok()) { if (!status.ok()) {
return status; return status;
} }
...@@ -1679,6 +1680,7 @@ WebRequestHandler::GetEntity(const milvus::server::web::OString& collection_name ...@@ -1679,6 +1680,7 @@ WebRequestHandler::GetEntity(const milvus::server::web::OString& collection_name
StringHelpFunctions::SplitStringByDelimeter(query_fields->c_str(), ",", field_names); StringHelpFunctions::SplitStringByDelimeter(query_fields->c_str(), ",", field_names);
} }
std::vector<bool> valid_row;
nlohmann::json entity_result_json; nlohmann::json entity_result_json;
status = GetEntityByIDs(collection_name->std_str(), entity_ids, field_names, entity_result_json); status = GetEntityByIDs(collection_name->std_str(), entity_ids, field_names, entity_result_json);
if (!status.ok()) { if (!status.ok()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册