diff --git a/paddle/ir/pass/analysis_manager.h b/paddle/ir/pass/analysis_manager.h index b43c12b8b349e998f6be12ab9d028bfc5bba83f0..417d9026b88d1c571353df9d114622f7f3c8fd4f 100644 --- a/paddle/ir/pass/analysis_manager.h +++ b/paddle/ir/pass/analysis_manager.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -25,7 +26,6 @@ #include "paddle/ir/core/type_name.h" #include "paddle/ir/pass/pass_instrumentation.h" #include "paddle/ir/pass/utils.h" -#include "paddle/utils/optional.h" namespace ir { @@ -149,14 +149,13 @@ class AnalysisMap { } template - paddle::optional> GetCachedAnalysis() - const { + std::optional> GetCachedAnalysis() const { auto res = analyses_.find(TypeId::get()); - if (res == analyses_.end()) return paddle::none; + if (res == analyses_.end()) return std::nullopt; return {static_cast&>(*res->second).analysis}; } - Operation* getOperation() const { return ir_; } + Operation* GetOperation() const { return ir_; } void Clear() { analyses_.clear(); } @@ -257,8 +256,7 @@ class AnalysisManager { } template - paddle::optional> GetCachedAnalysis() - const { + std::optional> GetCachedAnalysis() const { return analyses_->GetCachedAnalysis(); } @@ -269,11 +267,11 @@ class AnalysisManager { analyses_->Invalidate(pa); } - void clear() { analyses_->Clear(); } + void Clear() { analyses_->Clear(); } PassInstrumentor* GetPassInstrumentor() const { return instrumentor_; } - Operation* GetOperation() { return analyses_->getOperation(); } + Operation* GetOperation() { return analyses_->GetOperation(); } private: AnalysisManager(detail::AnalysisMap* impl, PassInstrumentor* pi) diff --git a/paddle/ir/pass/pass.h b/paddle/ir/pass/pass.h index e45407b8465172ec2a496b0bad455b5fa773e4fb..d785f3a801f4237f79320f37ee14be6be4d647f4 100644 --- a/paddle/ir/pass/pass.h +++ b/paddle/ir/pass/pass.h @@ -20,7 +20,7 @@ #include "paddle/ir/core/enforce.h" #include "paddle/ir/pass/analysis_manager.h" -#include "paddle/utils/optional.h" +#include "paddle/phi/core/enforce.h" namespace ir { @@ -91,8 +91,7 @@ class IR_API Pass { AnalysisManager analysis_manager() { return pass_state().am; } detail::PassExecutionState& pass_state() { - IR_ENFORCE(pass_state_.is_initialized() == true, - "pass state was never initialized"); + IR_ENFORCE(pass_state_.has_value() == true, "pass state has no value"); return *pass_state_; } @@ -101,7 +100,7 @@ class IR_API Pass { private: detail::PassInfo pass_info_; - paddle::optional pass_state_; + std::optional pass_state_; friend class PassManager; friend class detail::PassAdaptor; diff --git a/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.cc b/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.cc index ba9a680a3064e63faa07cd7181af7bd92ffc9e12..363595b91a988c518abef4ad3463c6e3fd93a760 100644 --- a/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.cc +++ b/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.cc @@ -14,13 +14,13 @@ #include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include #include #include #include #include #include "paddle/ir/core/op_info.h" -#include "paddle/utils/optional.h" namespace ir { @@ -76,19 +76,19 @@ FrozenRewritePatternSet::FrozenRewritePatternSet( continue; } - if (paddle::optional root_name = pat->root_kind()) { + if (std::optional root_name = pat->root_kind()) { impl_->op_specific_native_pattern_map_[*root_name].push_back(pat.get()); impl_->op_specific_native_patterns_.push_back(std::move(pat)); continue; } - if (paddle::optional interface_id = pat->GetRootInterfaceID()) { + if (std::optional interface_id = pat->GetRootInterfaceID()) { AddToOpsWhen( pat, [&](OpInfo info) { return info.HasInterface(*interface_id); }); continue; } - if (paddle::optional trait_id = pat->GetRootTraitID()) { + if (std::optional trait_id = pat->GetRootTraitID()) { AddToOpsWhen(pat, [&](OpInfo info) { return info.HasTrait(*trait_id); }); continue; } diff --git a/paddle/ir/pattern_rewrite/pattern_applicator.cc b/paddle/ir/pattern_rewrite/pattern_applicator.cc index 0a0a712afbeb289fd17c1463c545a163725c357d..7087efa9ac64ff16853736ab9eca8b2c277583e9 100644 --- a/paddle/ir/pattern_rewrite/pattern_applicator.cc +++ b/paddle/ir/pattern_rewrite/pattern_applicator.cc @@ -21,20 +21,20 @@ namespace ir { PatternApplicator::PatternApplicator( - const FrozenRewritePatternSet& frozen_patter_list) - : frozen_patter_list_(frozen_patter_list) {} + const FrozenRewritePatternSet& frozen_pattern_list) + : frozen_pattern_list_(frozen_pattern_list) {} -void PatternApplicator::ApplyCostModel(CostModel model) { +void PatternApplicator::ApplyCostModel(const CostModel& model) { // TODO(wilber): remove impossible patterns. patterns_.clear(); - for (const auto& it : frozen_patter_list_.op_specific_native_patterns()) { + for (const auto& it : frozen_pattern_list_.op_specific_native_patterns()) { for (const RewritePattern* pattern : it.second) { patterns_[it.first].push_back(pattern); } } any_op_patterns_.clear(); - for (auto& pattern : frozen_patter_list_.match_any_op_native_patterns()) { + for (auto& pattern : frozen_pattern_list_.match_any_op_native_patterns()) { any_op_patterns_.push_back(pattern.get()); } @@ -59,10 +59,11 @@ void PatternApplicator::ApplyCostModel(CostModel model) { void PatternApplicator::WalkAllPatterns( std::function walk) { - for (const auto& it : frozen_patter_list_.op_specific_native_patterns()) + for (const auto& it : frozen_pattern_list_.op_specific_native_patterns()) for (auto* pattern : it.second) walk(*pattern); - for (auto& it : frozen_patter_list_.match_any_op_native_patterns()) walk(*it); + for (const auto& it : frozen_pattern_list_.match_any_op_native_patterns()) + walk(*it); } bool PatternApplicator::MatchAndRewrite( diff --git a/paddle/ir/pattern_rewrite/pattern_applicator.h b/paddle/ir/pattern_rewrite/pattern_applicator.h index 5c4bc8784607b0a64891c644427265261b4c1423..d0eb4bce1acabb54c9829a99df1efe1a8385a9f4 100644 --- a/paddle/ir/pattern_rewrite/pattern_applicator.h +++ b/paddle/ir/pattern_rewrite/pattern_applicator.h @@ -39,7 +39,7 @@ class PatternApplicator { std::function on_failure = {}, std::function on_success = {}); - void ApplyCostModel(CostModel model); + void ApplyCostModel(const CostModel& model); void ApplyDefaultCostModel() { ApplyCostModel([](const Pattern& pattern) { return pattern.benefit(); }); @@ -48,7 +48,7 @@ class PatternApplicator { void WalkAllPatterns(std::function walk); private: - const FrozenRewritePatternSet& frozen_patter_list_; + const FrozenRewritePatternSet& frozen_pattern_list_; std::unordered_map> patterns_; std::vector any_op_patterns_; }; diff --git a/paddle/ir/pattern_rewrite/pattern_match.h b/paddle/ir/pattern_rewrite/pattern_match.h index 97f3bf09c69460373aa1a3a70b10725a530d4a95..dee11d6bd929620dfc77cc940faec3799e3013c9 100644 --- a/paddle/ir/pattern_rewrite/pattern_match.h +++ b/paddle/ir/pattern_rewrite/pattern_match.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include "paddle/ir/core/type_id.h" #include "paddle/ir/core/type_name.h" #include "paddle/ir/core/value.h" -#include "paddle/utils/optional.h" namespace ir { @@ -72,22 +72,22 @@ class IR_API Pattern { public: const std::vector& generated_ops() const { return generated_ops_; } - paddle::optional root_kind() const { + std::optional root_kind() const { if (root_kind_ == RootKind::OperationInfo) return OpInfo::RecoverFromOpaquePointer(root_val_); - return paddle::none; + return std::nullopt; } - paddle::optional GetRootInterfaceID() const { + std::optional GetRootInterfaceID() const { if (root_kind_ == RootKind::InterfaceId) return TypeId::RecoverFromOpaquePointer(root_val_); - return paddle::none; + return std::nullopt; } - paddle::optional GetRootTraitID() const { + std::optional GetRootTraitID() const { if (root_kind_ == RootKind::TraitId) return TypeId::RecoverFromOpaquePointer(root_val_); - return paddle::none; + return std::nullopt; } PatternBenefit benefit() const { return benefit_; } diff --git a/test/cpp/ir/pass/pass_manager_test.cc b/test/cpp/ir/pass/pass_manager_test.cc index 0b3aa35829fab7128400bc4f7ecce6cf53f1a945..22cb62dda27c55d182ae6933bc70dff1b31fd385 100644 --- a/test/cpp/ir/pass/pass_manager_test.cc +++ b/test/cpp/ir/pass/pass_manager_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include "glog/logging.h" #include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_type.h" @@ -29,6 +30,35 @@ #include "paddle/ir/pass/pass_manager.h" #include "paddle/phi/kernels/elementwise_add_kernel.h" +#ifndef _WIN32 +class TestAnalysis1 {}; +class TestAnalysis2 {}; + +IR_DECLARE_EXPLICIT_TYPE_ID(TestAnalysis1) +IR_DEFINE_EXPLICIT_TYPE_ID(TestAnalysis1) +IR_DECLARE_EXPLICIT_TYPE_ID(TestAnalysis2) +IR_DEFINE_EXPLICIT_TYPE_ID(TestAnalysis2) + +TEST(pass_manager, PreservedAnalyses) { + ir::detail::PreservedAnalyses pa; + CHECK_EQ(pa.IsNone(), true); + + CHECK_EQ(pa.IsPreserved(), false); + pa.Preserve(); + CHECK_EQ(pa.IsPreserved(), true); + pa.Unpreserve(); + CHECK_EQ(pa.IsPreserved(), false); + CHECK_EQ(pa.IsPreserved(), false); + pa.Preserve(); + CHECK_EQ(pa.IsPreserved(), true); + CHECK_EQ(pa.IsPreserved(), true); + CHECK_EQ(pa.IsAll(), false); + pa.PreserveAll(); + CHECK_EQ(pa.IsAll(), true); + CHECK_EQ(pa.IsNone(), false); +} +#endif + class AddOp : public ir::Op { public: using Op::Op; @@ -63,15 +93,49 @@ void AddOp::Build(ir::Builder &, IR_DECLARE_EXPLICIT_TYPE_ID(AddOp) IR_DEFINE_EXPLICIT_TYPE_ID(AddOp) +struct CountOpAnalysis { + explicit CountOpAnalysis(ir::Operation *container_op) { + IR_ENFORCE(container_op->num_regions() > 0, true); + + LOG(INFO) << "In CountOpAnalysis, op is " << container_op->name() << "\n"; + for (size_t i = 0; i < container_op->num_regions(); ++i) { + auto ®ion = container_op->region(i); + for (auto it = region.begin(); it != region.end(); ++it) { + auto *block = *it; + for (auto it = block->begin(); it != block->end(); ++it) { + ++count; + } + } + } + + LOG(INFO) << "-- count is " << count << "\n"; + } + + int count = 0; +}; + +IR_DECLARE_EXPLICIT_TYPE_ID(CountOpAnalysis) +IR_DEFINE_EXPLICIT_TYPE_ID(CountOpAnalysis) + class TestPass : public ir::Pass { public: TestPass() : ir::Pass("TestPass", 1) {} void Run(ir::Operation *op) override { + auto count_op_analysis = analysis_manager().GetAnalysis(); + pass_state().preserved_analyses.Preserve(); + CHECK_EQ(pass_state().preserved_analyses.IsPreserved(), + true); + CHECK_EQ(count_op_analysis.count, 4); + auto module_op = op->dyn_cast(); CHECK_EQ(module_op.operation(), op); CHECK_EQ(module_op.name(), module_op->name()); LOG(INFO) << "In " << pass_info().name << ": " << module_op->name() << std::endl; + + pass_state().preserved_analyses.Unpreserve(); + CHECK_EQ(pass_state().preserved_analyses.IsPreserved(), + false); } bool CanApplyOn(ir::Operation *op) const override { @@ -79,7 +143,11 @@ class TestPass : public ir::Pass { } }; -TEST(pass_manager_test, pass_manager) { +TEST(pass_manager, PassManager) { + // + // TODO(liuyuanle): remove test code other than pass manager + // + // (1) Init environment. ir::IrContext *ctx = ir::IrContext::Instance(); ir::Dialect *builtin_dialect = @@ -186,6 +254,7 @@ TEST(pass_manager_test, pass_manager) { op4->operand(0).type().dialect().GetRegisteredInterface(); // ir::Parameter *parameter_c = // c_interface->VariableToParameter(variable_c.get()); + std::unique_ptr parameter_c = c_interface->VariableToParameter(variable_c.get()); EXPECT_EQ(parameter_c->type(), dense_tensor_dtype); @@ -199,6 +268,10 @@ TEST(pass_manager_test, pass_manager) { EXPECT_EQ(program.block()->size() == 4, true); EXPECT_EQ(program.parameters_num() == 3, true); + // + // TODO(liuyuanle): remove the code above. + // + // (9) Test pass manager for program. ir::PassManager pm(ctx); diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index ea036ae9c31f1206d64bc3599cb91200bfa91f41..d26cbe9265325471763bb36a184d3e12c7b50158 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -20,7 +20,7 @@ #include "paddle/ir/core/ir_context.h" #include "paddle/ir/pattern_rewrite/pattern_match.h" -TEST(PatternBenefit, PatternBenefit) { +TEST(pattern_rewrite, PatternBenefit) { ir::PatternBenefit benefit1(1); EXPECT_EQ(benefit1.benefit(), 1U); ir::PatternBenefit benefit2(2); @@ -95,7 +95,7 @@ class TestPatternRewrite2 : public ir::OpRewritePattern { } }; -TEST(RewritePattern, OpRewritePattern) { +TEST(pattern_rewrite, RewritePatternSet) { ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); auto *test_dialect = ctx->GetOrRegisterDialect();