diff --git a/paddle/ir/core/block.cc b/paddle/ir/core/block.cc index f99ec340e4c49cf2bd388c41a3c7f7d1d06985ea..04d59e2582ebe5b27b29c9708d5049a464a91ead 100644 --- a/paddle/ir/core/block.cc +++ b/paddle/ir/core/block.cc @@ -13,6 +13,9 @@ // limitations under the License. #include "paddle/ir/core/block.h" + +#include + #include "paddle/ir/core/enforce.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/region.h" @@ -60,4 +63,34 @@ Block::UseIterator Block::use_end() const { return Block::UseIterator(); } bool Block::HasOneUse() const { return first_use_ && !first_use_.next_use(); } +void Block::ResetOpListOrder(const OpListType &new_op_list) { + IR_ENFORCE(new_op_list.size() == ops_.size(), + "The size of new_op_list not same with ops_."); + IR_ENFORCE(TopoOrderCheck(new_op_list), + "The new_op_list is not in topological order."); + + ops_.clear(); + for (Operation *op : new_op_list) { + push_back(op); + } +} + +bool Block::TopoOrderCheck(const OpListType &op_list) { + std::unordered_set visited_values; + for (const Operation *op : op_list) { + if (op->num_operands() > 0) { + for (size_t i = 0; i < op->num_operands(); ++i) { + auto operand = op->operand_source(i); + if (operand && visited_values.count(op->operand_source(i)) == 0) { + return false; + } + } + } + for (size_t i = 0; i < op->results().size(); ++i) { + visited_values.insert(op->result(i)); + } + } + return true; +} + } // namespace ir diff --git a/paddle/ir/core/block.h b/paddle/ir/core/block.h index ebe4b6cb8ecf4e6c6147f02e58a08e92b1be592d..2cf00037eb5fcfd55688f75873a3466222f406bf 100644 --- a/paddle/ir/core/block.h +++ b/paddle/ir/core/block.h @@ -70,6 +70,8 @@ class IR_API Block { bool HasOneUse() const; BlockOperand *first_use_addr() { return &first_use_; } + void ResetOpListOrder(const OpListType &new_op_list); + private: Block(Block &) = delete; Block &operator=(const Block &) = delete; @@ -78,6 +80,8 @@ class IR_API Block { friend class Region; void SetParent(Region *parent, Region::iterator position); + static bool TopoOrderCheck(const OpListType &op_list); + private: Region *parent_; // not owned OpListType ops_; // owned diff --git a/paddle/ir/transforms/reorder_block_ops_pass.cc b/paddle/ir/transforms/reorder_block_ops_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..d922326677985a609b47a367ede5af60a80ea523 --- /dev/null +++ b/paddle/ir/transforms/reorder_block_ops_pass.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/transforms/reorder_block_ops_pass.h" + +#include + +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/pass/pass.h" + +namespace { + +// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be +// removed by dce pass. +// Now just a naive implementation. +class ReorderBlockOpsPass : public ir::Pass { + public: + ReorderBlockOpsPass() : ir::Pass("ReorderBlockOpsPass", 0) {} + + void Run(ir::Operation *op) override { + IR_ENFORCE(op->num_regions() > 0, + "ReorderBlockOpsPass should run on Operation which regions " + "number greater than 0."); + for (size_t i = 0; i < op->num_regions(); ++i) { + for (auto *block : op->region(i)) { + std::list res_op_list; + std::unordered_map + reorder_op_dep_cnt; // op -> dependent input count + std::unordered_set visited_values; + std::queue op_que; + + auto update_op_que = [&](ir::Operation *op) { + for (size_t i = 0; i < op->results().size(); ++i) { + auto result = op->result(i); + visited_values.insert(result); + for (auto it = result.use_begin(); it != result.use_end(); ++it) { + if (reorder_op_dep_cnt.count(it->owner())) { + reorder_op_dep_cnt[it->owner()]--; + if (reorder_op_dep_cnt[it->owner()] == 0) { + op_que.push(it->owner()); + } + } + } + } + }; + + for (auto &op : *block) { + bool has_dependency = false; + if (op->num_operands() > 0) { + for (size_t i = 0; i < op->num_operands(); ++i) { + auto operand = op->operand_source(i); + if (operand && visited_values.count(op->operand_source(i)) == 0) { + reorder_op_dep_cnt[op]++; + has_dependency = true; + } + } + } + if (!has_dependency) { + res_op_list.push_back(op); + update_op_que(op); + } + } + + if (reorder_op_dep_cnt.empty()) { + return; + } + + while (!op_que.empty()) { + auto *op = op_que.front(); + op_que.pop(); + res_op_list.push_back(op); + update_op_que(op); + } + VLOG(4) << "ReorderBlockOpsPass is applied."; + block->ResetOpListOrder(res_op_list); + } + } + } + + bool CanApplyOn(ir::Operation *op) const override { + return op->num_regions() > 0; + } +}; + +} // namespace + +namespace ir { + +std::unique_ptr CreateReorderBlockOpsPass() { + return std::make_unique(); +} + +} // namespace ir diff --git a/paddle/ir/transforms/reorder_block_ops_pass.h b/paddle/ir/transforms/reorder_block_ops_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..f668471fc9e04d62b51514d2575498d129e26496 --- /dev/null +++ b/paddle/ir/transforms/reorder_block_ops_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ir/core/dll_decl.h" + +namespace ir { + +class Pass; + +IR_API std::unique_ptr CreateReorderBlockOpsPass(); + +} // namespace ir diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index e007b73c9f0ed0ec0568478f7668e9f6cba1c95b..fcca8cde7d5aa0b7ec58a7fe42c351e5c068bf74 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -42,6 +42,7 @@ #include "paddle/ir/pattern_rewrite/pattern_match.h" #include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" #include "paddle/ir/transforms/dead_code_elimination_pass.h" +#include "paddle/ir/transforms/reorder_block_ops_pass.h" #include "paddle/phi/core/kernel_registry.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in @@ -1099,6 +1100,7 @@ TEST(pattern_rewrite, Patterns) { pm.AddPass(std::make_unique()); pm.AddPass(ir::CreateConstantFoldingPass()); pm.AddPass(ir::CreateDeadCodeEliminationPass()); + pm.AddPass(ir::CreateReorderBlockOpsPass()); pm.EnablePassTiming(); pm.EnableIRPrinting(); // pm.EnableIRPrinting(std::make_unique(