未验证 提交 3bcc91e4 编写于 作者: Z zyfncg 提交者: GitHub

add reorder_block_ops_pass (#56990)

上级 14ede2b9
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include <unordered_set>
#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/region.h" #include "paddle/ir/core/region.h"
...@@ -60,4 +63,34 @@ Block::UseIterator Block::use_end() const { return Block::UseIterator(); } ...@@ -60,4 +63,34 @@ Block::UseIterator Block::use_end() const { return Block::UseIterator(); }
bool Block::HasOneUse() const { return first_use_ && !first_use_.next_use(); } bool Block::HasOneUse() const { return first_use_ && !first_use_.next_use(); }
void Block::ResetOpListOrder(const OpListType &new_op_list) {
IR_ENFORCE(new_op_list.size() == ops_.size(),
"The size of new_op_list not same with ops_.");
IR_ENFORCE(TopoOrderCheck(new_op_list),
"The new_op_list is not in topological order.");
ops_.clear();
for (Operation *op : new_op_list) {
push_back(op);
}
}
bool Block::TopoOrderCheck(const OpListType &op_list) {
std::unordered_set<Value> visited_values;
for (const Operation *op : op_list) {
if (op->num_operands() > 0) {
for (size_t i = 0; i < op->num_operands(); ++i) {
auto operand = op->operand_source(i);
if (operand && visited_values.count(op->operand_source(i)) == 0) {
return false;
}
}
}
for (size_t i = 0; i < op->results().size(); ++i) {
visited_values.insert(op->result(i));
}
}
return true;
}
} // namespace ir } // namespace ir
...@@ -70,6 +70,8 @@ class IR_API Block { ...@@ -70,6 +70,8 @@ class IR_API Block {
bool HasOneUse() const; bool HasOneUse() const;
BlockOperand *first_use_addr() { return &first_use_; } BlockOperand *first_use_addr() { return &first_use_; }
void ResetOpListOrder(const OpListType &new_op_list);
private: private:
Block(Block &) = delete; Block(Block &) = delete;
Block &operator=(const Block &) = delete; Block &operator=(const Block &) = delete;
...@@ -78,6 +80,8 @@ class IR_API Block { ...@@ -78,6 +80,8 @@ class IR_API Block {
friend class Region; friend class Region;
void SetParent(Region *parent, Region::iterator position); void SetParent(Region *parent, Region::iterator position);
static bool TopoOrderCheck(const OpListType &op_list);
private: private:
Region *parent_; // not owned Region *parent_; // not owned
OpListType ops_; // owned OpListType ops_; // owned
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/transforms/reorder_block_ops_pass.h"
#include <queue>
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// Now just a naive implementation.
class ReorderBlockOpsPass : public ir::Pass {
public:
ReorderBlockOpsPass() : ir::Pass("ReorderBlockOpsPass", 0) {}
void Run(ir::Operation *op) override {
IR_ENFORCE(op->num_regions() > 0,
"ReorderBlockOpsPass should run on Operation which regions "
"number greater than 0.");
for (size_t i = 0; i < op->num_regions(); ++i) {
for (auto *block : op->region(i)) {
std::list<ir::Operation *> res_op_list;
std::unordered_map<ir::Operation *, int>
reorder_op_dep_cnt; // op -> dependent input count
std::unordered_set<ir::Value> visited_values;
std::queue<ir::Operation *> op_que;
auto update_op_que = [&](ir::Operation *op) {
for (size_t i = 0; i < op->results().size(); ++i) {
auto result = op->result(i);
visited_values.insert(result);
for (auto it = result.use_begin(); it != result.use_end(); ++it) {
if (reorder_op_dep_cnt.count(it->owner())) {
reorder_op_dep_cnt[it->owner()]--;
if (reorder_op_dep_cnt[it->owner()] == 0) {
op_que.push(it->owner());
}
}
}
}
};
for (auto &op : *block) {
bool has_dependency = false;
if (op->num_operands() > 0) {
for (size_t i = 0; i < op->num_operands(); ++i) {
auto operand = op->operand_source(i);
if (operand && visited_values.count(op->operand_source(i)) == 0) {
reorder_op_dep_cnt[op]++;
has_dependency = true;
}
}
}
if (!has_dependency) {
res_op_list.push_back(op);
update_op_que(op);
}
}
if (reorder_op_dep_cnt.empty()) {
return;
}
while (!op_que.empty()) {
auto *op = op_que.front();
op_que.pop();
res_op_list.push_back(op);
update_op_que(op);
}
VLOG(4) << "ReorderBlockOpsPass is applied.";
block->ResetOpListOrder(res_op_list);
}
}
}
bool CanApplyOn(ir::Operation *op) const override {
return op->num_regions() > 0;
}
};
} // namespace
namespace ir {
std::unique_ptr<Pass> CreateReorderBlockOpsPass() {
return std::make_unique<ReorderBlockOpsPass>();
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/ir/core/dll_decl.h"
namespace ir {
class Pass;
IR_API std::unique_ptr<Pass> CreateReorderBlockOpsPass();
} // namespace ir
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include "paddle/ir/pattern_rewrite/pattern_match.h" #include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" #include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/ir/transforms/dead_code_elimination_pass.h" #include "paddle/ir/transforms/dead_code_elimination_pass.h"
#include "paddle/ir/transforms/reorder_block_ops_pass.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
...@@ -1099,6 +1100,7 @@ TEST(pattern_rewrite, Patterns) { ...@@ -1099,6 +1100,7 @@ TEST(pattern_rewrite, Patterns) {
pm.AddPass(std::make_unique<TestPass>()); pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateConstantFoldingPass()); pm.AddPass(ir::CreateConstantFoldingPass());
pm.AddPass(ir::CreateDeadCodeEliminationPass()); pm.AddPass(ir::CreateDeadCodeEliminationPass());
pm.AddPass(ir::CreateReorderBlockOpsPass());
pm.EnablePassTiming(); pm.EnablePassTiming();
pm.EnableIRPrinting(); pm.EnableIRPrinting();
// pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>( // pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册