diff --git a/core/src/db/SnapshotHandlers.h b/core/src/db/SnapshotHandlers.h index fc81534cab3f532634af165af774423dd1f45a44..fe8602f8c98391777658b8ce4c79523da67d2dc9 100644 --- a/core/src/db/SnapshotHandlers.h +++ b/core/src/db/SnapshotHandlers.h @@ -25,7 +25,7 @@ namespace milvus { namespace engine { -struct LoadVectorFieldElementHandler : public snapshot::IterateHandler { +struct LoadVectorFieldElementHandler : public snapshot::FieldElementIterator { using ResourceT = snapshot::FieldElement; using BaseT = snapshot::IterateHandler; LoadVectorFieldElementHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss, @@ -38,7 +38,7 @@ struct LoadVectorFieldElementHandler : public snapshot::IterateHandler { +struct LoadVectorFieldHandler : public snapshot::FieldIterator { using ResourceT = snapshot::Field; using BaseT = snapshot::IterateHandler; LoadVectorFieldHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss); @@ -49,7 +49,7 @@ struct LoadVectorFieldHandler : public snapshot::IterateHandler const server::ContextPtr context_; }; -struct SegmentsToSearchCollector : public snapshot::IterateHandler { +struct SegmentsToSearchCollector : public snapshot::SegmentCommitIterator { using ResourceT = snapshot::SegmentCommit; using BaseT = snapshot::IterateHandler; SegmentsToSearchCollector(snapshot::ScopedSnapshotT ss, snapshot::IDS_TYPE& segment_ids); @@ -60,7 +60,7 @@ struct SegmentsToSearchCollector : public snapshot::IterateHandler { +struct SegmentsToIndexCollector : public snapshot::SegmentCommitIterator { using ResourceT = snapshot::SegmentCommit; using BaseT = snapshot::IterateHandler; SegmentsToIndexCollector(snapshot::ScopedSnapshotT ss, const std::string& field_name, @@ -74,7 +74,7 @@ struct SegmentsToIndexCollector : public snapshot::IterateHandler { +struct GetEntityByIdSegmentHandler : public snapshot::SegmentIterator { using ResourceT = snapshot::Segment; using BaseT = snapshot::IterateHandler; GetEntityByIdSegmentHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss, diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index ba6ba691b21d4e78f61bfc1a0b56a8c6201119cf..fdafccfeef610837ba2a3ed080d8e3ef43c13879 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -742,8 +742,8 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col auto field_json = field->GetParams(); auto dimension = field_json[milvus::knowhere::meta::DIM]; - auto segment_commit = snapshot->GetSegmentCommitBySegmentId(segment->GetID()); - auto row_count = segment_commit->GetRowCount(); + snapshot::SIZE_TYPE row_count; + snapshot->GetSegmentRowCount(segment->GetID(), row_count); milvus::json conf = index_info.extra_params_; conf[knowhere::meta::DIM] = dimension; diff --git a/core/src/db/snapshot/CompoundOperations.cpp b/core/src/db/snapshot/CompoundOperations.cpp index 97805c19da016b4d6e1039ad9e0ab34d308e09ee..84818d1ea28aff4212df88aa58af41134c11ac8e 100644 --- a/core/src/db/snapshot/CompoundOperations.cpp +++ b/core/src/db/snapshot/CompoundOperations.cpp @@ -17,6 +17,7 @@ #include #include "db/meta/MetaAdapter.h" +#include "db/snapshot/IterateHandler.h" #include "db/snapshot/OperationExecutor.h" #include "db/snapshot/ResourceContext.h" #include "db/snapshot/Snapshots.h" @@ -383,18 +384,21 @@ DropAllIndexOperation::DropAllIndexOperation(const OperationContext& context, Sc Status DropAllIndexOperation::PreCheck() { - if (context_.stale_field_element == nullptr) { + if (context_.stale_field_elements.size() == 0) { std::stringstream emsg; emsg << GetRepr() << ". Stale field element is requried"; return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); } - if (!GetStartedSS()->GetResource(context_.stale_field_element->GetID())) { - std::stringstream emsg; - emsg << GetRepr() << ". Specified field element " << context_.stale_field_element->GetName(); - emsg << " is stale"; - return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + for (auto stale_fe : context_.stale_field_elements) { + if (!GetStartedSS()->GetResource(stale_fe->GetID())) { + std::stringstream emsg; + emsg << GetRepr() << ". Specified field element " << stale_fe->GetName(); + emsg << " is stale"; + return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + } } + // TODO: Check type return Status::OK(); } @@ -406,7 +410,6 @@ DropAllIndexOperation::DoExecute(StorePtr store) { OperationContext cc_context; { auto context = context_; - context.stale_field_elements.push_back(context.stale_field_element); FieldCommitOperation fc_op(context, GetAdjustedSS()); STATUS_CHECK(fc_op(store)); @@ -430,20 +433,40 @@ DropAllIndexOperation::DoExecute(StorePtr store) { } std::map> p_sc_map; - for (auto& kv : segment_files) { - if (kv.second->GetFieldElementId() != context_.stale_field_element->GetID()) { - continue; - } + std::set stale_fe_ids; + for (auto& fe : context_.stale_field_elements) { + stale_fe_ids.insert(fe->GetID()); + } + + auto seg_executor = [&](const SegmentPtr& segment, SegmentIterator* handler) -> Status { + auto sf_ids = handler->ss_->GetSegmentFileIds(segment->GetID()); + if (sf_ids.size() == 0) { + return Status::OK(); + } auto context = context_; - context.stale_segment_files.push_back(kv.second.Get()); + for (auto& sf_id : sf_ids) { + auto sf = handler->ss_->GetResource(sf_id); + if (stale_fe_ids.find(sf->GetFieldElementId()) == stale_fe_ids.end()) { + continue; + } + context.stale_segment_files.push_back(sf); + } + if (context.stale_segment_files.size() == 0) { + return Status::OK(); + } SegmentCommitOperation sc_op(context, GetAdjustedSS()); STATUS_CHECK(sc_op(store)); STATUS_CHECK(sc_op.GetResource(context.new_segment_commit)); auto segc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); AddStepWithLsn(*context.new_segment_commit, context.lsn, segc_ctx_p); p_sc_map[context.new_segment_commit->GetPartitionId()].push_back(context.new_segment_commit); - } + return Status::OK(); + }; + + auto segment_iter = std::make_shared(GetAdjustedSS(), seg_executor); + segment_iter->Iterate(); + STATUS_CHECK(segment_iter->GetStatus()); for (auto& kv : p_sc_map) { auto& partition_id = kv.first; diff --git a/core/src/db/snapshot/Context.h b/core/src/db/snapshot/Context.h index 5987d4f5d6b818c92e52b8ba5690f32c634481ec..ad6cb02fc69c27af21c810400548ef9426f6d0be 100644 --- a/core/src/db/snapshot/Context.h +++ b/core/src/db/snapshot/Context.h @@ -66,9 +66,6 @@ struct OperationContext { std::vector stale_segments; - FieldPtr prev_field = nullptr; - FieldElementPtr prev_field_element = nullptr; - FieldElementPtr stale_field_element = nullptr; std::vector new_field_elements; std::vector stale_field_elements; diff --git a/core/src/db/snapshot/IterateHandler.h b/core/src/db/snapshot/IterateHandler.h index 034b3299cc10b9d50eaabffca50d2102634c0e40..cb72aa6d813d6ef4081375b8826fb7f6da295faf 100644 --- a/core/src/db/snapshot/IterateHandler.h +++ b/core/src/db/snapshot/IterateHandler.h @@ -72,6 +72,7 @@ struct IterateHandler : public std::enable_shared_from_this> { using CollectionIterator = IterateHandler; using PartitionIterator = IterateHandler; +using SegmentCommitIterator = IterateHandler; using SegmentIterator = IterateHandler; using SegmentFileIterator = IterateHandler; using FieldIterator = IterateHandler; diff --git a/core/src/db/snapshot/Snapshot.cpp b/core/src/db/snapshot/Snapshot.cpp index d6ba6cc0a23e7827271f40a71182bb5004dea9d7..df72ca4696b9903de9716df9ea1c508e7144b0c3 100644 --- a/core/src/db/snapshot/Snapshot.cpp +++ b/core/src/db/snapshot/Snapshot.cpp @@ -130,6 +130,19 @@ Snapshot::Snapshot(StorePtr store, ID_TYPE ss_id) { RefAll(); } +Status +Snapshot::GetSegmentRowCount(ID_TYPE segment_id, SIZE_TYPE& row_cnt) const { + auto sc = GetSegmentCommitBySegmentId(segment_id); + if (!sc) { + std::stringstream emsg; + emsg << "Snapshot::GetSegmentRowCount: Specified segment \"" << segment_id; + emsg << "\" not found"; + return Status(SS_NOT_FOUND_ERROR, emsg.str()); + } + row_cnt = sc->GetRowCount(); + return Status::OK(); +} + FieldPtr Snapshot::GetField(const std::string& name) const { auto it = field_names_map_.find(name); diff --git a/core/src/db/snapshot/Snapshot.h b/core/src/db/snapshot/Snapshot.h index 45250f6ec424f23718f42c08d95b5b273ec5608f..2400537eb82fa4123af45a20aa07b2b31494ed18 100644 --- a/core/src/db/snapshot/Snapshot.h +++ b/core/src/db/snapshot/Snapshot.h @@ -131,7 +131,6 @@ class Snapshot : public ReferenceProxy { GetFieldElement(const std::string& field_name, const std::string& field_element_name, FieldElementPtr& field_element) const; - // PXU TODO: add const. Need to change Scopedxxxx::Get SegmentCommitPtr GetSegmentCommitBySegmentId(ID_TYPE segment_id) const { auto it = seg_segc_map_.find(segment_id); @@ -140,6 +139,9 @@ class Snapshot : public ReferenceProxy { return GetResource(it->second); } + Status + GetSegmentRowCount(ID_TYPE segment_id, SIZE_TYPE&) const; + std::vector GetPartitionNames() const { std::vector names; diff --git a/core/unittest/db/test_snapshot.cpp b/core/unittest/db/test_snapshot.cpp index 966e6b63ec7716629b78fa921e47daa100178d5f..66b11cb913a377b484954d02c5fd8b8f865bfc13 100644 --- a/core/unittest/db/test_snapshot.cpp +++ b/core/unittest/db/test_snapshot.cpp @@ -573,14 +573,14 @@ TEST_F(SnapshotTest, IndexTest) { OperationContext d_a_i_ctx; d_a_i_ctx.lsn = next_lsn(); - d_a_i_ctx.stale_field_element = ss->GetResource(field_element_id); + d_a_i_ctx.stale_field_elements.push_back(ss->GetResource(field_element_id)); FieldElement::Ptr fe; status = ss->GetFieldElement(sf_context.field_name, sf_context.field_element_name, fe); ASSERT_TRUE(status.ok()); - ASSERT_EQ(fe, d_a_i_ctx.stale_field_element); + ASSERT_EQ(fe, d_a_i_ctx.stale_field_elements[0]); auto drop_all_index_op = std::make_shared(d_a_i_ctx, ss); status = drop_all_index_op->Push();