未验证 提交 9307d357 编写于 作者: Y Yuanle Liu 提交者: GitHub

add ut for pass_infra (AnalysisManager, PreservedAnalyses) (#54849)

上级 fa44ea5c
......@@ -16,6 +16,7 @@
#include <algorithm>
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -25,7 +26,6 @@
#include "paddle/ir/core/type_name.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/utils.h"
#include "paddle/utils/optional.h"
namespace ir {
......@@ -149,14 +149,13 @@ class AnalysisMap {
}
template <typename AnalysisT>
paddle::optional<std::reference_wrapper<AnalysisT>> GetCachedAnalysis()
const {
std::optional<std::reference_wrapper<AnalysisT>> GetCachedAnalysis() const {
auto res = analyses_.find(TypeId::get<AnalysisT>());
if (res == analyses_.end()) return paddle::none;
if (res == analyses_.end()) return std::nullopt;
return {static_cast<AnalysisModel<AnalysisT>&>(*res->second).analysis};
}
Operation* getOperation() const { return ir_; }
Operation* GetOperation() const { return ir_; }
void Clear() { analyses_.clear(); }
......@@ -257,8 +256,7 @@ class AnalysisManager {
}
template <typename AnalysisT>
paddle::optional<std::reference_wrapper<AnalysisT>> GetCachedAnalysis()
const {
std::optional<std::reference_wrapper<AnalysisT>> GetCachedAnalysis() const {
return analyses_->GetCachedAnalysis<AnalysisT>();
}
......@@ -269,11 +267,11 @@ class AnalysisManager {
analyses_->Invalidate(pa);
}
void clear() { analyses_->Clear(); }
void Clear() { analyses_->Clear(); }
PassInstrumentor* GetPassInstrumentor() const { return instrumentor_; }
Operation* GetOperation() { return analyses_->getOperation(); }
Operation* GetOperation() { return analyses_->GetOperation(); }
private:
AnalysisManager(detail::AnalysisMap* impl, PassInstrumentor* pi)
......
......@@ -20,7 +20,7 @@
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/pass/analysis_manager.h"
#include "paddle/utils/optional.h"
#include "paddle/phi/core/enforce.h"
namespace ir {
......@@ -91,8 +91,7 @@ class IR_API Pass {
AnalysisManager analysis_manager() { return pass_state().am; }
detail::PassExecutionState& pass_state() {
IR_ENFORCE(pass_state_.is_initialized() == true,
"pass state was never initialized");
IR_ENFORCE(pass_state_.has_value() == true, "pass state has no value");
return *pass_state_;
}
......@@ -101,7 +100,7 @@ class IR_API Pass {
private:
detail::PassInfo pass_info_;
paddle::optional<detail::PassExecutionState> pass_state_;
std::optional<detail::PassExecutionState> pass_state_;
friend class PassManager;
friend class detail::PassAdaptor;
......
......@@ -14,13 +14,13 @@
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include <algorithm>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include "paddle/ir/core/op_info.h"
#include "paddle/utils/optional.h"
namespace ir {
......@@ -76,19 +76,19 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
continue;
}
if (paddle::optional<OpInfo> root_name = pat->root_kind()) {
if (std::optional<OpInfo> root_name = pat->root_kind()) {
impl_->op_specific_native_pattern_map_[*root_name].push_back(pat.get());
impl_->op_specific_native_patterns_.push_back(std::move(pat));
continue;
}
if (paddle::optional<TypeId> interface_id = pat->GetRootInterfaceID()) {
if (std::optional<TypeId> interface_id = pat->GetRootInterfaceID()) {
AddToOpsWhen(
pat, [&](OpInfo info) { return info.HasInterface(*interface_id); });
continue;
}
if (paddle::optional<TypeId> trait_id = pat->GetRootTraitID()) {
if (std::optional<TypeId> trait_id = pat->GetRootTraitID()) {
AddToOpsWhen(pat, [&](OpInfo info) { return info.HasTrait(*trait_id); });
continue;
}
......
......@@ -21,20 +21,20 @@
namespace ir {
PatternApplicator::PatternApplicator(
const FrozenRewritePatternSet& frozen_patter_list)
: frozen_patter_list_(frozen_patter_list) {}
const FrozenRewritePatternSet& frozen_pattern_list)
: frozen_pattern_list_(frozen_pattern_list) {}
void PatternApplicator::ApplyCostModel(CostModel model) {
void PatternApplicator::ApplyCostModel(const CostModel& model) {
// TODO(wilber): remove impossible patterns.
patterns_.clear();
for (const auto& it : frozen_patter_list_.op_specific_native_patterns()) {
for (const auto& it : frozen_pattern_list_.op_specific_native_patterns()) {
for (const RewritePattern* pattern : it.second) {
patterns_[it.first].push_back(pattern);
}
}
any_op_patterns_.clear();
for (auto& pattern : frozen_patter_list_.match_any_op_native_patterns()) {
for (auto& pattern : frozen_pattern_list_.match_any_op_native_patterns()) {
any_op_patterns_.push_back(pattern.get());
}
......@@ -59,10 +59,11 @@ void PatternApplicator::ApplyCostModel(CostModel model) {
void PatternApplicator::WalkAllPatterns(
std::function<void(const Pattern&)> walk) {
for (const auto& it : frozen_patter_list_.op_specific_native_patterns())
for (const auto& it : frozen_pattern_list_.op_specific_native_patterns())
for (auto* pattern : it.second) walk(*pattern);
for (auto& it : frozen_patter_list_.match_any_op_native_patterns()) walk(*it);
for (const auto& it : frozen_pattern_list_.match_any_op_native_patterns())
walk(*it);
}
bool PatternApplicator::MatchAndRewrite(
......
......@@ -39,7 +39,7 @@ class PatternApplicator {
std::function<void(const Pattern&)> on_failure = {},
std::function<bool(const Pattern&)> on_success = {});
void ApplyCostModel(CostModel model);
void ApplyCostModel(const CostModel& model);
void ApplyDefaultCostModel() {
ApplyCostModel([](const Pattern& pattern) { return pattern.benefit(); });
......@@ -48,7 +48,7 @@ class PatternApplicator {
void WalkAllPatterns(std::function<void(const Pattern&)> walk);
private:
const FrozenRewritePatternSet& frozen_patter_list_;
const FrozenRewritePatternSet& frozen_pattern_list_;
std::unordered_map<OpInfo, std::vector<const RewritePattern*>> patterns_;
std::vector<const RewritePattern*> any_op_patterns_;
};
......
......@@ -19,6 +19,7 @@
#include <functional>
#include <initializer_list>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <vector>
......@@ -30,7 +31,6 @@
#include "paddle/ir/core/type_id.h"
#include "paddle/ir/core/type_name.h"
#include "paddle/ir/core/value.h"
#include "paddle/utils/optional.h"
namespace ir {
......@@ -72,22 +72,22 @@ class IR_API Pattern {
public:
const std::vector<OpInfo>& generated_ops() const { return generated_ops_; }
paddle::optional<OpInfo> root_kind() const {
std::optional<OpInfo> root_kind() const {
if (root_kind_ == RootKind::OperationInfo)
return OpInfo::RecoverFromOpaquePointer(root_val_);
return paddle::none;
return std::nullopt;
}
paddle::optional<TypeId> GetRootInterfaceID() const {
std::optional<TypeId> GetRootInterfaceID() const {
if (root_kind_ == RootKind::InterfaceId)
return TypeId::RecoverFromOpaquePointer(root_val_);
return paddle::none;
return std::nullopt;
}
paddle::optional<TypeId> GetRootTraitID() const {
std::optional<TypeId> GetRootTraitID() const {
if (root_kind_ == RootKind::TraitId)
return TypeId::RecoverFromOpaquePointer(root_val_);
return paddle::none;
return std::nullopt;
}
PatternBenefit benefit() const { return benefit_; }
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <gtest/gtest.h>
#include "glog/logging.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
......@@ -29,6 +30,35 @@
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#ifndef _WIN32
class TestAnalysis1 {};
class TestAnalysis2 {};
IR_DECLARE_EXPLICIT_TYPE_ID(TestAnalysis1)
IR_DEFINE_EXPLICIT_TYPE_ID(TestAnalysis1)
IR_DECLARE_EXPLICIT_TYPE_ID(TestAnalysis2)
IR_DEFINE_EXPLICIT_TYPE_ID(TestAnalysis2)
TEST(pass_manager, PreservedAnalyses) {
ir::detail::PreservedAnalyses pa;
CHECK_EQ(pa.IsNone(), true);
CHECK_EQ(pa.IsPreserved<TestAnalysis1>(), false);
pa.Preserve<TestAnalysis1>();
CHECK_EQ(pa.IsPreserved<TestAnalysis1>(), true);
pa.Unpreserve<TestAnalysis1>();
CHECK_EQ(pa.IsPreserved<TestAnalysis1>(), false);
CHECK_EQ(pa.IsPreserved<TestAnalysis2>(), false);
pa.Preserve<TestAnalysis1, TestAnalysis2>();
CHECK_EQ(pa.IsPreserved<TestAnalysis1>(), true);
CHECK_EQ(pa.IsPreserved<TestAnalysis2>(), true);
CHECK_EQ(pa.IsAll(), false);
pa.PreserveAll();
CHECK_EQ(pa.IsAll(), true);
CHECK_EQ(pa.IsNone(), false);
}
#endif
class AddOp : public ir::Op<AddOp> {
public:
using Op::Op;
......@@ -63,15 +93,49 @@ void AddOp::Build(ir::Builder &,
IR_DECLARE_EXPLICIT_TYPE_ID(AddOp)
IR_DEFINE_EXPLICIT_TYPE_ID(AddOp)
struct CountOpAnalysis {
explicit CountOpAnalysis(ir::Operation *container_op) {
IR_ENFORCE(container_op->num_regions() > 0, true);
LOG(INFO) << "In CountOpAnalysis, op is " << container_op->name() << "\n";
for (size_t i = 0; i < container_op->num_regions(); ++i) {
auto &region = container_op->region(i);
for (auto it = region.begin(); it != region.end(); ++it) {
auto *block = *it;
for (auto it = block->begin(); it != block->end(); ++it) {
++count;
}
}
}
LOG(INFO) << "-- count is " << count << "\n";
}
int count = 0;
};
IR_DECLARE_EXPLICIT_TYPE_ID(CountOpAnalysis)
IR_DEFINE_EXPLICIT_TYPE_ID(CountOpAnalysis)
class TestPass : public ir::Pass {
public:
TestPass() : ir::Pass("TestPass", 1) {}
void Run(ir::Operation *op) override {
auto count_op_analysis = analysis_manager().GetAnalysis<CountOpAnalysis>();
pass_state().preserved_analyses.Preserve<CountOpAnalysis>();
CHECK_EQ(pass_state().preserved_analyses.IsPreserved<CountOpAnalysis>(),
true);
CHECK_EQ(count_op_analysis.count, 4);
auto module_op = op->dyn_cast<ir::ModuleOp>();
CHECK_EQ(module_op.operation(), op);
CHECK_EQ(module_op.name(), module_op->name());
LOG(INFO) << "In " << pass_info().name << ": " << module_op->name()
<< std::endl;
pass_state().preserved_analyses.Unpreserve<CountOpAnalysis>();
CHECK_EQ(pass_state().preserved_analyses.IsPreserved<CountOpAnalysis>(),
false);
}
bool CanApplyOn(ir::Operation *op) const override {
......@@ -79,7 +143,11 @@ class TestPass : public ir::Pass {
}
};
TEST(pass_manager_test, pass_manager) {
TEST(pass_manager, PassManager) {
//
// TODO(liuyuanle): remove test code other than pass manager
//
// (1) Init environment.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *builtin_dialect =
......@@ -186,6 +254,7 @@ TEST(pass_manager_test, pass_manager) {
op4->operand(0).type().dialect().GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c =
c_interface->VariableToParameter(variable_c.get());
EXPECT_EQ(parameter_c->type(), dense_tensor_dtype);
......@@ -199,6 +268,10 @@ TEST(pass_manager_test, pass_manager) {
EXPECT_EQ(program.block()->size() == 4, true);
EXPECT_EQ(program.parameters_num() == 3, true);
//
// TODO(liuyuanle): remove the code above.
//
// (9) Test pass manager for program.
ir::PassManager pm(ctx);
......
......@@ -20,7 +20,7 @@
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
TEST(PatternBenefit, PatternBenefit) {
TEST(pattern_rewrite, PatternBenefit) {
ir::PatternBenefit benefit1(1);
EXPECT_EQ(benefit1.benefit(), 1U);
ir::PatternBenefit benefit2(2);
......@@ -95,7 +95,7 @@ class TestPatternRewrite2 : public ir::OpRewritePattern<Operation1> {
}
};
TEST(RewritePattern, OpRewritePattern) {
TEST(pattern_rewrite, RewritePatternSet) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册