未验证 提交 7ace52ba 编写于 作者: W Wilber 提交者: GitHub

[IR&Pass] Add DCE Pass for new ir. (#54935)

上级 9137adb9
......@@ -40,6 +40,7 @@ endif()
add_subdirectory(core)
add_subdirectory(pass)
add_subdirectory(pattern_rewrite)
add_subdirectory(transforms)
if(WIN32)
if(WITH_SHARED_IR)
......
file(GLOB PATTERN_SRCS "*.cc")
ir_library(
ir_builtin_transforms
SRCS
${PATTERN_SRCS}
DEPS
ir_core
ir_pattern_rewrite
ir_pass)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/transforms/dce.h"
#include <memory>
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/pass/pass.h"
namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// Now just a naive implementation.
class DCEPass : public ir::Pass {
public:
DCEPass() : ir::Pass("DCEPass", 0) {}
void Run(ir::Operation *op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>();
IR_ENFORCE(module_op, "DCEPass should run on module op.");
auto *block = module_op.block();
std::vector<ir::Operation> erased_op;
for (auto it = block->begin(); it != block->end(); ++it) {
// TODO(wilber): Support NoSideEffect trait.
// if (!(*it)->HasTrait<NoSideEffect>()) continue;
bool use_empty = true;
for (uint32_t i = 0; i < (*it)->num_results(); ++i) {
use_empty &= (*it)->result(i).use_empty();
}
if (use_empty && (*it)->name() != "pd.fetch") {
erased_op.push_back(**it);
}
}
for (auto ep : erased_op) block->erase(ep);
}
bool CanApplyOn(ir::Operation *op) const override {
return op->name() == "builtin.module" && op->num_regions() > 0;
}
};
} // namespace
namespace ir {
std::unique_ptr<Pass> CreateDCEPass() { return std::make_unique<DCEPass>(); }
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/ir/core/dll_decl.h"
namespace ir {
class Pass;
IR_API std::unique_ptr<Pass> CreateDCEPass();
} // namespace ir
......@@ -22,6 +22,7 @@
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/cast_utils.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
......@@ -34,6 +35,7 @@
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/ir/transforms/dce.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
......@@ -235,7 +237,7 @@ class TestPass : public ir::Pass {
ir::FrozenRewritePatternSet frozen_ps(std::move(ps));
ir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 1;
cfg.max_iterations = 10;
ir::ApplyPatternsGreedily(op->region(0), frozen_ps, cfg);
}
......@@ -255,10 +257,10 @@ void BuildProgram(ir::Builder &builder) { // NOLINT
auto transpose1_op = builder.Build<paddle::dialect::TransposeOp>(
full_op_output, std::vector<int>{0, 2, 3, 1});
builder.Build<paddle::dialect::TransposeOp>(transpose1_op.out(),
std::vector<int>{0, 3, 1, 2});
auto transpose2_op = builder.Build<paddle::dialect::TransposeOp>(
transpose1_op.out(), std::vector<int>{0, 3, 1, 2});
// builder.Build<paddle::dialect::FetchOp>(transpose2_op.out());
builder.Build<paddle::dialect::FetchOp>(transpose2_op.out(), "out");
}
// TODO(wilber): Add a normal test.
......@@ -268,10 +270,11 @@ TEST(PatternRewrite, GreedyPatternRewriteDriver) {
ir::Program program(ctx);
ir::Builder builder = ir::Builder(ctx, program.block());
BuildProgram(builder);
EXPECT_EQ(program.block()->size(), 3u);
EXPECT_EQ(program.block()->size(), 4u);
ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateDCEPass());
std::stringstream o1, o2;
program.Print(o1);
LOG(INFO) << o1.str();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册