未验证 提交 72b8c7c2 编写于 作者: W Wilber 提交者: GitHub

[IR&PASS] part 3-3: Add PatternRewrite Driver code. (#54738)

上级 813266a2
......@@ -224,4 +224,12 @@ void Operation::SetParent(Block *parent, const Block::iterator &position) {
position_ = position;
}
void Operation::ReplaceAllUsesWith(const std::vector<Value> &values) {
IR_ENFORCE(num_results_ == values.size(),
"the num of result should be the same.");
for (uint32_t i = 0; i < num_results_; ++i) {
result(i).ReplaceAllUsesWith(values[i]);
}
}
} // namespace ir
......@@ -15,6 +15,7 @@
#pragma once
#include <ostream>
#include <vector>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h"
......@@ -102,6 +103,13 @@ class IR_API alignas(8) Operation final {
operator Block::const_iterator() const { return position_; }
/// Replace all uses of results of this operation with the provided 'values'.
void ReplaceAllUsesWith(const std::vector<Value> &values);
inline void ReplaceAllUsesWith(Value value) {
ReplaceAllUsesWith(std::vector<Value>{value});
}
private:
Operation(const AttributeMap &attribute,
ir::OpInfo op_info,
......
......@@ -15,6 +15,7 @@
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
namespace ir {
Region::~Region() { clear(); }
......@@ -50,4 +51,9 @@ void Region::clear() {
blocks_.pop_back();
}
}
IrContext *Region::ir_context() const {
IR_ENFORCE(parent_, "Region is not attached to a container.");
return parent_->ir_context();
}
} // namespace ir
......@@ -23,6 +23,7 @@ namespace ir {
class Block;
class Operation;
class IrContext;
class IR_API Region {
public:
......@@ -55,6 +56,8 @@ class IR_API Region {
Operation *GetParent() const { return parent_; }
IrContext *ir_context() const;
private:
Region(Region &) = delete;
Region &operator=(const Region &) = delete;
......
......@@ -85,6 +85,8 @@ OpOperand Value::first_use() const { return impl()->first_use(); }
bool Value::use_empty() const { return !first_use(); }
bool Value::HasOneUse() const { return impl()->HasOneUse(); }
void Value::ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const {
......
......@@ -158,10 +158,12 @@ class IR_API Value {
OpOperand first_use() const;
friend struct std::hash<Value>;
bool use_empty() const;
bool HasOneUse() const;
friend struct std::hash<Value>;
void ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const;
......
......@@ -98,6 +98,10 @@ class alignas(8) ValueImpl {
bool use_empty() const { return first_use() == nullptr; }
bool HasOneUse() const {
return (first_use() != nullptr) && (first_use()->next_use() == nullptr);
}
std::string PrintUdChain();
protected:
......
......@@ -21,12 +21,13 @@
#include <unordered_map>
#include <vector>
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace ir {
class FrozenRewritePatternSet {
class IR_API FrozenRewritePatternSet {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
......
......@@ -15,9 +15,9 @@
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
namespace ir {
......@@ -90,44 +90,55 @@ RewritePattern::~RewritePattern() = default;
//===----------------------------------------------------------------------===//
RewriterBase::~RewriterBase() = default;
// TODO(wilber): value support replace method.
// void RewriterBase::ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// bool* all_uses_replaced,
// std::function<bool(OpOperand&)> functor) {
// // assert(op->num_results() == new_values.size() && "incorrect number of
// values to replace operation"); NotifyRootReplaced(op, new_values); bool
// replace_all_uses = true; for (uint32_t i = 0; i < op->num_results(); ++i) {
// // op->result(0)
// }
// }
// void RewriterBase::ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// std::function<bool(OpOperand&)> functor) {
// ReplaceOpWithIf(op, new_values, nullptr, functor);
// }
// TODO(wilber): support erase.
// void ReplaceOp(Operation* op, ValueRange new_values) {
// NotifyRootReplaced(op, new_values);
// assert(op->num_results() == new_values.size() && "incorrect # of
// replacement values"); op->ReplaceAllUsesWith(new_values);
// NotifyOperationRemoved(op);
// op->erase();
// }
void RewriterBase::ReplaceOpWithIf(
Operation* op,
const std::vector<Value>& new_values,
bool* all_uses_replaced,
const std::function<bool(OpOperand)>& functor) {
IR_ENFORCE(op->num_results() == new_values.size(),
"incorrect number of values to replace operation");
NotifyRootReplaced(op, new_values);
// Replace each use of the results when the functor is true.
bool replace_all_uses = true;
for (uint32_t i = 0; i < op->num_results(); ++i) {
auto src_res = op->result(i);
src_res.ReplaceUsesWithIf(new_values[i], functor);
replace_all_uses &= src_res.use_empty();
}
if (replace_all_uses) {
*all_uses_replaced = replace_all_uses;
}
}
void RewriterBase::ReplaceOpWithIf(
Operation* op,
const std::vector<Value>& new_values,
const std::function<bool(OpOperand)>& functor) {
ReplaceOpWithIf(op, new_values, nullptr, functor);
}
void RewriterBase::ReplaceOp(Operation* op,
const std::vector<Value>& new_values) {
NotifyRootReplaced(op, new_values);
IR_ENFORCE(op->num_results() == new_values.size(),
"incorrect # of replacement values");
op->ReplaceAllUsesWith(new_values);
NotifyOperationRemoved(op);
op->GetParent()->erase(*op);
}
void RewriterBase::EraseOp(Operation* op) {
// assert(op->use_empty() && "expected 'op' to have no uses");
// NotifyOperationRemoved(op);
// op->erase();
// TODO(wilber): Operation support use_empty.
// IR_ENFORCE(op->use_empty(), "expected 'op' to have no uses");
NotifyOperationRemoved(op);
op->GetParent()->erase(*op);
}
/// Find uses of `from` and replace it with `to`
void RewriterBase::ReplaceAllUsesWith(Value from, Value to) {
// from.
// for (OpOperand& operand : llvm::make_early_inc_range(from.getUses()))
// {
// Operation* op = operand.getOwner();
// UpdateRootInPlace(op, [&]() { operand.set(to); });
// }
// TODO(wilber): Substitue a low level impl.
from.ReplaceAllUsesWith(to);
}
// TODO(wilber): iterator maybe should support modify inplace.
......@@ -135,8 +146,8 @@ void RewriterBase::ReplaceUseIf(Value from,
Value to,
std::function<bool(OpOperand&)> functor) {
// for (auto it = from.begin(); it != from.end(); ++it) {
// // TODO: need a lvalue.
// if (functor(it.get())) {
// // // TODO: need a lvalue.
// if (functor(*it)) {
// UpdateRootInplace(it.owner(), [&](){it.get().set(to)});
// }
// }
......@@ -144,8 +155,8 @@ void RewriterBase::ReplaceUseIf(Value from,
void RewriterBase::ReplaceOpWithResultsOfAnotherOp(Operation* op,
Operation* new_op) {
assert(op->num_results() == new_op->num_results() &&
"replacement op doesn't match results of original op");
IR_ENFORCE(op->num_results() == new_op->num_results(),
"replacement op doesn't match results of original op");
// TODO(wilber): Op support results method.
// if (op->num_results() == 1) return ReplaceOp(op,
// new_op->result(0)); return ReplaceOp(op, new_op->GetResults());
......
......@@ -25,6 +25,7 @@
#include <vector>
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
......@@ -36,7 +37,7 @@ namespace ir {
// This class reprensents the benefit of a pattern. The most common
// unit to use is the `numver of operations` in the pattern.
class PatternBenefit {
class IR_API PatternBenefit {
public:
PatternBenefit() = default;
PatternBenefit(uint32_t val) : val_(val) {} // NOLINT
......@@ -257,30 +258,21 @@ class RewriterBase : public Builder {
public:
// TODO(wilber): Supplementary methods of block and region.
// TODO(wilber): Support ValueRange.
// virtual void ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// bool* all_uses_replaced,
// std::function<bool(OpOperand&)> functor);
// void ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// std::function<bool(OpOperand&)> functor);
// virtual void ReplaceOp(Operation* op, ValueRange new_values);
virtual void ReplaceOpWithIf(Operation* op,
const std::vector<Value>& new_values,
bool* all_uses_replaced,
const std::function<bool(OpOperand)>& functor);
// virtual void ReplaceOpWithNewOp()
void ReplaceOpWithIf(Operation* op,
const std::vector<Value>& new_values,
const std::function<bool(OpOperand)>& functor);
virtual void EraseOp(Operation* op);
virtual void ReplaceOp(Operation* op, const std::vector<Value>& new_values);
virtual void StartRootUpdate(Operation* op) {}
virtual void FinalizeRootUpdate(Operation* op) {}
virtual void CancleRootUpdate(Operation* op) {}
// template <typename OpTy, typename... Args>
// OpTy ReplaceOpWithNewOp(Operation *op, Args &&...args);
template <typename CallableT>
void UpdateRootInplace(Operation* root, CallableT&& callable) {
StartRootUpdate(root);
callable();
FinalizeRootUpdate(root);
}
virtual void EraseOp(Operation* op);
void ReplaceAllUsesWith(Value from, Value to);
......@@ -293,11 +285,25 @@ class RewriterBase : public Builder {
virtual ~RewriterBase();
// virtual void NotifyRootReplaced(Operation* op, ValueRange replacement) {}
virtual void NotifyRootReplaced(Operation* op,
const std::vector<Value>& replacement) {}
virtual void NotifyOperationRemoved(Operation* op) {}
// virtual bool NotifyMatchFailure()
virtual void NotifyOperationInserted(Operation* op) {}
virtual void StartRootUpdate(Operation* op) {}
virtual void FinalizeRootUpdate(Operation* op) {}
virtual void CancleRootUpdate(Operation* op) {}
template <typename CallableT>
void UpdateRootInplace(Operation* root, CallableT&& callable) {
StartRootUpdate(root);
callable();
FinalizeRootUpdate(root);
}
private:
void operator=(const RewriterBase&) = delete;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include <unordered_map>
#include <unordered_set>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/value.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace {
class GreedyPatternRewriteDriver : public ir::PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(
ir::IrContext* ctx,
const ir::FrozenRewritePatternSet& patterns,
const ir::GreedyRewriteConfig& config)
: ir::PatternRewriter(ctx),
config_(config),
region_(*config.region),
matcher_(patterns) {
worklist_.reserve(128);
matcher_.ApplyDefaultCostModel();
if (config.strict_mode != ir::GreedyRewriteStrictness::AnyOp) {
for (auto it = region_.begin(); it != region_.end(); ++it) {
for (auto op_it = (*it)->begin(); op_it != (*it)->end(); ++op_it) {
strict_mode_filtered_ops_.insert(*op_it);
}
}
}
}
bool Simplify() {
bool changed = false;
int64_t iteration = 0;
do {
// Check if the iteration limit was reached.
if (iteration++ >= config_.max_iterations &&
config_.max_iterations != ir::GreedyRewriteConfig::kNoLimit)
break;
VLOG(6) << "Iteration[" << iteration << "] for PatternRewrite";
worklist_.clear();
worklist_map_.clear();
for (auto block_it = region_.begin(); block_it != region_.end();
++block_it) {
for (auto op_it = (*block_it)->begin(); op_it != (*block_it)->end();
++op_it) {
worklist_.push_back(*op_it);
}
}
if (config_.use_top_down_traversal) {
// Reverse the list so out pop-back loop process them in-order.
std::reverse(worklist_.begin(), worklist_.end());
}
for (size_t i = 0; i < worklist_.size(); ++i) {
worklist_map_[worklist_[i]] = i;
VLOG(6) << "worklist[" << i << "] is " << worklist_[i]->name();
}
changed = ProcessWorklist();
} while (changed);
return !changed;
}
private:
/// Process ops until the worklist is empty or `config.max_num_rewrites`
/// is reached. Return `true` if any IR was changed.
bool ProcessWorklist() {
bool changed = false;
int64_t num_rewrites = 0;
while (!worklist_.empty() &&
(num_rewrites < config_.max_num_rewrites ||
config_.max_num_rewrites == ir::GreedyRewriteConfig::kNoLimit)) {
auto* op = PopFromWorklist();
if (op == nullptr) continue;
VLOG(6) << "PopFromWorklist, get op: " << op->name();
// TODO(wilber): ir is dead.
// ...
// TODO(wilber): fold logical.
// ...
bool match_result = matcher_.MatchAndRewrite(op, *this);
if (match_result) {
changed = true;
++num_rewrites;
}
}
return changed;
}
// TODO(wilber): OpResult support GetUsers method.
void NotifyRootReplaced(ir::Operation* op,
const std::vector<ir::Value>& replacement) override {
// for (uint32_t i = 0; i < op->num_results(); ++i) {
// auto res = op->GetResultByIndex(i);
// }
// }
}
void FinalizeRootUpdate(ir::Operation* op) override { AddToWorklist(op); }
void NotifyOperationRemoved(ir::Operation* op) override {
for (uint32_t i = 0; i < op->num_operands(); ++i) {
AddOperandToWorklist(op->operand(i).source());
}
for (uint32_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->region(i);
for (auto it = region.begin(); it != region.end(); ++it) {
for (auto op_it = (*it)->begin(); op_it != (*it)->end(); ++op_it) {
RemoveFromWorklist(*op_it);
}
}
}
if (config_.strict_mode != ir::GreedyRewriteStrictness::AnyOp) {
strict_mode_filtered_ops_.erase(op);
}
}
void NotifyOperationInserted(ir::Operation* op) override {
if (config_.strict_mode == ir::GreedyRewriteStrictness::ExistingAndNewOps)
strict_mode_filtered_ops_.insert(op);
AddToWorklist(op);
}
/// Add the given operation to the worklist.
void AddToWorklist(ir::Operation* op) {
if (config_.strict_mode == ir::GreedyRewriteStrictness::AnyOp ||
strict_mode_filtered_ops_.count(op)) {
if (worklist_map_.count(op)) return;
worklist_map_[op] = worklist_.size();
worklist_.push_back(op);
}
}
void AddOperandToWorklist(ir::Value operand) {
// If the use count of this operand is now < 2, we re-add the defining
// operation to the worklist.
// This is based on the fact that zero use operations may be deleted, and
// that single use values often have more canonicalization opportunities.
if (!operand || (!operand.use_empty() && !operand.HasOneUse())) return;
if (auto* def_op = operand.GetDefiningOp()) AddToWorklist(def_op);
}
void AddOperandsToWorklist(const std::vector<ir::Value> operands) {
for (auto& v : operands) {
AddOperandToWorklist(v);
}
}
/// Pop the next operation from the worklist
ir::Operation* PopFromWorklist() {
auto* op = worklist_.back();
worklist_.pop_back();
if (op) worklist_map_.erase(op);
return op;
}
/// If the specified operation is in the worklist, remove it.
void RemoveFromWorklist(ir::Operation* op) {
auto it = worklist_map_.find(op);
if (it != worklist_map_.end()) {
worklist_[it->second] = nullptr;
worklist_map_.erase(it);
}
}
private:
std::vector<ir::Operation*> worklist_;
std::unordered_map<ir::Operation*, unsigned> worklist_map_;
ir::GreedyRewriteConfig config_;
std::unordered_set<ir::Operation*> strict_mode_filtered_ops_;
ir::Region& region_;
ir::PatternApplicator matcher_;
};
} // namespace
namespace ir {
bool ApplyPatternsGreedily(Region& region, // NOLINT
const FrozenRewritePatternSet& patterns,
GreedyRewriteConfig config) {
if (!config.region) config.region = &region;
GreedyPatternRewriteDriver driver(region.ir_context(), patterns, config);
bool converged = driver.Simplify();
if (!converged) {
LOG(WARNING) << "The pattern rewrite did not converge after scaning "
<< config.max_iterations << " times";
}
return converged;
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace ir {
/// This enum will control which ops will be added to the worklist during the
/// match rewrite process
enum class IR_API GreedyRewriteStrictness {
/// No restrictions wrt. any ops are processed.
AnyOp,
/// Only pre-existing and newly created ops are processed.
ExistingAndNewOps,
/// Only pre-existing ops are processed.
ExistingOps
};
/// Control over how the GreedyPatternRewriteDriver works.
class IR_API GreedyRewriteConfig {
public:
/// Control the way op is added to the worklist: bottom-up or top-down.
bool use_top_down_traversal = false;
/// Control the maximum number of iterations in the process of applying the
/// pattern, use `kNolimit` to represent unlimited.
int64_t max_iterations = 10;
/// Control the upper limit of rewrite times during each iteration, use
/// kNoLimit to represent unlimited.
int64_t max_num_rewrites = kNoLimit;
/// Only the op inside this region will be added to the worklist.
Region* region{nullptr};
/// Limit which ops will be added to the worklist during the Match and Rewrite
/// process.
/// - AnyOp: all ops will be added to the worklist.
/// - ExistingAndNewOps: pre-existing ops and newly created ops are added to
/// the worklist.
/// - ExistingOps: only pre-existing ops are added to the worklist.
GreedyRewriteStrictness strict_mode = GreedyRewriteStrictness::AnyOp;
static constexpr int64_t kNoLimit = -1;
};
/// Perform the Match and Rewrite process in the specified region, greedily
/// apply the Pattern with the highest benefit, and repeat this process until
/// convergence or the upper limit of iterations.
///
/// Returns true if the iteration converges and no patterns can be applied.
bool IR_API
ApplyPatternsGreedily(Region& region, // NOLINT
const FrozenRewritePatternSet& patterns,
GreedyRewriteConfig config = GreedyRewriteConfig());
/// Perform a match and rewrite process for all regions of a given op.
inline IR_API bool ApplyPatternsGreedily(
Operation* op,
const FrozenRewritePatternSet& patterns,
GreedyRewriteConfig config = GreedyRewriteConfig()) {
bool failed = false;
for (uint32_t i = 0; i < op->num_regions(); ++i) {
Region& region = op->region(i);
failed |= !ApplyPatternsGreedily(region, patterns, config);
}
return !failed;
}
} // namespace ir
......@@ -256,6 +256,7 @@ TEST(op_test, region_test) {
block->push_front(op1);
block->insert(block->begin(), op1_2);
ir::Operation *op2 = ir::Operation::Create(std::move(argument));
EXPECT_EQ(op2->region(0).ir_context(), ctx);
op2->Destroy();
}
......
......@@ -45,6 +45,7 @@ TEST(value_test, value_test) {
ir::OpInfo());
op1->Print(std::cout);
ir::OpResult a = op1->result(0);
EXPECT_TRUE(a.use_empty());
// 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
......@@ -55,6 +56,7 @@ TEST(value_test, value_test) {
ir::OpInfo());
op2->Print(std::cout);
ir::OpResult b = op2->result(0);
EXPECT_TRUE(b.use_empty());
// 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs{a, b};
std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)};
......@@ -63,6 +65,9 @@ TEST(value_test, value_test) {
CreateAttributeMap("op3_name", "op3_attr"),
op3_output_types,
ir::OpInfo());
EXPECT_TRUE(op1->result(0).HasOneUse());
EXPECT_TRUE(op2->result(0).HasOneUse());
op3->Print(std::cout);
ir::OpResult c = op3->result(0);
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
......
cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS ir gtest)
cc_test_old(
pattern_rewrite_test
SRCS
pattern_rewrite_test.cc
DEPS
ir
pd_dialect
gtest)
......@@ -13,27 +13,33 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <numeric>
#include <sstream>
#include <vector>
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/cast_utils.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
TEST(pattern_rewrite, PatternBenefit) {
ir::PatternBenefit benefit1(1);
EXPECT_EQ(benefit1.benefit(), 1U);
ir::PatternBenefit benefit2(2);
EXPECT_EQ(benefit2.benefit(), 2U);
EXPECT_TRUE(benefit2 > benefit1);
EXPECT_TRUE(benefit2 >= benefit1);
EXPECT_TRUE(benefit1 < benefit2);
EXPECT_TRUE(benefit1 <= benefit2);
EXPECT_TRUE(benefit1 != benefit2);
ir::PatternBenefit benefit3(2);
EXPECT_TRUE(benefit2 == benefit3);
}
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
// Define op1.
class Operation1 : public ir::Op<Operation1> {
......@@ -95,7 +101,22 @@ class TestPatternRewrite2 : public ir::OpRewritePattern<Operation1> {
}
};
TEST(pattern_rewrite, RewritePatternSet) {
TEST(PatternRewrite, PatternBenefit) {
ir::PatternBenefit benefit1(1);
EXPECT_EQ(benefit1.benefit(), 1U);
ir::PatternBenefit benefit2(2);
EXPECT_EQ(benefit2.benefit(), 2U);
EXPECT_TRUE(benefit2 > benefit1);
EXPECT_TRUE(benefit2 >= benefit1);
EXPECT_TRUE(benefit1 < benefit2);
EXPECT_TRUE(benefit1 <= benefit2);
EXPECT_TRUE(benefit1 != benefit2);
ir::PatternBenefit benefit3(2);
EXPECT_TRUE(benefit2 == benefit3);
}
TEST(RewritePattern, RewritePatternSet) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
......@@ -118,3 +139,144 @@ TEST(pattern_rewrite, RewritePatternSet) {
EXPECT_EQ(ps.native_patterns()[0]->benefit(), 2U);
EXPECT_EQ(ps.native_patterns()[1]->benefit(), 2U);
}
// TODO(wilber): Add actual case.
// TEST(PatternRewrite, PatternApplicator) {
// ir::IrContext *ctx = ir::IrContext::Instance();
// ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// auto *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
// test_dialect->RegisterOp<Operation1>();
// ir::RewritePatternSet ps(ctx);
// ps.Add<TestPatternRewrite, TestPatternRewrite2>(ctx, 2);
// ir::FrozenRewritePatternSet frozen_set(std::move(ps));
// ir::PatternApplicator applicator(frozen_set);
// applicator.ApplyDefaultCostModel();
// }
// // TODO(wilber): Add actual case.
TEST(PatternRewrite, FrozenRewritePatternSet) {
ir::FrozenRewritePatternSet frozen_set;
EXPECT_TRUE(frozen_set.match_any_op_native_patterns().empty());
EXPECT_TRUE(frozen_set.op_specific_native_patterns().empty());
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
test_dialect->RegisterOp<Operation1>();
ir::RewritePatternSet ps(ctx);
ps.Add<TestPatternRewrite, TestPatternRewrite2>(ctx, 2);
ir::FrozenRewritePatternSet frozen_set2(std::move(ps));
EXPECT_TRUE(frozen_set2.match_any_op_native_patterns().empty());
const auto &pattern_maps = frozen_set2.op_specific_native_patterns();
EXPECT_EQ(pattern_maps.size(), 1U);
EXPECT_EQ(pattern_maps.at(ctx->GetRegisteredOpInfo("test.Operation1")).size(),
2U);
}
class TransposePatternRewrite
: public ir::OpRewritePattern<paddle::dialect::TransposeOp> {
public:
using ir::OpRewritePattern<paddle::dialect::TransposeOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::TransposeOp op,
ir::PatternRewriter &rewriter) const override {
auto prev_op = op->operand(0).source().GetDefiningOp();
std::vector<int> axis_last = GetAxis(op);
auto prev_trans_op = prev_op->dyn_cast<paddle::dialect::TransposeOp>();
if (prev_trans_op) {
std::vector<int> axis_first = GetAxis(prev_trans_op);
IR_ENFORCE(axis_first.size() == axis_last.size(),
"tranpose op's perm rank should be same.");
auto new_perm = GetPerm(axis_first, axis_last);
rewriter.SetInsertionPoint(op);
auto new_op = rewriter.Build<paddle::dialect::TransposeOp>(
prev_op->operand(0).source().GetDefiningOp()->result(0), new_perm);
rewriter.ReplaceOp(op, {new_op.out()});
return true;
}
return false;
}
private:
std::vector<int> GetAxis(paddle::dialect::TransposeOp op) const {
auto attr_map = op->attributes();
ir::ArrayAttribute array_attr =
attr_map.at("perm").dyn_cast<ir::ArrayAttribute>();
std::vector<int> axis(array_attr.size());
for (size_t i = 0; i < array_attr.size(); ++i) {
axis[i] = array_attr[i].dyn_cast<ir::Int32Attribute>().data();
}
return axis;
}
std::vector<int> GetPerm(const std::vector<int> &perm1,
const std::vector<int> &perm2) const {
int n = perm1.size();
std::vector<int> axis(n), axis1(n), axis2(n);
std::iota(axis.begin(), axis.end(), 0);
for (int i = 0; i < n; ++i) {
axis1[i] = axis[perm1[i]];
}
for (int i = 0; i < n; ++i) {
axis2[i] = axis1[perm2[i]];
}
return axis2;
}
};
class TestPass : public ir::Pass {
public:
TestPass() : ir::Pass("TestPass", 1) {}
void Run(ir::Operation *op) override {
ir::RewritePatternSet ps(op->ir_context());
ps.Add<TransposePatternRewrite>(op->ir_context());
ir::FrozenRewritePatternSet frozen_ps(std::move(ps));
ir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 1;
ir::ApplyPatternsGreedily(op->region(0), frozen_ps, cfg);
}
bool CanApplyOn(ir::Operation *op) const override {
return op->name() == "builtin.module" && op->num_regions() > 0;
}
};
void BuildProgram(ir::Builder &builder) { // NOLINT
paddle::dialect::FullOp full_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1, 3, 16, 16},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
ir::OpResult full_op_output = full_op->result(0);
auto transpose1_op = builder.Build<paddle::dialect::TransposeOp>(
full_op_output, std::vector<int>{0, 2, 3, 1});
builder.Build<paddle::dialect::TransposeOp>(transpose1_op.out(),
std::vector<int>{0, 3, 1, 2});
// builder.Build<paddle::dialect::FetchOp>(transpose2_op.out());
}
// TODO(wilber): Add a normal test.
TEST(PatternRewrite, GreedyPatternRewriteDriver) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder(ctx, program.block());
BuildProgram(builder);
EXPECT_EQ(program.block()->size(), 3u);
ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
std::stringstream o1, o2;
program.Print(o1);
LOG(INFO) << o1.str();
pm.Run(&program);
LOG(INFO) << "After Pass.";
program.Print(o2);
LOG(INFO) << o2.str();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册