未验证 提交 a7f211f8 编写于 作者: C chen qingxiang 提交者: GitHub

add valid row implementation (#3053)

* fix bug in test case
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* fix bugs in test case
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* fix bug in DB GetEntityById interface
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* add valid row implementation
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>

* cancel annotation of index thread
Signed-off-by: Ngodchen0212 <qingxiang.chen@zilliz.com>
上级 666524a4
......@@ -95,7 +95,8 @@ class DB {
virtual Status
GetEntityByID(const std::string& collection_name, const IDNumbers& id_array,
const std::vector<std::string>& field_names, DataChunkPtr& data_chunk) = 0;
const std::vector<std::string>& field_names, std::vector<bool>& valid_row,
DataChunkPtr& data_chunk) = 0;
virtual Status
DeleteEntityByID(const std::string& collection_name, const engine::IDNumbers entity_ids) = 0;
......
......@@ -524,14 +524,16 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
Status
DBImpl::GetEntityByID(const std::string& collection_name, const IDNumbers& id_array,
const std::vector<std::string>& field_names, DataChunkPtr& data_chunk) {
const std::vector<std::string>& field_names, std::vector<bool>& valid_row,
DataChunkPtr& data_chunk) {
CHECK_INITIALIZED;
snapshot::ScopedSnapshotT ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name));
std::string dir_root = options_.meta_.path_;
auto handler = std::make_shared<GetEntityByIdSegmentHandler>(nullptr, ss, dir_root, id_array, field_names);
auto handler =
std::make_shared<GetEntityByIdSegmentHandler>(nullptr, ss, dir_root, id_array, field_names, valid_row);
handler->Iterate();
STATUS_CHECK(handler->GetStatus());
......
......@@ -90,7 +90,8 @@ class DBImpl : public DB {
Status
GetEntityByID(const std::string& collection_name, const IDNumbers& id_array,
const std::vector<std::string>& field_names, DataChunkPtr& data_chunk) override;
const std::vector<std::string>& field_names, std::vector<bool>& valid_row,
DataChunkPtr& data_chunk) override;
Status
DeleteEntityByID(const std::string& collection_name, const engine::IDNumbers entity_ids) override;
......
......@@ -123,8 +123,9 @@ SegmentsToIndexCollector::Handle(const snapshot::SegmentCommitPtr& segment_commi
GetEntityByIdSegmentHandler::GetEntityByIdSegmentHandler(const std::shared_ptr<milvus::server::Context>& context,
engine::snapshot::ScopedSnapshotT ss,
const std::string& dir_root, const IDNumbers& ids,
const std::vector<std::string>& field_names)
: BaseT(ss), context_(context), dir_root_(dir_root), ids_(ids), field_names_(field_names) {
const std::vector<std::string>& field_names,
std::vector<bool>& valid_row)
: BaseT(ss), context_(context), dir_root_(dir_root), ids_(ids), field_names_(field_names), valid_row_(valid_row) {
}
Status
......@@ -149,6 +150,7 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) {
for (auto id : ids_) {
// fast check using bloom filter
if (!id_bloom_filter_ptr->Check(id)) {
valid_row_.push_back(false);
continue;
}
......@@ -158,6 +160,7 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) {
}
auto found = std::find(uids.begin(), uids.end(), id);
if (found == uids.end()) {
valid_row_.push_back(false);
continue;
}
......@@ -170,9 +173,11 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) {
auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs();
auto deleted = std::find(deleted_docs.begin(), deleted_docs.end(), offset);
if (deleted != deleted_docs.end()) {
valid_row_.push_back(false);
continue;
}
}
valid_row_.push_back(true);
offsets.push_back(offset);
}
......
......@@ -79,7 +79,7 @@ struct GetEntityByIdSegmentHandler : public snapshot::IterateHandler<snapshot::S
using BaseT = snapshot::IterateHandler<ResourceT>;
GetEntityByIdSegmentHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss,
const std::string& dir_root, const IDNumbers& ids,
const std::vector<std::string>& field_names);
const std::vector<std::string>& field_names, std::vector<bool>& valid_row);
Status
Handle(const typename ResourceT::Ptr&) override;
......@@ -89,6 +89,7 @@ struct GetEntityByIdSegmentHandler : public snapshot::IterateHandler<snapshot::S
const engine::IDNumbers ids_;
const std::vector<std::string> field_names_;
engine::DataChunkPtr data_chunk_;
std::vector<bool>& valid_row_;
};
///////////////////////////////////////////////////////////////////////////////
......
......@@ -157,9 +157,10 @@ ReqHandler::Insert(const std::shared_ptr<Context>& context, const std::string& c
Status
ReqHandler::GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name,
const engine::IDNumbers& ids, std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) {
std::vector<bool>& valid_row, engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk) {
BaseReqPtr req_ptr =
GetEntityByIDReq::Create(context, collection_name, ids, field_names, field_mappings, data_chunk);
GetEntityByIDReq::Create(context, collection_name, ids, field_names, valid_row, field_mappings, data_chunk);
ReqScheduler::ExecReq(req_ptr);
return req_ptr->status();
}
......
......@@ -86,7 +86,7 @@ class ReqHandler {
Status
GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name,
const engine::IDNumbers& ids, std::vector<std::string>& field_names,
const engine::IDNumbers& ids, std::vector<std::string>& field_names, std::vector<bool>& valid_row,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
Status
......
......@@ -31,13 +31,14 @@ constexpr uint64_t MAX_COUNT_RETURNED = 1000;
GetEntityByIDReq::GetEntityByIDReq(const std::shared_ptr<milvus::server::Context>& context,
const std::string& collection_name, const engine::IDNumbers& id_array,
std::vector<std::string>& field_names,
std::vector<std::string>& field_names, std::vector<bool>& valid_row,
engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk)
: BaseReq(context, BaseReq::kGetEntityByID),
collection_name_(collection_name),
id_array_(id_array),
field_names_(field_names),
valid_row_(valid_row),
field_mappings_(field_mappings),
data_chunk_(data_chunk) {
}
......@@ -45,9 +46,10 @@ GetEntityByIDReq::GetEntityByIDReq(const std::shared_ptr<milvus::server::Context
BaseReqPtr
GetEntityByIDReq::Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const engine::IDNumbers& id_array, std::vector<std::string>& field_names_,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk) {
std::vector<bool>& valid_row, engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk) {
return std::shared_ptr<BaseReq>(
new GetEntityByIDReq(context, collection_name, id_array, field_names_, field_mappings, data_chunk));
new GetEntityByIDReq(context, collection_name, id_array, field_names_, valid_row, field_mappings, data_chunk));
}
Status
......@@ -82,8 +84,8 @@ GetEntityByIDReq::OnExecute() {
if (field_names_.empty()) {
for (const auto& schema : field_mappings_) {
if (schema.first->GetFtype() != engine::meta::hybrid::DataType::UID)
field_names_.emplace_back(schema.first->GetName());
// if (schema.first->GetFtype() != engine::meta::hybrid::DataType::UID)
field_names_.emplace_back(schema.first->GetName());
}
} else {
for (const auto& name : field_names_) {
......@@ -101,7 +103,7 @@ GetEntityByIDReq::OnExecute() {
}
// step 2: get vector data, now only support get one id
status = DBWrapper::DB()->GetEntityByID(collection_name_, id_array_, field_names_, data_chunk_);
status = DBWrapper::DB()->GetEntityByID(collection_name_, id_array_, field_names_, valid_row_, data_chunk_);
if (!status.ok()) {
return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_));
}
......
......@@ -33,13 +33,14 @@ class GetEntityByIDReq : public BaseReq {
public:
static BaseReqPtr
Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const engine::IDNumbers& id_array, std::vector<std::string>& field_names_,
const engine::IDNumbers& id_array, std::vector<std::string>& field_names_, std::vector<bool>& valid_row,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
protected:
GetEntityByIDReq(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const engine::IDNumbers& id_array, std::vector<std::string>& field_names,
engine::snapshot::CollectionMappings& field_mappings, engine::DataChunkPtr& data_chunk);
std::vector<bool>& valid_row, engine::snapshot::CollectionMappings& field_mappings,
engine::DataChunkPtr& data_chunk);
Status
OnExecute() override;
......@@ -50,6 +51,7 @@ class GetEntityByIDReq : public BaseReq {
std::vector<std::string>& field_names_;
engine::snapshot::CollectionMappings& field_mappings_;
engine::DataChunkPtr& data_chunk_;
std::vector<bool>& valid_row_;
};
} // namespace server
......
......@@ -747,9 +747,14 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
std::vector<engine::AttrsData> attrs;
std::vector<engine::VectorsData> vectors;
std::vector<bool> valid_row;
Status status = req_handler_.GetEntityByID(GetContext(context), request->collection_name(), vector_ids, field_names,
field_mappings, data_chunk);
valid_row, field_mappings, data_chunk);
for (auto it : valid_row) {
response->add_valid_row(it);
}
auto id_size = vector_ids.size();
for (const auto& it : field_mappings) {
......@@ -757,9 +762,16 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
std::string name = it.first->GetName();
std::vector<uint8_t> data = data_chunk->fixed_fields_[name];
auto single_size = data.size() / id_size;
if (type == engine::meta::hybrid::DataType::UID) {
response->mutable_ids()->Resize(data.size(), 0);
memcpy(response->mutable_ids()->mutable_data(), data.data(), data.size() * sizeof(uint64_t));
int64_t int64_value;
auto int64_size = single_size * sizeof(int8_t) / sizeof(int64_t);
for (int i = 0; i < id_size; i++) {
auto offset = i * single_size;
memcpy(&int64_value, data.data() + offset, single_size);
response->add_ids(int64_value);
}
continue;
}
......@@ -768,7 +780,6 @@ GrpcRequestHandler::GetEntityByID(::grpc::ServerContext* context, const ::milvus
field_value->set_field_name(name);
field_value->set_type(static_cast<milvus::grpc::DataType>(type));
auto single_size = data.size() / id_size;
// general data
if (type == engine::meta::hybrid::DataType::VECTOR_BINARY) {
// add binary vector data
......
......@@ -804,13 +804,14 @@ WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohman
Status
WebRequestHandler::GetEntityByIDs(const std::string& collection_name, const std::vector<int64_t>& ids,
std::vector<std::string>& field_names, nlohmann::json& json_out) {
std::vector<bool> valid_row;
engine::DataChunkPtr data_chunk;
engine::snapshot::CollectionMappings field_mappings;
std::vector<engine::AttrsData> attr_batch;
std::vector<engine::VectorsData> vector_batch;
auto status =
req_handler_.GetEntityByID(context_ptr_, collection_name, ids, field_names, field_mappings, data_chunk);
auto status = req_handler_.GetEntityByID(context_ptr_, collection_name, ids, field_names, valid_row, field_mappings,
data_chunk);
if (!status.ok()) {
return status;
}
......@@ -1679,6 +1680,7 @@ WebRequestHandler::GetEntity(const milvus::server::web::OString& collection_name
StringHelpFunctions::SplitStringByDelimeter(query_fields->c_str(), ",", field_names);
}
std::vector<bool> valid_row;
nlohmann::json entity_result_json;
status = GetEntityByIDs(collection_name->std_str(), entity_ids, field_names, entity_result_json);
if (!status.ok()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册