diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 991f8dbe2e107dcdf3e8fc159f63e0ebbfe232b3..ae23338cb22e9521b4e5d6eb5d2e468039a1df59 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -224,4 +224,12 @@ void Operation::SetParent(Block *parent, const Block::iterator &position) { position_ = position; } +void Operation::ReplaceAllUsesWith(const std::vector &values) { + IR_ENFORCE(num_results_ == values.size(), + "the num of result should be the same."); + for (uint32_t i = 0; i < num_results_; ++i) { + result(i).ReplaceAllUsesWith(values[i]); + } +} + } // namespace ir diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index 6a6d9dc19de5bb51a72d8c9182bffe9103d4ef28..bf223f2fdf966bb12a9630a10a954e518764ec9b 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include "paddle/ir/core/block.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/operation_utils.h" @@ -102,6 +103,13 @@ class IR_API alignas(8) Operation final { operator Block::const_iterator() const { return position_; } + /// Replace all uses of results of this operation with the provided 'values'. + void ReplaceAllUsesWith(const std::vector &values); + + inline void ReplaceAllUsesWith(Value value) { + ReplaceAllUsesWith(std::vector{value}); + } + private: Operation(const AttributeMap &attribute, ir::OpInfo op_info, diff --git a/paddle/ir/core/region.cc b/paddle/ir/core/region.cc index 3d2cbe5be64913ecc9c1f0edb5e9c09efd450682..ac7cd4ccdfa8b8554b1adebfc5f29874af6cdddd 100644 --- a/paddle/ir/core/region.cc +++ b/paddle/ir/core/region.cc @@ -15,6 +15,7 @@ #include "paddle/ir/core/region.h" #include "paddle/ir/core/block.h" #include "paddle/ir/core/enforce.h" +#include "paddle/ir/core/operation.h" namespace ir { Region::~Region() { clear(); } @@ -50,4 +51,9 @@ void Region::clear() { blocks_.pop_back(); } } + +IrContext *Region::ir_context() const { + IR_ENFORCE(parent_, "Region is not attached to a container."); + return parent_->ir_context(); +} } // namespace ir diff --git a/paddle/ir/core/region.h b/paddle/ir/core/region.h index 5d3c78e59a6001e5d8321c56602db4765059587a..5335588790f021c97c1f6dc4fbf02c44ab3ce7e3 100644 --- a/paddle/ir/core/region.h +++ b/paddle/ir/core/region.h @@ -23,6 +23,7 @@ namespace ir { class Block; class Operation; +class IrContext; class IR_API Region { public: @@ -55,6 +56,8 @@ class IR_API Region { Operation *GetParent() const { return parent_; } + IrContext *ir_context() const; + private: Region(Region &) = delete; Region &operator=(const Region &) = delete; diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 4e7afd2d835edb49f8a43c0c4690cda4fe99d968..a5ca59d19759b5fefe677002a61898470f3e4f3a 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -85,6 +85,8 @@ OpOperand Value::first_use() const { return impl()->first_use(); } bool Value::use_empty() const { return !first_use(); } +bool Value::HasOneUse() const { return impl()->HasOneUse(); } + void Value::ReplaceUsesWithIf( Value new_value, const std::function &should_replace) const { diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index 7fa336fed4a4b4244ede26705ba0bea2579611c7..429516acc4a6b395690f37c916ea7d93faf9fb1f 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -158,10 +158,12 @@ class IR_API Value { OpOperand first_use() const; - friend struct std::hash; - bool use_empty() const; + bool HasOneUse() const; + + friend struct std::hash; + void ReplaceUsesWithIf( Value new_value, const std::function &should_replace) const; diff --git a/paddle/ir/core/value_impl.h b/paddle/ir/core/value_impl.h index f7032d87ce9375eb3ee23c4c3cc2b0b57207d64a..1e21e8f0d19c6bb24c49a09f2eeb53e6af168797 100644 --- a/paddle/ir/core/value_impl.h +++ b/paddle/ir/core/value_impl.h @@ -98,6 +98,10 @@ class alignas(8) ValueImpl { bool use_empty() const { return first_use() == nullptr; } + bool HasOneUse() const { + return (first_use() != nullptr) && (first_use()->next_use() == nullptr); + } + std::string PrintUdChain(); protected: diff --git a/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h b/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h index eb7b33d7a1f2bdf6c50c21f11dd220cbf255b391..59d7e2a8e8141ec2c725f43a0319a31a656bbdb4 100644 --- a/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h +++ b/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h @@ -21,12 +21,13 @@ #include #include +#include "paddle/ir/core/dll_decl.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/pattern_rewrite/pattern_match.h" namespace ir { -class FrozenRewritePatternSet { +class IR_API FrozenRewritePatternSet { using NativePatternListT = std::vector>; public: diff --git a/paddle/ir/pattern_rewrite/pattern_match.cc b/paddle/ir/pattern_rewrite/pattern_match.cc index cd7950b0af5d9cbc920429696ddb3f6390bc996e..1f465809be37ca004a184977d367182a200f4ae6 100644 --- a/paddle/ir/pattern_rewrite/pattern_match.cc +++ b/paddle/ir/pattern_rewrite/pattern_match.cc @@ -15,9 +15,9 @@ #include "paddle/ir/pattern_rewrite/pattern_match.h" #include -#include #include +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/operation.h" namespace ir { @@ -90,44 +90,55 @@ RewritePattern::~RewritePattern() = default; //===----------------------------------------------------------------------===// RewriterBase::~RewriterBase() = default; -// TODO(wilber): value support replace method. -// void RewriterBase::ReplaceOpWithIf(Operation* op, -// ValueRange new_values, -// bool* all_uses_replaced, -// std::function functor) { -// // assert(op->num_results() == new_values.size() && "incorrect number of -// values to replace operation"); NotifyRootReplaced(op, new_values); bool -// replace_all_uses = true; for (uint32_t i = 0; i < op->num_results(); ++i) { -// // op->result(0) -// } -// } -// void RewriterBase::ReplaceOpWithIf(Operation* op, -// ValueRange new_values, -// std::function functor) { -// ReplaceOpWithIf(op, new_values, nullptr, functor); -// } - -// TODO(wilber): support erase. -// void ReplaceOp(Operation* op, ValueRange new_values) { -// NotifyRootReplaced(op, new_values); -// assert(op->num_results() == new_values.size() && "incorrect # of -// replacement values"); op->ReplaceAllUsesWith(new_values); -// NotifyOperationRemoved(op); -// op->erase(); -// } +void RewriterBase::ReplaceOpWithIf( + Operation* op, + const std::vector& new_values, + bool* all_uses_replaced, + const std::function& functor) { + IR_ENFORCE(op->num_results() == new_values.size(), + "incorrect number of values to replace operation"); + NotifyRootReplaced(op, new_values); + + // Replace each use of the results when the functor is true. + bool replace_all_uses = true; + for (uint32_t i = 0; i < op->num_results(); ++i) { + auto src_res = op->result(i); + src_res.ReplaceUsesWithIf(new_values[i], functor); + replace_all_uses &= src_res.use_empty(); + } + if (replace_all_uses) { + *all_uses_replaced = replace_all_uses; + } +} + +void RewriterBase::ReplaceOpWithIf( + Operation* op, + const std::vector& new_values, + const std::function& functor) { + ReplaceOpWithIf(op, new_values, nullptr, functor); +} + +void RewriterBase::ReplaceOp(Operation* op, + const std::vector& new_values) { + NotifyRootReplaced(op, new_values); + IR_ENFORCE(op->num_results() == new_values.size(), + "incorrect # of replacement values"); + op->ReplaceAllUsesWith(new_values); + NotifyOperationRemoved(op); + op->GetParent()->erase(*op); +} + void RewriterBase::EraseOp(Operation* op) { - // assert(op->use_empty() && "expected 'op' to have no uses"); - // NotifyOperationRemoved(op); - // op->erase(); + // TODO(wilber): Operation support use_empty. + // IR_ENFORCE(op->use_empty(), "expected 'op' to have no uses"); + NotifyOperationRemoved(op); + op->GetParent()->erase(*op); } +/// Find uses of `from` and replace it with `to` void RewriterBase::ReplaceAllUsesWith(Value from, Value to) { - // from. - // for (OpOperand& operand : llvm::make_early_inc_range(from.getUses())) - // { - // Operation* op = operand.getOwner(); - // UpdateRootInPlace(op, [&]() { operand.set(to); }); - // } + // TODO(wilber): Substitue a low level impl. + from.ReplaceAllUsesWith(to); } // TODO(wilber): iterator maybe should support modify inplace. @@ -135,8 +146,8 @@ void RewriterBase::ReplaceUseIf(Value from, Value to, std::function functor) { // for (auto it = from.begin(); it != from.end(); ++it) { - // // TODO: need a lvalue. - // if (functor(it.get())) { + // // // TODO: need a lvalue. + // if (functor(*it)) { // UpdateRootInplace(it.owner(), [&](){it.get().set(to)}); // } // } @@ -144,8 +155,8 @@ void RewriterBase::ReplaceUseIf(Value from, void RewriterBase::ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op) { - assert(op->num_results() == new_op->num_results() && - "replacement op doesn't match results of original op"); + IR_ENFORCE(op->num_results() == new_op->num_results(), + "replacement op doesn't match results of original op"); // TODO(wilber): Op support results method. // if (op->num_results() == 1) return ReplaceOp(op, // new_op->result(0)); return ReplaceOp(op, new_op->GetResults()); diff --git a/paddle/ir/pattern_rewrite/pattern_match.h b/paddle/ir/pattern_rewrite/pattern_match.h index dee11d6bd929620dfc77cc940faec3799e3013c9..6c90f366564fc9c28813b80a0368971375c8a749 100644 --- a/paddle/ir/pattern_rewrite/pattern_match.h +++ b/paddle/ir/pattern_rewrite/pattern_match.h @@ -25,6 +25,7 @@ #include #include "paddle/ir/core/builder.h" +#include "paddle/ir/core/dll_decl.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/operation.h" @@ -36,7 +37,7 @@ namespace ir { // This class reprensents the benefit of a pattern. The most common // unit to use is the `numver of operations` in the pattern. -class PatternBenefit { +class IR_API PatternBenefit { public: PatternBenefit() = default; PatternBenefit(uint32_t val) : val_(val) {} // NOLINT @@ -257,30 +258,21 @@ class RewriterBase : public Builder { public: // TODO(wilber): Supplementary methods of block and region. - // TODO(wilber): Support ValueRange. - // virtual void ReplaceOpWithIf(Operation* op, - // ValueRange new_values, - // bool* all_uses_replaced, - // std::function functor); - // void ReplaceOpWithIf(Operation* op, - // ValueRange new_values, - // std::function functor); - // virtual void ReplaceOp(Operation* op, ValueRange new_values); + virtual void ReplaceOpWithIf(Operation* op, + const std::vector& new_values, + bool* all_uses_replaced, + const std::function& functor); - // virtual void ReplaceOpWithNewOp() + void ReplaceOpWithIf(Operation* op, + const std::vector& new_values, + const std::function& functor); - virtual void EraseOp(Operation* op); + virtual void ReplaceOp(Operation* op, const std::vector& new_values); - virtual void StartRootUpdate(Operation* op) {} - virtual void FinalizeRootUpdate(Operation* op) {} - virtual void CancleRootUpdate(Operation* op) {} + // template + // OpTy ReplaceOpWithNewOp(Operation *op, Args &&...args); - template - void UpdateRootInplace(Operation* root, CallableT&& callable) { - StartRootUpdate(root); - callable(); - FinalizeRootUpdate(root); - } + virtual void EraseOp(Operation* op); void ReplaceAllUsesWith(Value from, Value to); @@ -293,11 +285,25 @@ class RewriterBase : public Builder { virtual ~RewriterBase(); - // virtual void NotifyRootReplaced(Operation* op, ValueRange replacement) {} + virtual void NotifyRootReplaced(Operation* op, + const std::vector& replacement) {} virtual void NotifyOperationRemoved(Operation* op) {} - // virtual bool NotifyMatchFailure() + virtual void NotifyOperationInserted(Operation* op) {} + + virtual void StartRootUpdate(Operation* op) {} + + virtual void FinalizeRootUpdate(Operation* op) {} + + virtual void CancleRootUpdate(Operation* op) {} + + template + void UpdateRootInplace(Operation* root, CallableT&& callable) { + StartRootUpdate(root); + callable(); + FinalizeRootUpdate(root); + } private: void operator=(const RewriterBase&) = delete; diff --git a/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc new file mode 100644 index 0000000000000000000000000000000000000000..21a673e6b3a15c24fe58d89f5cba7ebc56697c06 --- /dev/null +++ b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc @@ -0,0 +1,227 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/ir/core/block.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/region.h" +#include "paddle/ir/core/value.h" +#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/ir/pattern_rewrite/pattern_applicator.h" +#include "paddle/ir/pattern_rewrite/pattern_match.h" + +namespace { + +class GreedyPatternRewriteDriver : public ir::PatternRewriter { + public: + explicit GreedyPatternRewriteDriver( + ir::IrContext* ctx, + const ir::FrozenRewritePatternSet& patterns, + const ir::GreedyRewriteConfig& config) + : ir::PatternRewriter(ctx), + config_(config), + region_(*config.region), + matcher_(patterns) { + worklist_.reserve(128); + matcher_.ApplyDefaultCostModel(); + if (config.strict_mode != ir::GreedyRewriteStrictness::AnyOp) { + for (auto it = region_.begin(); it != region_.end(); ++it) { + for (auto op_it = (*it)->begin(); op_it != (*it)->end(); ++op_it) { + strict_mode_filtered_ops_.insert(*op_it); + } + } + } + } + + bool Simplify() { + bool changed = false; + int64_t iteration = 0; + do { + // Check if the iteration limit was reached. + if (iteration++ >= config_.max_iterations && + config_.max_iterations != ir::GreedyRewriteConfig::kNoLimit) + break; + VLOG(6) << "Iteration[" << iteration << "] for PatternRewrite"; + worklist_.clear(); + worklist_map_.clear(); + + for (auto block_it = region_.begin(); block_it != region_.end(); + ++block_it) { + for (auto op_it = (*block_it)->begin(); op_it != (*block_it)->end(); + ++op_it) { + worklist_.push_back(*op_it); + } + } + if (config_.use_top_down_traversal) { + // Reverse the list so out pop-back loop process them in-order. + std::reverse(worklist_.begin(), worklist_.end()); + } + for (size_t i = 0; i < worklist_.size(); ++i) { + worklist_map_[worklist_[i]] = i; + VLOG(6) << "worklist[" << i << "] is " << worklist_[i]->name(); + } + + changed = ProcessWorklist(); + } while (changed); + + return !changed; + } + + private: + /// Process ops until the worklist is empty or `config.max_num_rewrites` + /// is reached. Return `true` if any IR was changed. + bool ProcessWorklist() { + bool changed = false; + int64_t num_rewrites = 0; + + while (!worklist_.empty() && + (num_rewrites < config_.max_num_rewrites || + config_.max_num_rewrites == ir::GreedyRewriteConfig::kNoLimit)) { + auto* op = PopFromWorklist(); + if (op == nullptr) continue; + VLOG(6) << "PopFromWorklist, get op: " << op->name(); + + // TODO(wilber): ir is dead. + // ... + + // TODO(wilber): fold logical. + // ... + + bool match_result = matcher_.MatchAndRewrite(op, *this); + if (match_result) { + changed = true; + ++num_rewrites; + } + } + + return changed; + } + + // TODO(wilber): OpResult support GetUsers method. + void NotifyRootReplaced(ir::Operation* op, + const std::vector& replacement) override { + // for (uint32_t i = 0; i < op->num_results(); ++i) { + // auto res = op->GetResultByIndex(i); + // } + // } + } + + void FinalizeRootUpdate(ir::Operation* op) override { AddToWorklist(op); } + + void NotifyOperationRemoved(ir::Operation* op) override { + for (uint32_t i = 0; i < op->num_operands(); ++i) { + AddOperandToWorklist(op->operand(i).source()); + } + for (uint32_t i = 0; i < op->num_regions(); ++i) { + auto& region = op->region(i); + for (auto it = region.begin(); it != region.end(); ++it) { + for (auto op_it = (*it)->begin(); op_it != (*it)->end(); ++op_it) { + RemoveFromWorklist(*op_it); + } + } + } + + if (config_.strict_mode != ir::GreedyRewriteStrictness::AnyOp) { + strict_mode_filtered_ops_.erase(op); + } + } + + void NotifyOperationInserted(ir::Operation* op) override { + if (config_.strict_mode == ir::GreedyRewriteStrictness::ExistingAndNewOps) + strict_mode_filtered_ops_.insert(op); + AddToWorklist(op); + } + + /// Add the given operation to the worklist. + void AddToWorklist(ir::Operation* op) { + if (config_.strict_mode == ir::GreedyRewriteStrictness::AnyOp || + strict_mode_filtered_ops_.count(op)) { + if (worklist_map_.count(op)) return; + + worklist_map_[op] = worklist_.size(); + worklist_.push_back(op); + } + } + + void AddOperandToWorklist(ir::Value operand) { + // If the use count of this operand is now < 2, we re-add the defining + // operation to the worklist. + // This is based on the fact that zero use operations may be deleted, and + // that single use values often have more canonicalization opportunities. + if (!operand || (!operand.use_empty() && !operand.HasOneUse())) return; + + if (auto* def_op = operand.GetDefiningOp()) AddToWorklist(def_op); + } + + void AddOperandsToWorklist(const std::vector operands) { + for (auto& v : operands) { + AddOperandToWorklist(v); + } + } + + /// Pop the next operation from the worklist + ir::Operation* PopFromWorklist() { + auto* op = worklist_.back(); + worklist_.pop_back(); + if (op) worklist_map_.erase(op); + return op; + } + + /// If the specified operation is in the worklist, remove it. + void RemoveFromWorklist(ir::Operation* op) { + auto it = worklist_map_.find(op); + if (it != worklist_map_.end()) { + worklist_[it->second] = nullptr; + worklist_map_.erase(it); + } + } + + private: + std::vector worklist_; + std::unordered_map worklist_map_; + ir::GreedyRewriteConfig config_; + std::unordered_set strict_mode_filtered_ops_; + ir::Region& region_; + ir::PatternApplicator matcher_; +}; + +} // namespace + +namespace ir { + +bool ApplyPatternsGreedily(Region& region, // NOLINT + const FrozenRewritePatternSet& patterns, + GreedyRewriteConfig config) { + if (!config.region) config.region = ®ion; + + GreedyPatternRewriteDriver driver(region.ir_context(), patterns, config); + bool converged = driver.Simplify(); + if (!converged) { + LOG(WARNING) << "The pattern rewrite did not converge after scaning " + << config.max_iterations << " times"; + } + return converged; +} + +} // namespace ir diff --git a/paddle/ir/pattern_rewrite/pattern_rewrite_driver.h b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.h new file mode 100644 index 0000000000000000000000000000000000000000..2e87eac5fef0bf45a7493f1ba187f036bf55cfd2 --- /dev/null +++ b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.h @@ -0,0 +1,86 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/dll_decl.h" +#include "paddle/ir/core/region.h" +#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/ir/pattern_rewrite/pattern_match.h" + +namespace ir { + +/// This enum will control which ops will be added to the worklist during the +/// match rewrite process +enum class IR_API GreedyRewriteStrictness { + /// No restrictions wrt. any ops are processed. + AnyOp, + /// Only pre-existing and newly created ops are processed. + ExistingAndNewOps, + /// Only pre-existing ops are processed. + ExistingOps +}; + +/// Control over how the GreedyPatternRewriteDriver works. +class IR_API GreedyRewriteConfig { + public: + /// Control the way op is added to the worklist: bottom-up or top-down. + bool use_top_down_traversal = false; + + /// Control the maximum number of iterations in the process of applying the + /// pattern, use `kNolimit` to represent unlimited. + int64_t max_iterations = 10; + + /// Control the upper limit of rewrite times during each iteration, use + /// kNoLimit to represent unlimited. + int64_t max_num_rewrites = kNoLimit; + + /// Only the op inside this region will be added to the worklist. + Region* region{nullptr}; + + /// Limit which ops will be added to the worklist during the Match and Rewrite + /// process. + /// - AnyOp: all ops will be added to the worklist. + /// - ExistingAndNewOps: pre-existing ops and newly created ops are added to + /// the worklist. + /// - ExistingOps: only pre-existing ops are added to the worklist. + GreedyRewriteStrictness strict_mode = GreedyRewriteStrictness::AnyOp; + + static constexpr int64_t kNoLimit = -1; +}; + +/// Perform the Match and Rewrite process in the specified region, greedily +/// apply the Pattern with the highest benefit, and repeat this process until +/// convergence or the upper limit of iterations. +/// +/// Returns true if the iteration converges and no patterns can be applied. +bool IR_API +ApplyPatternsGreedily(Region& region, // NOLINT + const FrozenRewritePatternSet& patterns, + GreedyRewriteConfig config = GreedyRewriteConfig()); + +/// Perform a match and rewrite process for all regions of a given op. +inline IR_API bool ApplyPatternsGreedily( + Operation* op, + const FrozenRewritePatternSet& patterns, + GreedyRewriteConfig config = GreedyRewriteConfig()) { + bool failed = false; + for (uint32_t i = 0; i < op->num_regions(); ++i) { + Region& region = op->region(i); + failed |= !ApplyPatternsGreedily(region, patterns, config); + } + return !failed; +} + +} // namespace ir diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 04758da4f6ecc33543741ae25308fc3f01bdb882..cb04f440c01193a9b745807e19eba93bfe6b51b6 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -256,6 +256,7 @@ TEST(op_test, region_test) { block->push_front(op1); block->insert(block->begin(), op1_2); ir::Operation *op2 = ir::Operation::Create(std::move(argument)); + EXPECT_EQ(op2->region(0).ir_context(), ctx); op2->Destroy(); } diff --git a/test/cpp/ir/core/ir_value_test.cc b/test/cpp/ir/core/ir_value_test.cc index ae1edfe685dd2ba3fb6e7c1a5aad34536074952b..b77552122bfc19ae888573d23308b0b5e512e2f2 100644 --- a/test/cpp/ir/core/ir_value_test.cc +++ b/test/cpp/ir/core/ir_value_test.cc @@ -45,6 +45,7 @@ TEST(value_test, value_test) { ir::OpInfo()); op1->Print(std::cout); ir::OpResult a = op1->result(0); + EXPECT_TRUE(a.use_empty()); // 2. Construct OP2: b = OP2(); std::vector op2_inputs = {}; std::vector op2_output_types = {ir::Float32Type::get(ctx)}; @@ -55,6 +56,7 @@ TEST(value_test, value_test) { ir::OpInfo()); op2->Print(std::cout); ir::OpResult b = op2->result(0); + EXPECT_TRUE(b.use_empty()); // 3. Construct OP3: c = OP3(a, b); std::vector op3_inputs{a, b}; std::vector op3_output_types = {ir::Float32Type::get(ctx)}; @@ -63,6 +65,9 @@ TEST(value_test, value_test) { CreateAttributeMap("op3_name", "op3_attr"), op3_output_types, ir::OpInfo()); + + EXPECT_TRUE(op1->result(0).HasOneUse()); + EXPECT_TRUE(op2->result(0).HasOneUse()); op3->Print(std::cout); ir::OpResult c = op3->result(0); // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); diff --git a/test/cpp/ir/pattern_rewrite/CMakeLists.txt b/test/cpp/ir/pattern_rewrite/CMakeLists.txt index c5332866c6aef21171f186cff1703fb8ccefc437..62dfa3b8dece57bd6cb6a7447777593facfa94a4 100644 --- a/test/cpp/ir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/ir/pattern_rewrite/CMakeLists.txt @@ -1 +1,8 @@ -cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS ir gtest) +cc_test_old( + pattern_rewrite_test + SRCS + pattern_rewrite_test.cc + DEPS + ir + pd_dialect + gtest) diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index d26cbe9265325471763bb36a184d3e12c7b50158..607108d582b44523fd50e32d20ca9963e4a93f43 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -13,27 +13,33 @@ // limitations under the License. #include +#include +#include +#include +#include +#include "paddle/fluid/ir/dialect/pd_attribute.h" +#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/cast_utils.h" #include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/op_info.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/pass/pass.h" +#include "paddle/ir/pass/pass_manager.h" +#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/ir/pattern_rewrite/pattern_applicator.h" #include "paddle/ir/pattern_rewrite/pattern_match.h" +#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" -TEST(pattern_rewrite, PatternBenefit) { - ir::PatternBenefit benefit1(1); - EXPECT_EQ(benefit1.benefit(), 1U); - ir::PatternBenefit benefit2(2); - EXPECT_EQ(benefit2.benefit(), 2U); - - EXPECT_TRUE(benefit2 > benefit1); - EXPECT_TRUE(benefit2 >= benefit1); - EXPECT_TRUE(benefit1 < benefit2); - EXPECT_TRUE(benefit1 <= benefit2); - EXPECT_TRUE(benefit1 != benefit2); - ir::PatternBenefit benefit3(2); - EXPECT_TRUE(benefit2 == benefit3); -} +// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in +// paddle/fluid/ir/dialect/CMakeLists.txt. +#include "paddle/fluid/ir/dialect/pd_dialect.h" +#include "paddle/fluid/ir/dialect/pd_op.h" +#include "paddle/fluid/ir/dialect/pd_type.h" // Define op1. class Operation1 : public ir::Op { @@ -95,7 +101,22 @@ class TestPatternRewrite2 : public ir::OpRewritePattern { } }; -TEST(pattern_rewrite, RewritePatternSet) { +TEST(PatternRewrite, PatternBenefit) { + ir::PatternBenefit benefit1(1); + EXPECT_EQ(benefit1.benefit(), 1U); + ir::PatternBenefit benefit2(2); + EXPECT_EQ(benefit2.benefit(), 2U); + + EXPECT_TRUE(benefit2 > benefit1); + EXPECT_TRUE(benefit2 >= benefit1); + EXPECT_TRUE(benefit1 < benefit2); + EXPECT_TRUE(benefit1 <= benefit2); + EXPECT_TRUE(benefit1 != benefit2); + ir::PatternBenefit benefit3(2); + EXPECT_TRUE(benefit2 == benefit3); +} + +TEST(RewritePattern, RewritePatternSet) { ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); auto *test_dialect = ctx->GetOrRegisterDialect(); @@ -118,3 +139,144 @@ TEST(pattern_rewrite, RewritePatternSet) { EXPECT_EQ(ps.native_patterns()[0]->benefit(), 2U); EXPECT_EQ(ps.native_patterns()[1]->benefit(), 2U); } + +// TODO(wilber): Add actual case. +// TEST(PatternRewrite, PatternApplicator) { +// ir::IrContext *ctx = ir::IrContext::Instance(); +// ctx->GetOrRegisterDialect(); +// auto *test_dialect = ctx->GetOrRegisterDialect(); +// test_dialect->RegisterOp(); +// ir::RewritePatternSet ps(ctx); +// ps.Add(ctx, 2); +// ir::FrozenRewritePatternSet frozen_set(std::move(ps)); +// ir::PatternApplicator applicator(frozen_set); +// applicator.ApplyDefaultCostModel(); +// } + +// // TODO(wilber): Add actual case. +TEST(PatternRewrite, FrozenRewritePatternSet) { + ir::FrozenRewritePatternSet frozen_set; + EXPECT_TRUE(frozen_set.match_any_op_native_patterns().empty()); + EXPECT_TRUE(frozen_set.op_specific_native_patterns().empty()); + + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + auto *test_dialect = ctx->GetOrRegisterDialect(); + test_dialect->RegisterOp(); + ir::RewritePatternSet ps(ctx); + ps.Add(ctx, 2); + + ir::FrozenRewritePatternSet frozen_set2(std::move(ps)); + EXPECT_TRUE(frozen_set2.match_any_op_native_patterns().empty()); + const auto &pattern_maps = frozen_set2.op_specific_native_patterns(); + EXPECT_EQ(pattern_maps.size(), 1U); + EXPECT_EQ(pattern_maps.at(ctx->GetRegisteredOpInfo("test.Operation1")).size(), + 2U); +} + +class TransposePatternRewrite + : public ir::OpRewritePattern { + public: + using ir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(paddle::dialect::TransposeOp op, + ir::PatternRewriter &rewriter) const override { + auto prev_op = op->operand(0).source().GetDefiningOp(); + std::vector axis_last = GetAxis(op); + auto prev_trans_op = prev_op->dyn_cast(); + if (prev_trans_op) { + std::vector axis_first = GetAxis(prev_trans_op); + IR_ENFORCE(axis_first.size() == axis_last.size(), + "tranpose op's perm rank should be same."); + auto new_perm = GetPerm(axis_first, axis_last); + rewriter.SetInsertionPoint(op); + auto new_op = rewriter.Build( + prev_op->operand(0).source().GetDefiningOp()->result(0), new_perm); + rewriter.ReplaceOp(op, {new_op.out()}); + return true; + } + + return false; + } + + private: + std::vector GetAxis(paddle::dialect::TransposeOp op) const { + auto attr_map = op->attributes(); + ir::ArrayAttribute array_attr = + attr_map.at("perm").dyn_cast(); + std::vector axis(array_attr.size()); + for (size_t i = 0; i < array_attr.size(); ++i) { + axis[i] = array_attr[i].dyn_cast().data(); + } + return axis; + } + + std::vector GetPerm(const std::vector &perm1, + const std::vector &perm2) const { + int n = perm1.size(); + std::vector axis(n), axis1(n), axis2(n); + std::iota(axis.begin(), axis.end(), 0); + for (int i = 0; i < n; ++i) { + axis1[i] = axis[perm1[i]]; + } + for (int i = 0; i < n; ++i) { + axis2[i] = axis1[perm2[i]]; + } + return axis2; + } +}; + +class TestPass : public ir::Pass { + public: + TestPass() : ir::Pass("TestPass", 1) {} + void Run(ir::Operation *op) override { + ir::RewritePatternSet ps(op->ir_context()); + ps.Add(op->ir_context()); + ir::FrozenRewritePatternSet frozen_ps(std::move(ps)); + ir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 1; + ir::ApplyPatternsGreedily(op->region(0), frozen_ps, cfg); + } + + bool CanApplyOn(ir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } +}; + +void BuildProgram(ir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_op = + builder.Build(std::vector{1, 3, 16, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + ir::OpResult full_op_output = full_op->result(0); + + auto transpose1_op = builder.Build( + full_op_output, std::vector{0, 2, 3, 1}); + + builder.Build(transpose1_op.out(), + std::vector{0, 3, 1, 2}); + + // builder.Build(transpose2_op.out()); +} + +// TODO(wilber): Add a normal test. +TEST(PatternRewrite, GreedyPatternRewriteDriver) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ir::Program program(ctx); + ir::Builder builder = ir::Builder(ctx, program.block()); + BuildProgram(builder); + EXPECT_EQ(program.block()->size(), 3u); + + ir::PassManager pm(ctx); + pm.AddPass(std::make_unique()); + std::stringstream o1, o2; + program.Print(o1); + LOG(INFO) << o1.str(); + pm.Run(&program); + LOG(INFO) << "After Pass."; + program.Print(o2); + LOG(INFO) << o2.str(); +}