未验证 提交 f3f3f8a8 编写于 作者: Z zhagnlu 提交者: GitHub

Segcore retrieve by pk optimazation (#24659) (#24660)

Signed-off-by: Nluzhang <luzhang@zilliz.com>
Co-authored-by: Nluzhang <luzhang@zilliz.com>
上级 e40d95e0
...@@ -268,4 +268,10 @@ struct ExistsExpr : Expr { ...@@ -268,4 +268,10 @@ struct ExistsExpr : Expr {
accept(ExprVisitor&) override; accept(ExprVisitor&) override;
}; };
inline bool
IsTermExpr(Expr* expr) {
TermExpr* term_expr = dynamic_cast<TermExpr*>(expr);
return term_expr != nullptr;
}
} // namespace milvus::query } // namespace milvus::query
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "segcore/SegmentGrowingImpl.h" #include "segcore/SegmentGrowingImpl.h"
#include "query/ExprImpl.h" #include "query/ExprImpl.h"
#include "ExprVisitor.h" #include "ExprVisitor.h"
#include "ExecPlanNodeVisitor.h"
namespace milvus::query { namespace milvus::query {
...@@ -56,7 +57,20 @@ class ExecExprVisitor : public ExprVisitor { ...@@ -56,7 +57,20 @@ class ExecExprVisitor : public ExprVisitor {
ExecExprVisitor(const segcore::SegmentInternalInterface& segment, ExecExprVisitor(const segcore::SegmentInternalInterface& segment,
int64_t row_count, int64_t row_count,
Timestamp timestamp) Timestamp timestamp)
: segment_(segment), row_count_(row_count), timestamp_(timestamp) { : segment_(segment),
row_count_(row_count),
timestamp_(timestamp),
plan_visitor_(nullptr) {
}
ExecExprVisitor(const segcore::SegmentInternalInterface& segment,
ExecPlanNodeVisitor* plan_visitor,
int64_t row_count,
Timestamp timestamp)
: segment_(segment),
plan_visitor_(plan_visitor),
row_count_(row_count),
timestamp_(timestamp) {
} }
BitsetType BitsetType
...@@ -165,5 +179,6 @@ class ExecExprVisitor : public ExprVisitor { ...@@ -165,5 +179,6 @@ class ExecExprVisitor : public ExprVisitor {
int64_t row_count_; int64_t row_count_;
BitsetTypeOpt bitset_opt_; BitsetTypeOpt bitset_opt_;
ExecPlanNodeVisitor* plan_visitor_;
}; };
} // namespace milvus::query } // namespace milvus::query
...@@ -68,6 +68,31 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { ...@@ -68,6 +68,31 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
return ret; return ret;
} }
void
SetExprCacheOffsets(std::vector<int64_t>&& offsets) {
expr_cached_pk_id_offsets_ = std::move(offsets);
}
void
AddExprCacheOffset(int64_t offset) {
expr_cached_pk_id_offsets_.push_back(offset);
}
const std::vector<int64_t>&
GetExprCacheOffsets() {
return expr_cached_pk_id_offsets_;
}
void
SetExprUsePkIndex(bool use_pk_index) {
expr_use_pk_index_ = use_pk_index;
}
bool
GetExprUsePkIndex() {
return expr_use_pk_index_;
}
private: private:
template <typename VectorType> template <typename VectorType>
void void
...@@ -80,5 +105,7 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { ...@@ -80,5 +105,7 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
SearchResultOpt search_result_opt_; SearchResultOpt search_result_opt_;
RetrieveResultOpt retrieve_result_opt_; RetrieveResultOpt retrieve_result_opt_;
bool expr_use_pk_index_ = false;
std::vector<int64_t> expr_cached_pk_id_offsets_;
}; };
} // namespace milvus::query } // namespace milvus::query
...@@ -1709,9 +1709,16 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType { ...@@ -1709,9 +1709,16 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
auto [uids, seg_offsets] = segment_.search_ids(*id_array, timestamp_); auto [uids, seg_offsets] = segment_.search_ids(*id_array, timestamp_);
BitsetType bitset(row_count_); BitsetType bitset(row_count_);
std::vector<int64_t> cached_offsets;
for (const auto& offset : seg_offsets) { for (const auto& offset : seg_offsets) {
auto _offset = (int64_t)offset.get(); auto _offset = (int64_t)offset.get();
bitset[_offset] = true; bitset[_offset] = true;
cached_offsets.push_back(_offset);
}
// If enable plan_visitor pk index cache, pass offsets to it
if (plan_visitor_ != nullptr) {
plan_visitor_->SetExprUsePkIndex(true);
plan_visitor_->SetExprCacheOffsets(std::move(cached_offsets));
} }
AssertInfo(bitset.size() == row_count_, AssertInfo(bitset.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count"); "[ExecExprVisitor]Size of results not equal row count");
......
...@@ -100,7 +100,7 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) { ...@@ -100,7 +100,7 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
std::unique_ptr<BitsetType> bitset_holder; std::unique_ptr<BitsetType> bitset_holder;
if (node.predicate_.has_value()) { if (node.predicate_.has_value()) {
bitset_holder = std::make_unique<BitsetType>( bitset_holder = std::make_unique<BitsetType>(
ExecExprVisitor(*segment, active_count, timestamp_) ExecExprVisitor(*segment, this, active_count, timestamp_)
.call_child(*node.predicate_.value())); .call_child(*node.predicate_.value()));
bitset_holder->flip(); bitset_holder->flip();
} else { } else {
...@@ -166,8 +166,9 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) { ...@@ -166,8 +166,9 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
} }
if (node.predicate_.has_value() && node.predicate_.value() != nullptr) { if (node.predicate_.has_value() && node.predicate_.value() != nullptr) {
bitset_holder = ExecExprVisitor(*segment, active_count, timestamp_) bitset_holder =
.call_child(*(node.predicate_.value())); ExecExprVisitor(*segment, this, active_count, timestamp_)
.call_child(*(node.predicate_.value()));
bitset_holder.flip(); bitset_holder.flip();
} }
...@@ -188,7 +189,11 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) { ...@@ -188,7 +189,11 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
} }
BitsetView final_view = bitset_holder; BitsetView final_view = bitset_holder;
auto seg_offsets = segment->search_ids(final_view, timestamp_); auto seg_offsets =
GetExprUsePkIndex() && IsTermExpr(node.predicate_.value().get())
? segment->search_ids(
final_view, expr_cached_pk_id_offsets_, timestamp_)
: segment->search_ids(final_view, timestamp_);
retrieve_result.result_offsets_.assign( retrieve_result.result_offsets_.assign(
(int64_t*)seg_offsets.data(), (int64_t*)seg_offsets.data(),
(int64_t*)seg_offsets.data() + seg_offsets.size()); (int64_t*)seg_offsets.data() + seg_offsets.size());
......
...@@ -420,6 +420,22 @@ SegmentGrowingImpl::search_ids(const BitsetView& bitset, ...@@ -420,6 +420,22 @@ SegmentGrowingImpl::search_ids(const BitsetView& bitset,
return res_offsets; return res_offsets;
} }
std::vector<SegOffset>
SegmentGrowingImpl::search_ids(const BitsetView& bitset,
const std::vector<int64_t>& offsets,
Timestamp timestamp) const {
std::vector<SegOffset> res_offsets;
for (auto& offset : offsets) {
if (!bitset.test(offset)) {
if (insert_record_.timestamps_[offset] <= timestamp) {
res_offsets.push_back(SegOffset(offset));
}
}
}
return res_offsets;
}
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>> std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
SegmentGrowingImpl::search_ids(const IdArray& id_array, SegmentGrowingImpl::search_ids(const IdArray& id_array,
Timestamp timestamp) const { Timestamp timestamp) const {
......
...@@ -211,6 +211,11 @@ class SegmentGrowingImpl : public SegmentGrowing { ...@@ -211,6 +211,11 @@ class SegmentGrowingImpl : public SegmentGrowing {
std::vector<SegOffset> std::vector<SegOffset>
search_ids(const BitsetView& view, Timestamp timestamp) const override; search_ids(const BitsetView& view, Timestamp timestamp) const override;
std::vector<SegOffset>
search_ids(const BitsetView& view,
const std::vector<int64_t>& offsets,
Timestamp timestamp) const override;
bool bool
HasIndex(FieldId field_id) const override { HasIndex(FieldId field_id) const override {
return true; return true;
......
...@@ -186,6 +186,11 @@ class SegmentInternalInterface : public SegmentInterface { ...@@ -186,6 +186,11 @@ class SegmentInternalInterface : public SegmentInterface {
virtual std::vector<SegOffset> virtual std::vector<SegOffset>
search_ids(const BitsetView& view, Timestamp timestamp) const = 0; search_ids(const BitsetView& view, Timestamp timestamp) const = 0;
virtual std::vector<SegOffset>
search_ids(const BitsetView& view,
const std::vector<int64_t>& offsets,
Timestamp timestamp) const = 0;
virtual std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>> virtual std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
search_ids(const IdArray& id_array, Timestamp timestamp) const = 0; search_ids(const IdArray& id_array, Timestamp timestamp) const = 0;
......
...@@ -917,6 +917,21 @@ SegmentSealedImpl::search_ids(const BitsetView& bitset, ...@@ -917,6 +917,21 @@ SegmentSealedImpl::search_ids(const BitsetView& bitset,
return dst_offset; return dst_offset;
} }
std::vector<SegOffset>
SegmentSealedImpl::search_ids(const BitsetView& bitset,
const std::vector<int64_t>& offsets,
Timestamp timestamp) const {
std::vector<SegOffset> dst_offset;
for (auto& offset : offsets) {
if (!bitset.test(offset)) {
if (insert_record_.timestamps_[offset] <= timestamp) {
dst_offset.push_back(SegOffset(offset));
}
}
}
return dst_offset;
}
std::string std::string
SegmentSealedImpl::debug() const { SegmentSealedImpl::debug() const {
std::string log_str; std::string log_str;
......
...@@ -203,6 +203,11 @@ class SegmentSealedImpl : public SegmentSealed { ...@@ -203,6 +203,11 @@ class SegmentSealedImpl : public SegmentSealed {
std::vector<SegOffset> std::vector<SegOffset>
search_ids(const BitsetView& view, Timestamp timestamp) const override; search_ids(const BitsetView& view, Timestamp timestamp) const override;
std::vector<SegOffset>
search_ids(const BitsetView& view,
const std::vector<int64_t>& offsets,
Timestamp timestamp) const override;
std::vector<SegOffset> std::vector<SegOffset>
search_ids(const BitsetType& view, Timestamp timestamp) const override; search_ids(const BitsetType& view, Timestamp timestamp) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册