未验证 提交 548fb821 编写于 作者: Y Yuanle Liu 提交者: GitHub

[IR&PASS] part 3-2: add PatternApplicator and FrozenRewritePatternSet, refine...

[IR&PASS] part 3-2: add PatternApplicator and FrozenRewritePatternSet, refine PatternMatch code, add some api for Builder (#54492)

* [IR&PASS] add PatternApplicator and FrozenRewritePatternSet, refine PatternMatch code, add some api for Builder and TypeId

* fix comment
上级 1b5e1e81
......@@ -23,10 +23,12 @@ namespace ir {
class Operation;
class Block {
using OpListType = std::list<Operation *>;
public:
using iterator = std::list<Operation *>::iterator;
using reverse_iterator = std::list<Operation *>::reverse_iterator;
using const_iterator = std::list<Operation *>::const_iterator;
using iterator = OpListType::iterator;
using reverse_iterator = OpListType::reverse_iterator;
using const_iterator = OpListType::const_iterator;
Block() = default;
~Block();
......@@ -60,7 +62,7 @@ class Block {
private:
Region *parent_; // not owned
OpListType ops_; // owned
Region::iterator position_;
std::list<Operation *> ops_; // owned
};
} // namespace ir
......@@ -20,6 +20,7 @@
#include "paddle/ir/core/operation.h"
namespace ir {
///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute
/// classes only derives interfaces, not members.
......@@ -27,11 +28,47 @@ namespace ir {
class Builder {
public:
Builder(IrContext *context, Block *block, Block::iterator insert_point)
: context_(context), block_(block), insert_point_(insert_point) {}
: context_(context) {
SetInsertionPoint(block, insert_point);
}
Builder(IrContext *context, Block *block)
: Builder(context, block, block->end()) {}
IrContext *context() const { return context_; }
explicit Builder(IrContext *context)
: Builder(context, nullptr, Block::iterator{}) {}
/// Set the insertion point to the specified location.
void SetInsertionPoint(Block *block, Block::iterator insert_point) {
// TODO(liuyuanle): check that insertPoint is in this rather than some other
// block.
this->block_ = block;
this->insert_point_ = insert_point;
}
/// Set the insertion point to the specified operation, which will cause
/// subsequent insertions to go right before it.
void SetInsertionPoint(Operation *op) {
SetInsertionPoint(op->GetParent(), Block::iterator{*op});
}
/// Set the insertion point to the node after the specified operation, which
/// will cause subsequent insertions to go right after it.
void SetInsertionPointAfter(Operation *op) {
SetInsertionPoint(op->GetParent(), std::next(Block::iterator{*op}));
}
/// Set the insertion point to the start of the specified block.
void SetInsertionPointToStart(Block *block) {
SetInsertionPoint(block, block->begin());
}
/// Set the insertion point to the end of the specified block.
void SetInsertionPointToEnd(Block *block) {
SetInsertionPoint(block, block->end());
}
IrContext *ir_context() const { return context_; }
Block *block() const { return block_; }
......@@ -57,8 +94,9 @@ class Builder {
Operation *Insert(Operation *op);
IrContext *context_;
Block *block_ = nullptr;
Block *block_;
// The insertion point within the list that this builder is inserting before.
Block::iterator insert_point_;
};
} // namespace ir
......@@ -142,8 +142,8 @@ class ConstructInterfacesOrTraits {
static void PlacementConstrctInterface(
InterfaceValue *&p_interface) { // NOLINT
p_interface->swap(InterfaceValue::get<ConcreteOp, T>());
VLOG(4) << "New a interface: id[" << (p_interface->type_id()).storage()
<< "].";
VLOG(4) << "New a interface: id["
<< (p_interface->type_id()).AsOpaquePointer() << "].";
++p_interface;
}
......@@ -151,7 +151,7 @@ class ConstructInterfacesOrTraits {
template <typename T>
static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT
*p_trait = TypeId::get<T>();
VLOG(4) << "New a trait: id[" << p_trait->storage() << "].";
VLOG(4) << "New a trait: id[" << p_trait->AsOpaquePointer() << "].";
++p_trait;
}
};
......@@ -206,4 +206,5 @@ class Op : public OpBase {
return trait_set;
}
};
} // namespace ir
......@@ -71,15 +71,15 @@ class OpInfo {
typename Interface::Concept *GetInterfaceImpl() const;
void *AsOpaquePointer() const { return impl_; }
static OpInfo RecoverFromOpaquePointer(void *impl) {
return static_cast<OpInfoImpl *>(impl);
static OpInfo RecoverFromOpaquePointer(void *pointer) {
return OpInfo(static_cast<OpInfoImpl *>(pointer));
}
friend class OpInfoImpl;
friend struct std::hash<OpInfo>;
private:
OpInfo(OpInfoImpl *impl) : impl_(impl) {} // NOLINT
explicit OpInfo(OpInfoImpl *impl) : impl_(impl) {}
void *GetInterfaceImpl(TypeId interface_id) const;
private:
......
......@@ -50,14 +50,14 @@ OpInfo OpInfoImpl::Create(Dialect *dialect,
}
// Construct OpInfoImpl.
VLOG(4) << "Construct OpInfoImpl at " << base_ptr << " ......";
OpInfo op_info = new (base_ptr) OpInfoImpl(dialect,
op_id,
op_name,
interfaces_num,
traits_num,
attributes_num,
attributes_name,
verify);
OpInfo op_info = OpInfo(new (base_ptr) OpInfoImpl(dialect,
op_id,
op_name,
interfaces_num,
traits_num,
attributes_num,
attributes_name,
verify));
return op_info;
}
void OpInfoImpl::Destroy(OpInfo info) {
......
......@@ -51,7 +51,10 @@ class TypeId {
TypeId &operator=(const TypeId &other) = default;
const Storage *storage() const { return storage_; }
void *AsOpaquePointer() const { return storage_; }
static TypeId RecoverFromOpaquePointer(void *pointer) {
return TypeId(static_cast<Storage *>(pointer));
}
///
/// \brief Comparison operations.
......@@ -77,10 +80,11 @@ class TypeId {
///
/// \param storage The storage of this TypeId.
///
explicit TypeId(const Storage *storage) : storage_(storage) {}
explicit TypeId(Storage *storage) : storage_(storage) {}
const Storage *storage_{nullptr};
Storage *storage_{nullptr};
};
} // namespace ir
namespace std {
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include <memory>
#include <optional>
#include <set>
#include <string>
#include "paddle/ir/core/op_info.h"
#include "paddle/utils/optional.h"
namespace ir {
FrozenRewritePatternSet::FrozenRewritePatternSet()
: impl_(std::make_shared<Impl>()) {}
FrozenRewritePatternSet::FrozenRewritePatternSet(
RewritePatternSet&& patterns,
const std::vector<std::string>& disabled_pattern_labels,
const std::vector<std::string>& enabled_pattern_labels)
: impl_(std::make_shared<Impl>()) {
std::set<std::string> disabled_patterns, enabled_patterns;
disabled_patterns.insert(disabled_pattern_labels.begin(),
disabled_pattern_labels.end());
enabled_patterns.insert(enabled_pattern_labels.begin(),
enabled_pattern_labels.end());
ir::OpInfoMap op_info_map;
auto AddToOpsWhen = [&](std::unique_ptr<RewritePattern>& pattern,
std::function<bool(OpInfo)> callback) {
if (op_info_map.empty())
op_info_map = pattern->ir_context()->registered_op_info_map();
for (auto& info_map : op_info_map) {
if (callback(info_map.second))
impl_->op_specific_native_pattern_map_[info_map.second].push_back(
pattern.get());
impl_->op_specific_native_patterns_.push_back(std::move(pattern));
}
};
for (std::unique_ptr<RewritePattern>& pat : patterns.native_patterns()) {
// Don't add patterns that haven't been enabled by the user.
if (!enabled_patterns.empty()) {
auto IsEnableFn = [&](const std::string& label) {
return enabled_patterns.count(label);
};
if (!IsEnableFn(pat->debug_name()) &&
std::none_of(pat->debug_labels().begin(),
pat->debug_labels().end(),
IsEnableFn))
continue;
}
// Don't add patterns that have been disabled by the user.
if (!disabled_patterns.empty()) {
auto IsDisabledFn = [&](const std::string& label) {
return disabled_patterns.count(label);
};
if (IsDisabledFn(pat->debug_name()) ||
std::any_of(pat->debug_labels().begin(),
pat->debug_labels().end(),
IsDisabledFn))
continue;
}
if (paddle::optional<OpInfo> root_name = pat->root_kind()) {
impl_->op_specific_native_pattern_map_[*root_name].push_back(pat.get());
impl_->op_specific_native_patterns_.push_back(std::move(pat));
continue;
}
if (paddle::optional<TypeId> interface_id = pat->GetRootInterfaceID()) {
AddToOpsWhen(
pat, [&](OpInfo info) { return info.HasInterface(*interface_id); });
continue;
}
if (paddle::optional<TypeId> trait_id = pat->GetRootTraitID()) {
AddToOpsWhen(pat, [&](OpInfo info) { return info.HasTrait(*trait_id); });
continue;
}
impl_->match_any_op_native_patterns_.push_back(std::move(pat));
}
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// The design and code is mainly from MLIR, thanks to the greate project.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace ir {
class FrozenRewritePatternSet {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
using OpSpecificNativePatternListT =
std::unordered_map<OpInfo, std::vector<RewritePattern*>>;
FrozenRewritePatternSet();
FrozenRewritePatternSet(FrozenRewritePatternSet&& patterns) = default;
FrozenRewritePatternSet(const FrozenRewritePatternSet& patterns) = default;
FrozenRewritePatternSet& operator=(FrozenRewritePatternSet&& patterns) =
default;
FrozenRewritePatternSet& operator=(const FrozenRewritePatternSet& patterns) =
default;
~FrozenRewritePatternSet() = default;
/// Freeze the patterns held in `patterns`, and take ownership.
FrozenRewritePatternSet(
RewritePatternSet&& patterns,
const std::vector<std::string>& disabled_pattern_labels = {},
const std::vector<std::string>& enabled_pattern_labels = {});
/// Return the op specific native patterns held by this list.
const OpSpecificNativePatternListT& op_specific_native_patterns() const {
return impl_->op_specific_native_pattern_map_;
}
/// Return the "match any" native patterns held by this list.
const NativePatternListT& match_any_op_native_patterns() const {
return impl_->match_any_op_native_patterns_;
}
private:
struct Impl {
OpSpecificNativePatternListT op_specific_native_pattern_map_;
NativePatternListT op_specific_native_patterns_;
NativePatternListT match_any_op_native_patterns_;
};
std::shared_ptr<Impl> impl_;
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace ir {
PatternApplicator::PatternApplicator(
const FrozenRewritePatternSet& frozen_patter_list)
: frozen_patter_list_(frozen_patter_list) {}
void PatternApplicator::ApplyCostModel(CostModel model) {
// TODO(wilber): remove impossible patterns.
patterns_.clear();
for (const auto& it : frozen_patter_list_.op_specific_native_patterns()) {
for (const RewritePattern* pattern : it.second) {
patterns_[it.first].push_back(pattern);
}
}
any_op_patterns_.clear();
for (auto& pattern : frozen_patter_list_.match_any_op_native_patterns()) {
any_op_patterns_.push_back(pattern.get());
}
// Sort by benefit based on the cost model.
std::unordered_map<const Pattern*, PatternBenefit> benefits;
auto cmp = [&benefits](const Pattern* lhs, const Pattern* rhs) {
return benefits[lhs] > benefits[rhs];
};
auto ProcessPatternList = [&](std::vector<const RewritePattern*>& list) {
if (list.size() == 1) return;
benefits.clear();
for (const Pattern* pat : list) benefits.emplace(pat, model(*pat));
std::stable_sort(list.begin(), list.end(), cmp);
};
for (auto& it : patterns_) {
ProcessPatternList(it.second);
}
ProcessPatternList(any_op_patterns_);
}
void PatternApplicator::WalkAllPatterns(
std::function<void(const Pattern&)> walk) {
for (const auto& it : frozen_patter_list_.op_specific_native_patterns())
for (auto* pattern : it.second) walk(*pattern);
for (auto& it : frozen_patter_list_.match_any_op_native_patterns()) walk(*it);
}
bool PatternApplicator::MatchAndRewrite(
Operation* op,
PatternRewriter& rewriter,
std::function<bool(const Pattern&)> can_apply,
std::function<void(const Pattern&)> on_failure,
std::function<bool(const Pattern&)> on_success) {
// whether there are patterns matching this operation type.
std::vector<const RewritePattern*> op_patterns;
auto pattern_it = patterns_.find(op->info());
if (pattern_it != patterns_.end()) op_patterns = pattern_it->second;
unsigned op_it = 0, op_e = op_patterns.size();
unsigned any_it = 0, any_e = any_op_patterns_.size();
bool result = false;
do {
// Find the next pattern with the highest benefit.
const Pattern* best_pattern = nullptr;
unsigned* best_pattern_it = &op_it;
// For specific patterns
if (op_it < op_e) best_pattern = op_patterns[op_it];
// For op-agnostic patterns
if (any_it < any_e &&
(!best_pattern ||
best_pattern->benefit() < any_op_patterns_[any_it]->benefit())) {
best_pattern_it = &any_it;
best_pattern = any_op_patterns_[any_it];
}
if (!best_pattern) break;
// Update the pattern iterator, so that this pattern isn't attempted again.
++(*best_pattern_it);
if (can_apply && !can_apply(*best_pattern)) continue;
rewriter.SetInsertionPoint(op);
const auto* pattern = static_cast<const RewritePattern*>(best_pattern);
result = pattern->MatchAndRewrite(op, rewriter);
if (result && on_success && !on_success(*best_pattern)) result = false;
if (result) break;
if (on_failure) on_failure(*best_pattern);
} while (true);
return result;
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The design and code is mainly from MLIR, thanks to the greate project.
#pragma once
#include <functional>
#include <unordered_map>
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace ir {
class PatternApplicator {
public:
using CostModel = std::function<PatternBenefit(const Pattern&)>;
explicit PatternApplicator(const FrozenRewritePatternSet& frozen_patter_list);
~PatternApplicator() = default;
bool MatchAndRewrite(Operation* op,
PatternRewriter& rewriter, // NOLINT
std::function<bool(const Pattern&)> can_apply = {},
std::function<void(const Pattern&)> on_failure = {},
std::function<bool(const Pattern&)> on_success = {});
void ApplyCostModel(CostModel model);
void ApplyDefaultCostModel() {
ApplyCostModel([](const Pattern& pattern) { return pattern.benefit(); });
}
void WalkAllPatterns(std::function<void(const Pattern&)> walk);
private:
const FrozenRewritePatternSet& frozen_patter_list_;
std::unordered_map<OpInfo, std::vector<const RewritePattern*>> patterns_;
std::vector<const RewritePattern*> any_op_patterns_;
};
} // namespace ir
......@@ -13,8 +13,11 @@
// limitations under the License.
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include "paddle/ir/core/operation.h"
namespace ir {
......@@ -22,62 +25,69 @@ namespace ir {
//===----------------------------------------------------------------------===//
// Pattern
//===----------------------------------------------------------------------===//
// Pattern::Pattern(const void* root_val,
// RootKind root_kind,
// const std::vector<std::string>& generated_names,
// PatternBenefit benefit,
// ir::IrContext* context)
// : benefit_(benefit), context_(context), generated_names_(generated_names)
// {}
Pattern::Pattern(const std::string& root_name,
PatternBenefit benefit,
IrContext* context,
const std::vector<std::string>& generated_names)
: op_name_(root_name),
root_kind_(RootKind::OperationName),
benefit_(benefit),
context_(context),
generated_names_(generated_names) {}
: Pattern(context->GetRegisteredOpInfo(root_name).AsOpaquePointer(),
RootKind::OperationInfo,
generated_names,
benefit,
context) {}
Pattern::Pattern(MatchAnyOpTypeTag tag,
PatternBenefit benefit,
ir::IrContext* context,
IrContext* context,
const std::vector<std::string>& generated_names)
: root_kind_(RootKind::Any),
benefit_(benefit),
context_(context),
generated_names_(generated_names) {}
: Pattern(nullptr, RootKind::Any, generated_names, benefit, context) {}
Pattern::Pattern(MatchInterfaceOpTypeTag tag,
ir::TypeId interface_id,
TypeId interface_id,
PatternBenefit benefit,
ir::IrContext* context,
IrContext* context,
const std::vector<std::string>& generated_names)
: interface_id_(interface_id),
root_kind_(RootKind::InterfaceId),
benefit_(benefit),
context_(context),
generated_names_(generated_names) {}
: Pattern(interface_id.AsOpaquePointer(),
RootKind::InterfaceId,
generated_names,
benefit,
context) {}
Pattern::Pattern(MatchTraitOpTypeTag tag,
ir::TypeId trait_id,
TypeId trait_id,
PatternBenefit benefit,
ir::IrContext* context,
IrContext* context,
const std::vector<std::string>& generated_names)
: trait_id_(trait_id),
root_kind_(RootKind::TraitId),
: Pattern(trait_id.AsOpaquePointer(),
RootKind::TraitId,
generated_names,
benefit,
context) {}
Pattern::Pattern(void* root_val,
RootKind root_kind,
const std::vector<std::string>& generated_names,
PatternBenefit benefit,
IrContext* context)
: root_val_(root_val),
root_kind_(root_kind),
benefit_(benefit),
context_(context),
generated_names_(generated_names) {}
context_(context) {
if (generated_names.empty()) return;
generated_ops_.reserve(generated_names.size());
std::transform(generated_names.begin(),
generated_names.end(),
std::back_inserter(generated_ops_),
[context](const std::string& name) {
return context->GetRegisteredOpInfo(name);
});
}
RewritePattern::~RewritePattern() = default;
//===----------------------------------------------------------------------===//
// RewriterBase
//===----------------------------------------------------------------------===//
RewriterBase::~RewriterBase() = default;
// TODO(wilber): value support replace method.
......@@ -113,9 +123,9 @@ void RewriterBase::EraseOp(Operation* op) {
void RewriterBase::ReplaceAllUsesWith(Value from, Value to) {
// from.
// for (mlir::OpOperand& operand : llvm::make_early_inc_range(from.getUses()))
// for (OpOperand& operand : llvm::make_early_inc_range(from.getUses()))
// {
// mlir::Operation* op = operand.getOwner();
// Operation* op = operand.getOwner();
// UpdateRootInPlace(op, [&]() { operand.set(to); });
// }
}
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// The design is mainly from MLIR, very thanks to the greate project.
#pragma once
#include <functional>
......@@ -19,46 +21,78 @@
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/type_id.h"
#include "paddle/ir/core/type_name.h"
#include "paddle/ir/core/value.h"
namespace ir {
#include "paddle/utils/optional.h"
/// The design is mainly from MLIR, very thanks to the greate project.
namespace ir {
/// This class reprensents the benefit of a pattern. The most common
/// unit to use is the `numver of operations` in the pattern.
// This class reprensents the benefit of a pattern. The most common
// unit to use is the `numver of operations` in the pattern.
class PatternBenefit {
public:
PatternBenefit(unsigned val) : val_(val) {} // NOLINT
PatternBenefit() = default;
PatternBenefit(uint32_t val) : val_(val) {} // NOLINT
unsigned benefit() { return val_; }
uint32_t benefit() { return val_; }
bool operator==(const PatternBenefit& rhs) const { return val_ == rhs.val_; }
bool operator!=(const PatternBenefit& rhs) const { return !(*this == rhs); }
bool operator<(const PatternBenefit& rhs) const { return val_ < rhs.val_; }
bool operator>(const PatternBenefit& rhs) const { return rhs < *this; }
bool operator<=(const PatternBenefit& rhs) const { return !(*this > rhs); }
bool operator>=(const PatternBenefit& rhs) const { return !(*this <= rhs); }
bool operator>=(const PatternBenefit& rhs) const { return !(*this < rhs); }
private:
unsigned int val_{0};
uint32_t val_{0};
};
/// This class contains all of the data related to a Pattern, but not contains
/// any methods for the matching. This class is used to interface with the
/// metadata of a pattern, such as benefit or root operation.
// This class contains all of the data related to a Pattern, but not contains
// any methods for the matching. This class is used to interface with the
// metadata of a pattern, such as benefit or root operation.
class Pattern {
enum class RootKind { Any, OperationName, InterfaceId, TraitId };
enum class RootKind {
// The pattern root matches "any" operation.
Any,
// The pattern root is matched using a concrete operation.
OperationInfo,
// The pattern root is matched using an interface id.
InterfaceId,
// The patter root is matched using a trait id.
TraitId
};
public:
const std::vector<OpInfo>& generated_ops() const { return generated_ops_; }
paddle::optional<OpInfo> root_kind() const {
if (root_kind_ == RootKind::OperationInfo)
return OpInfo::RecoverFromOpaquePointer(root_val_);
return paddle::none;
}
paddle::optional<TypeId> GetRootInterfaceID() const {
if (root_kind_ == RootKind::InterfaceId)
return TypeId::RecoverFromOpaquePointer(root_val_);
return paddle::none;
}
paddle::optional<TypeId> GetRootTraitID() const {
if (root_kind_ == RootKind::TraitId)
return TypeId::RecoverFromOpaquePointer(root_val_);
return paddle::none;
}
PatternBenefit benefit() const { return benefit_; }
IrContext* context() const { return context_; }
IrContext* ir_context() const { return context_; }
std::string debug_name() const { return debug_name_; }
......@@ -81,41 +115,39 @@ class Pattern {
Pattern(const std::string& root_name,
PatternBenefit benefit,
ir::IrContext* context,
IrContext* context,
const std::vector<std::string>& generated_names = {});
Pattern(MatchAnyOpTypeTag tag,
PatternBenefit benefit,
ir::IrContext* context,
IrContext* context,
const std::vector<std::string>& generated_names = {});
Pattern(MatchInterfaceOpTypeTag tag,
ir::TypeId interface_id,
TypeId interface_id,
PatternBenefit benefit,
ir::IrContext* context,
IrContext* context,
const std::vector<std::string>& generated_names = {});
Pattern(MatchTraitOpTypeTag tag,
ir::TypeId trait_id,
TypeId trait_id,
PatternBenefit benefit,
ir::IrContext* context,
IrContext* context,
const std::vector<std::string>& generated_names = {});
private:
// TODO(wilber): How to uniform variables and constructor.
// Pattern(const void* root_val,
// RootKind root_kind,
// const std::vector<std::string>& generated_names,
// PatternBenefit benefit,
// ir::IrContext* context);
std::string op_name_;
ir::TypeId interface_id_;
ir::TypeId trait_id_;
Pattern(void* root_val,
RootKind root_kind,
const std::vector<std::string>& generated_names,
PatternBenefit benefit,
IrContext* context);
void* root_val_;
RootKind root_kind_;
const PatternBenefit benefit_;
ir::IrContext* context_;
std::vector<std::string> generated_names_;
IrContext* context_;
std::vector<OpInfo> generated_ops_;
std::string debug_name_;
std::vector<std::string> debug_labels_;
......@@ -127,19 +159,19 @@ class RewritePattern : public Pattern {
public:
virtual ~RewritePattern();
virtual void Rewrite(ir::Operation* op,
virtual void Rewrite(Operation* op,
PatternRewriter& rewriter) const { // NOLINT
throw(
"need to implement either MatchAndRewrite or one of the rewrite "
"functions.");
}
virtual bool Match(ir::Operation* op) const {
virtual bool Match(Operation* op) const {
throw("need to implement either MatchAndRewrite or Match.");
return false;
}
virtual bool MatchAndRewrite(ir::Operation* op,
virtual bool MatchAndRewrite(Operation* op,
PatternRewriter& rewriter) const { // NOLINT
if (Match(op)) {
Rewrite(op, rewriter);
......@@ -157,7 +189,7 @@ class RewritePattern : public Pattern {
pattern->Initialize();
if (pattern->debug_name().empty())
pattern->SetDebugName(get_type_name<T>());
pattern->SetDebugName(ir::get_type_name<T>());
return pattern;
}
......@@ -166,8 +198,8 @@ class RewritePattern : public Pattern {
};
namespace detail {
/// A wrapper around PatternWrite that allows for matching and rewriting
/// against an instance of a derived operation class or Interface.
// A wrapper around PatternWrite that allows for matching and rewriting
// against an instance of a derived operation class or Interface.
template <typename SourceOp>
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
using RewritePattern::RewritePattern;
......@@ -203,30 +235,25 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
};
} // namespace detail
/// OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation
/// class as opposed to a raw Operation.
// OpRewritePattern is a wrapper around RewritePattern that allows for
// matching and rewriting against an instance of a derived operation
// class as opposed to a raw Operation.
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
OpRewritePattern(ir::IrContext* context,
OpRewritePattern(IrContext* context,
PatternBenefit benefit = 1,
const std::vector<std::string>& generated_names = {})
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
"NeedToFix", // TODO(wilber): Need to fix. SourceOp maybe should
// have a getOperationName static method.
benefit,
context,
generated_names) {}
SourceOp::name(), benefit, context, generated_names) {}
};
// TODO(wilber): Support OpInterfaceRewritePattern and OpTraitRewritePattern.
// ...
/// This class provides a series of interfaces for modifying IR and tracking IR
/// changes. This class provides a unified API for IR modification.
///
class RewriterBase { // maybe should inherit OpBuilder.
// This class provides a series of interfaces for modifying IR and tracking IR
// changes. This class provides a unified API for IR modification.
class RewriterBase : public Builder {
public:
// TODO(wilber): Supplementary methods of block and region.
......@@ -262,7 +289,7 @@ class RewriterBase { // maybe should inherit OpBuilder.
std::function<bool(OpOperand&)> functor);
protected:
explicit RewriterBase(IrContext* ctx) : ctx_(ctx) {}
explicit RewriterBase(IrContext* ctx) : Builder(ctx) {}
virtual ~RewriterBase();
......@@ -277,9 +304,6 @@ class RewriterBase { // maybe should inherit OpBuilder.
RewriterBase(const RewriterBase&) = delete;
void ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op);
private:
IrContext* ctx_;
};
class PatternRewriter : public RewriterBase {
......@@ -287,7 +311,7 @@ class PatternRewriter : public RewriterBase {
using RewriterBase::RewriterBase;
};
/// A pattern collection, easy to add patterns.
// A pattern collection, easy to add patterns.
class RewritePatternSet {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
......@@ -299,7 +323,7 @@ class RewritePatternSet {
native_patterns_.emplace_back(std::move(pattern));
}
IrContext* context() const { return context_; }
IrContext* ir_context() const { return context_; }
NativePatternListT& native_patterns() { return native_patterns_; }
......@@ -351,6 +375,8 @@ class RewritePatternSet {
private:
IrContext* const context_;
NativePatternListT native_patterns_;
};
} // namespace ir
......@@ -102,7 +102,7 @@ class Operation1 : public ir::Op<Operation1> {
ir::OperationArgument &argument) { // NOLINT
std::vector<ir::OpResult> inputs = {};
std::vector<ir::Type> output_types = {
ir::Float32Type::get(builder.context())};
ir::Float32Type::get(builder.ir_context())};
std::unordered_map<std::string, ir::Attribute> attributes =
CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"});
......
......@@ -5,6 +5,4 @@ cc_test_old(
DEPS
new_pass
pattern_rewrite
pd_dialect
phi
gtest)
......@@ -80,6 +80,7 @@ class TestPatternRewrite : public ir::OpRewritePattern<Operation1> {
void Rewrite(Operation1 op, ir::PatternRewriter &rewriter) const override {}
bool Match(Operation1 op) const override { return false; }
};
class TestPatternRewrite2 : public ir::OpRewritePattern<Operation1> {
public:
using ir::OpRewritePattern<Operation1>::OpRewritePattern;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册