未验证 提交 f20d3fc2 编写于 作者: W Wilber 提交者: GitHub

[IR&PASS] part 3-1: Add PatternMatch base class. (#54385)

上级 c15e53d6
......@@ -4,3 +4,4 @@ endif()
add_subdirectory(core)
add_subdirectory(pass)
add_subdirectory(pattern_rewrite)
file(GLOB PATTERN_SRCS "*.cc")
cc_library(
pattern_rewrite
SRCS ${PATTERN_SRCS}
DEPS new_ir)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include <cassert>
#include <cstdint>
#include "paddle/ir/core/operation.h"
namespace ir {
//===----------------------------------------------------------------------===//
// Pattern
//===----------------------------------------------------------------------===//
// Pattern::Pattern(const void* root_val,
// RootKind root_kind,
// const std::vector<std::string>& generated_names,
// PatternBenefit benefit,
// ir::IrContext* context)
// : benefit_(benefit), context_(context), generated_names_(generated_names)
// {}
Pattern::Pattern(const std::string& root_name,
PatternBenefit benefit,
IrContext* context,
const std::vector<std::string>& generated_names)
: op_name_(root_name),
root_kind_(RootKind::OperationName),
benefit_(benefit),
context_(context),
generated_names_(generated_names) {}
Pattern::Pattern(MatchAnyOpTypeTag tag,
PatternBenefit benefit,
ir::IrContext* context,
const std::vector<std::string>& generated_names)
: root_kind_(RootKind::Any),
benefit_(benefit),
context_(context),
generated_names_(generated_names) {}
Pattern::Pattern(MatchInterfaceOpTypeTag tag,
ir::TypeId interface_id,
PatternBenefit benefit,
ir::IrContext* context,
const std::vector<std::string>& generated_names)
: interface_id_(interface_id),
root_kind_(RootKind::InterfaceId),
benefit_(benefit),
context_(context),
generated_names_(generated_names) {}
Pattern::Pattern(MatchTraitOpTypeTag tag,
ir::TypeId trait_id,
PatternBenefit benefit,
ir::IrContext* context,
const std::vector<std::string>& generated_names)
: trait_id_(trait_id),
root_kind_(RootKind::TraitId),
benefit_(benefit),
context_(context),
generated_names_(generated_names) {}
RewritePattern::~RewritePattern() = default;
//===----------------------------------------------------------------------===//
// RewriterBase
//===----------------------------------------------------------------------===//
RewriterBase::~RewriterBase() = default;
// TODO(wilber): value support replace method.
// void RewriterBase::ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// bool* all_uses_replaced,
// std::function<bool(OpOperand&)> functor) {
// // assert(op->num_results() == new_values.size() && "incorrect number of
// values to replace operation"); NotifyRootReplaced(op, new_values); bool
// replace_all_uses = true; for (uint32_t i = 0; i < op->num_results(); ++i) {
// // op->GetResultByIndex(0)
// }
// }
// void RewriterBase::ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// std::function<bool(OpOperand&)> functor) {
// ReplaceOpWithIf(op, new_values, nullptr, functor);
// }
// TODO(wilber): support erase.
// void ReplaceOp(Operation* op, ValueRange new_values) {
// NotifyRootReplaced(op, new_values);
// assert(op->num_results() == new_values.size() && "incorrect # of
// replacement values"); op->ReplaceAllUsesWith(new_values);
// NotifyOperationRemoved(op);
// op->erase();
// }
void RewriterBase::EraseOp(Operation* op) {
// assert(op->use_empty() && "expected 'op' to have no uses");
// NotifyOperationRemoved(op);
// op->erase();
}
void RewriterBase::ReplaceAllUsesWith(Value from, Value to) {
// from.
// for (mlir::OpOperand& operand : llvm::make_early_inc_range(from.getUses()))
// {
// mlir::Operation* op = operand.getOwner();
// UpdateRootInPlace(op, [&]() { operand.set(to); });
// }
}
// TODO(wilber): iterator maybe should support modify inplace.
void RewriterBase::ReplaceUseIf(Value from,
Value to,
std::function<bool(OpOperand&)> functor) {
// for (auto it = from.begin(); it != from.end(); ++it) {
// // TODO: need a lvalue.
// if (functor(it.get())) {
// UpdateRootInplace(it.owner(), [&](){it.get().set(to)});
// }
// }
}
void RewriterBase::ReplaceOpWithResultsOfAnotherOp(Operation* op,
Operation* new_op) {
assert(op->num_results() == new_op->num_results() &&
"replacement op doesn't match results of original op");
// TODO(wilber): Op support results method.
// if (op->num_results() == 1) return ReplaceOp(op,
// new_op->GetResultByIndex(0)); return ReplaceOp(op, new_op->GetResults());
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <initializer_list>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/type_id.h"
#include "paddle/ir/core/type_name.h"
#include "paddle/ir/core/value.h"
namespace ir {
/// The design is mainly from MLIR, very thanks to the greate project.
/// This class reprensents the benefit of a pattern. The most common
/// unit to use is the `numver of operations` in the pattern.
class PatternBenefit {
public:
PatternBenefit(unsigned val) : val_(val) {} // NOLINT
unsigned benefit() { return val_; }
bool operator==(const PatternBenefit& rhs) const { return val_ == rhs.val_; }
bool operator!=(const PatternBenefit& rhs) const { return !(*this == rhs); }
bool operator<(const PatternBenefit& rhs) const { return val_ < rhs.val_; }
bool operator>(const PatternBenefit& rhs) const { return rhs < *this; }
bool operator<=(const PatternBenefit& rhs) const { return !(*this > rhs); }
bool operator>=(const PatternBenefit& rhs) const { return !(*this <= rhs); }
private:
unsigned int val_{0};
};
/// This class contains all of the data related to a Pattern, but not contains
/// any methods for the matching. This class is used to interface with the
/// metadata of a pattern, such as benefit or root operation.
class Pattern {
enum class RootKind { Any, OperationName, InterfaceId, TraitId };
public:
PatternBenefit benefit() const { return benefit_; }
IrContext* context() const { return context_; }
std::string debug_name() const { return debug_name_; }
void SetDebugName(const std::string& name) { debug_name_ = name; }
const std::vector<std::string>& debug_labels() const { return debug_labels_; }
void AddDebugLabels(const std::vector<std::string>& labels) {
debug_labels_.insert(debug_labels_.end(), labels.begin(), labels.end());
}
void AddDebugLabels(const std::string& label) {
debug_labels_.push_back(label);
}
protected:
struct MatchAnyOpTypeTag {};
struct MatchInterfaceOpTypeTag {};
struct MatchTraitOpTypeTag {};
Pattern(const std::string& root_name,
PatternBenefit benefit,
ir::IrContext* context,
const std::vector<std::string>& generated_names = {});
Pattern(MatchAnyOpTypeTag tag,
PatternBenefit benefit,
ir::IrContext* context,
const std::vector<std::string>& generated_names = {});
Pattern(MatchInterfaceOpTypeTag tag,
ir::TypeId interface_id,
PatternBenefit benefit,
ir::IrContext* context,
const std::vector<std::string>& generated_names = {});
Pattern(MatchTraitOpTypeTag tag,
ir::TypeId trait_id,
PatternBenefit benefit,
ir::IrContext* context,
const std::vector<std::string>& generated_names = {});
private:
// TODO(wilber): How to uniform variables and constructor.
// Pattern(const void* root_val,
// RootKind root_kind,
// const std::vector<std::string>& generated_names,
// PatternBenefit benefit,
// ir::IrContext* context);
std::string op_name_;
ir::TypeId interface_id_;
ir::TypeId trait_id_;
RootKind root_kind_;
const PatternBenefit benefit_;
ir::IrContext* context_;
std::vector<std::string> generated_names_;
std::string debug_name_;
std::vector<std::string> debug_labels_;
};
class PatternRewriter;
class RewritePattern : public Pattern {
public:
virtual ~RewritePattern();
virtual void Rewrite(ir::Operation* op,
PatternRewriter& rewriter) const { // NOLINT
throw(
"need to implement either MatchAndRewrite or one of the rewrite "
"functions.");
}
virtual bool Match(ir::Operation* op) const {
throw("need to implement either MatchAndRewrite or Match.");
return false;
}
virtual bool MatchAndRewrite(ir::Operation* op,
PatternRewriter& rewriter) const { // NOLINT
if (Match(op)) {
Rewrite(op, rewriter);
return true;
}
return false;
}
virtual void Initialize() {}
template <typename T, typename... Args>
static std::unique_ptr<T> Create(Args&&... args) {
std::unique_ptr<T> pattern =
std::make_unique<T>(std::forward<Args>(args)...);
pattern->Initialize();
if (pattern->debug_name().empty())
pattern->SetDebugName(get_type_name<T>());
return pattern;
}
protected:
using Pattern::Pattern;
};
namespace detail {
/// A wrapper around PatternWrite that allows for matching and rewriting
/// against an instance of a derived operation class or Interface.
template <typename SourceOp>
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
using RewritePattern::RewritePattern;
void Rewrite(Operation* op,
PatternRewriter& rewriter) const final { // NOLINT
Rewrite(op->dyn_cast<SourceOp>(), rewriter);
}
bool Match(Operation* op) const final {
return Match(op->dyn_cast<SourceOp>());
}
bool MatchAndRewrite(Operation* op,
PatternRewriter& rewriter) const final { // NOLINT
return MatchAndRewrite(op->dyn_cast<SourceOp>(), rewriter);
}
virtual void Rewrite(SourceOp op,
PatternRewriter& rewriter) const { // NOLINT
throw("must override Rewrite or MatchAndRewrite");
}
virtual bool Match(SourceOp op) const {
throw("must override Match or MatchAndRewrite");
}
virtual bool MatchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const { // NOLINT
if (Match(op)) {
Rewrite(op, rewriter);
return true;
}
return false;
}
};
} // namespace detail
/// OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation
/// class as opposed to a raw Operation.
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
OpRewritePattern(ir::IrContext* context,
PatternBenefit benefit = 1,
const std::vector<std::string>& generated_names = {})
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
"NeedToFix", // TODO(wilber): Need to fix. SourceOp maybe should
// have a getOperationName static method.
benefit,
context,
generated_names) {}
};
// TODO(wilber): Support OpInterfaceRewritePattern and OpTraitRewritePattern.
// ...
/// This class provides a series of interfaces for modifying IR and tracking IR
/// changes. This class provides a unified API for IR modification.
///
class RewriterBase { // maybe should inherit OpBuilder.
public:
// TODO(wilber): Supplementary methods of block and region.
// TODO(wilber): Support ValueRange.
// virtual void ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// bool* all_uses_replaced,
// std::function<bool(OpOperand&)> functor);
// void ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// std::function<bool(OpOperand&)> functor);
// virtual void ReplaceOp(Operation* op, ValueRange new_values);
// virtual void ReplaceOpWithNewOp()
virtual void EraseOp(Operation* op);
virtual void StartRootUpdate(Operation* op) {}
virtual void FinalizeRootUpdate(Operation* op) {}
virtual void CancleRootUpdate(Operation* op) {}
template <typename CallableT>
void UpdateRootInplace(Operation* root, CallableT&& callable) {
StartRootUpdate(root);
callable();
FinalizeRootUpdate(root);
}
void ReplaceAllUsesWith(Value from, Value to);
void ReplaceUseIf(Value from,
Value to,
std::function<bool(OpOperand&)> functor);
protected:
explicit RewriterBase(IrContext* ctx) : ctx_(ctx) {}
virtual ~RewriterBase();
// virtual void NotifyRootReplaced(Operation* op, ValueRange replacement) {}
virtual void NotifyOperationRemoved(Operation* op) {}
// virtual bool NotifyMatchFailure()
private:
void operator=(const RewriterBase&) = delete;
RewriterBase(const RewriterBase&) = delete;
void ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op);
private:
IrContext* ctx_;
};
class PatternRewriter : public RewriterBase {
public:
using RewriterBase::RewriterBase;
};
/// A pattern collection, easy to add patterns.
class RewritePatternSet {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
explicit RewritePatternSet(IrContext* context) : context_(context) {}
RewritePatternSet(IrContext* context, std::unique_ptr<RewritePattern> pattern)
: context_(context) {
native_patterns_.emplace_back(std::move(pattern));
}
IrContext* context() const { return context_; }
NativePatternListT& native_patterns() { return native_patterns_; }
void Clear() { native_patterns_.clear(); }
// 'add' methods for adding patterns to the set.
template <typename... Ts,
typename ConstructorArg,
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
RewritePatternSet& Add(ConstructorArg&& arg, ConstructorArgs&&... args) {
std::initializer_list<int>{
(AddImpl<Ts>({},
std::forward<ConstructorArg>(arg),
std::forward<ConstructorArgs>(args)...),
0)...};
return *this;
}
template <typename... Ts,
typename ConstructorArg,
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
RewritePatternSet& AddWithLabel(const std::vector<std::string>& debug_labels,
ConstructorArg&& arg,
ConstructorArgs&&... args) {
std::initializer_list<int>{
(AddImpl<Ts>(debug_labels,
std::forward<ConstructorArg>(arg),
std::forward<ConstructorArgs>(args)...),
0)...};
return *this;
}
RewritePatternSet& Add(std::unique_ptr<RewritePattern> pattern) {
native_patterns_.emplace_back(std::move(pattern));
return *this;
}
private:
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<RewritePattern, T>::value> AddImpl(
const std::vector<std::string>& debug_labels, Args&&... args) {
std::unique_ptr<T> pattern =
RewritePattern::Create<T>(std::forward<Args>(args)...);
pattern->AddDebugLabels(debug_labels);
native_patterns_.emplace_back(std::move(pattern));
}
private:
IrContext* const context_;
NativePatternListT native_patterns_;
};
} // namespace ir
......@@ -4,3 +4,4 @@ endif()
add_subdirectory(core)
add_subdirectory(pass)
add_subdirectory(pattern_rewrite)
cc_test_old(
pattern_rewrite_test
SRCS
pattern_rewrite_test.cc
DEPS
new_pass
pattern_rewrite
pd_dialect
phi
gtest)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
TEST(PatternBenefit, PatternBenefit) {
ir::PatternBenefit benefit1(1);
EXPECT_EQ(benefit1.benefit(), 1U);
ir::PatternBenefit benefit2(2);
EXPECT_EQ(benefit2.benefit(), 2U);
EXPECT_TRUE(benefit2 > benefit1);
EXPECT_TRUE(benefit2 >= benefit1);
EXPECT_TRUE(benefit1 < benefit2);
EXPECT_TRUE(benefit1 <= benefit2);
EXPECT_TRUE(benefit1 != benefit2);
ir::PatternBenefit benefit3(2);
EXPECT_TRUE(benefit2 == benefit3);
}
// Define op1.
class Operation1 : public ir::Op<Operation1> {
public:
using Op::Op;
static const char *name() { return "test.Operation1"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
if (attributes.count("op2_attr1") == 0 ||
(!attributes.at("op2_attr1").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
}
if (attributes.count("op2_attr2") == 0 ||
(!attributes.at("op2_attr2").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
}
}
static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; }
};
const char *Operation1::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"};
// Define a dialect, op1 and op2 will be registered by this dialect.
class TestDialect : public ir::Dialect {
public:
explicit TestDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize();
}
static const char *name() { return "test"; }
private:
void initialize() { RegisterOps<Operation1>(); }
};
// TODO(wilber): Add logical when ir support erase, replace or update.
class TestPatternRewrite : public ir::OpRewritePattern<Operation1> {
public:
using ir::OpRewritePattern<Operation1>::OpRewritePattern;
void Rewrite(Operation1 op, ir::PatternRewriter &rewriter) const override {}
bool Match(Operation1 op) const override { return false; }
};
class TestPatternRewrite2 : public ir::OpRewritePattern<Operation1> {
public:
using ir::OpRewritePattern<Operation1>::OpRewritePattern;
bool MatchAndRewrite(
Operation1 op,
ir::PatternRewriter &rewriter) const override { // NOLINT
return false;
}
};
TEST(RewritePattern, OpRewritePattern) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
test_dialect->RegisterOp<Operation1>();
ir::RewritePatternSet ps(ctx);
ps.Add<TestPatternRewrite>(ctx, 1);
EXPECT_EQ(ps.native_patterns().size(), 1U);
EXPECT_TRUE(ps.native_patterns().back()->debug_labels().empty());
EXPECT_EQ(ps.native_patterns().back()->benefit(), 1U);
ps.AddWithLabel<TestPatternRewrite2>({"TestPatternRewrite2"}, ctx, 2);
EXPECT_EQ(ps.native_patterns().size(), 2U);
EXPECT_EQ(ps.native_patterns().back()->debug_labels()[0],
"TestPatternRewrite2");
EXPECT_EQ(ps.native_patterns().back()->benefit(), 2U);
ps.Clear();
ps.Add<TestPatternRewrite, TestPatternRewrite2>(ctx, 2);
EXPECT_EQ(ps.native_patterns().size(), 2U);
EXPECT_EQ(ps.native_patterns()[0]->benefit(), 2U);
EXPECT_EQ(ps.native_patterns()[1]->benefit(), 2U);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册