未验证 提交 548fb821 编写于 作者: Y Yuanle Liu 提交者: GitHub

[IR&PASS] part 3-2: add PatternApplicator and FrozenRewritePatternSet, refine...

[IR&PASS] part 3-2: add PatternApplicator and FrozenRewritePatternSet, refine PatternMatch code, add some api for Builder (#54492)

* [IR&PASS] add PatternApplicator and FrozenRewritePatternSet, refine PatternMatch code, add some api for Builder and TypeId

* fix comment
上级 1b5e1e81
...@@ -23,10 +23,12 @@ namespace ir { ...@@ -23,10 +23,12 @@ namespace ir {
class Operation; class Operation;
class Block { class Block {
using OpListType = std::list<Operation *>;
public: public:
using iterator = std::list<Operation *>::iterator; using iterator = OpListType::iterator;
using reverse_iterator = std::list<Operation *>::reverse_iterator; using reverse_iterator = OpListType::reverse_iterator;
using const_iterator = std::list<Operation *>::const_iterator; using const_iterator = OpListType::const_iterator;
Block() = default; Block() = default;
~Block(); ~Block();
...@@ -60,7 +62,7 @@ class Block { ...@@ -60,7 +62,7 @@ class Block {
private: private:
Region *parent_; // not owned Region *parent_; // not owned
OpListType ops_; // owned
Region::iterator position_; Region::iterator position_;
std::list<Operation *> ops_; // owned
}; };
} // namespace ir } // namespace ir
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
namespace ir { namespace ir {
/// ///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute /// \brief Unified interface of the Attribute class. Derivation of all Attribute
/// classes only derives interfaces, not members. /// classes only derives interfaces, not members.
...@@ -27,11 +28,47 @@ namespace ir { ...@@ -27,11 +28,47 @@ namespace ir {
class Builder { class Builder {
public: public:
Builder(IrContext *context, Block *block, Block::iterator insert_point) Builder(IrContext *context, Block *block, Block::iterator insert_point)
: context_(context), block_(block), insert_point_(insert_point) {} : context_(context) {
SetInsertionPoint(block, insert_point);
}
Builder(IrContext *context, Block *block) Builder(IrContext *context, Block *block)
: Builder(context, block, block->end()) {} : Builder(context, block, block->end()) {}
IrContext *context() const { return context_; } explicit Builder(IrContext *context)
: Builder(context, nullptr, Block::iterator{}) {}
/// Set the insertion point to the specified location.
void SetInsertionPoint(Block *block, Block::iterator insert_point) {
// TODO(liuyuanle): check that insertPoint is in this rather than some other
// block.
this->block_ = block;
this->insert_point_ = insert_point;
}
/// Set the insertion point to the specified operation, which will cause
/// subsequent insertions to go right before it.
void SetInsertionPoint(Operation *op) {
SetInsertionPoint(op->GetParent(), Block::iterator{*op});
}
/// Set the insertion point to the node after the specified operation, which
/// will cause subsequent insertions to go right after it.
void SetInsertionPointAfter(Operation *op) {
SetInsertionPoint(op->GetParent(), std::next(Block::iterator{*op}));
}
/// Set the insertion point to the start of the specified block.
void SetInsertionPointToStart(Block *block) {
SetInsertionPoint(block, block->begin());
}
/// Set the insertion point to the end of the specified block.
void SetInsertionPointToEnd(Block *block) {
SetInsertionPoint(block, block->end());
}
IrContext *ir_context() const { return context_; }
Block *block() const { return block_; } Block *block() const { return block_; }
...@@ -57,8 +94,9 @@ class Builder { ...@@ -57,8 +94,9 @@ class Builder {
Operation *Insert(Operation *op); Operation *Insert(Operation *op);
IrContext *context_; IrContext *context_;
Block *block_ = nullptr; Block *block_;
// The insertion point within the list that this builder is inserting before. // The insertion point within the list that this builder is inserting before.
Block::iterator insert_point_; Block::iterator insert_point_;
}; };
} // namespace ir } // namespace ir
...@@ -142,8 +142,8 @@ class ConstructInterfacesOrTraits { ...@@ -142,8 +142,8 @@ class ConstructInterfacesOrTraits {
static void PlacementConstrctInterface( static void PlacementConstrctInterface(
InterfaceValue *&p_interface) { // NOLINT InterfaceValue *&p_interface) { // NOLINT
p_interface->swap(InterfaceValue::get<ConcreteOp, T>()); p_interface->swap(InterfaceValue::get<ConcreteOp, T>());
VLOG(4) << "New a interface: id[" << (p_interface->type_id()).storage() VLOG(4) << "New a interface: id["
<< "]."; << (p_interface->type_id()).AsOpaquePointer() << "].";
++p_interface; ++p_interface;
} }
...@@ -151,7 +151,7 @@ class ConstructInterfacesOrTraits { ...@@ -151,7 +151,7 @@ class ConstructInterfacesOrTraits {
template <typename T> template <typename T>
static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT
*p_trait = TypeId::get<T>(); *p_trait = TypeId::get<T>();
VLOG(4) << "New a trait: id[" << p_trait->storage() << "]."; VLOG(4) << "New a trait: id[" << p_trait->AsOpaquePointer() << "].";
++p_trait; ++p_trait;
} }
}; };
...@@ -206,4 +206,5 @@ class Op : public OpBase { ...@@ -206,4 +206,5 @@ class Op : public OpBase {
return trait_set; return trait_set;
} }
}; };
} // namespace ir } // namespace ir
...@@ -71,15 +71,15 @@ class OpInfo { ...@@ -71,15 +71,15 @@ class OpInfo {
typename Interface::Concept *GetInterfaceImpl() const; typename Interface::Concept *GetInterfaceImpl() const;
void *AsOpaquePointer() const { return impl_; } void *AsOpaquePointer() const { return impl_; }
static OpInfo RecoverFromOpaquePointer(void *impl) { static OpInfo RecoverFromOpaquePointer(void *pointer) {
return static_cast<OpInfoImpl *>(impl); return OpInfo(static_cast<OpInfoImpl *>(pointer));
} }
friend class OpInfoImpl; friend class OpInfoImpl;
friend struct std::hash<OpInfo>; friend struct std::hash<OpInfo>;
private: private:
OpInfo(OpInfoImpl *impl) : impl_(impl) {} // NOLINT explicit OpInfo(OpInfoImpl *impl) : impl_(impl) {}
void *GetInterfaceImpl(TypeId interface_id) const; void *GetInterfaceImpl(TypeId interface_id) const;
private: private:
......
...@@ -50,14 +50,14 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, ...@@ -50,14 +50,14 @@ OpInfo OpInfoImpl::Create(Dialect *dialect,
} }
// Construct OpInfoImpl. // Construct OpInfoImpl.
VLOG(4) << "Construct OpInfoImpl at " << base_ptr << " ......"; VLOG(4) << "Construct OpInfoImpl at " << base_ptr << " ......";
OpInfo op_info = new (base_ptr) OpInfoImpl(dialect, OpInfo op_info = OpInfo(new (base_ptr) OpInfoImpl(dialect,
op_id, op_id,
op_name, op_name,
interfaces_num, interfaces_num,
traits_num, traits_num,
attributes_num, attributes_num,
attributes_name, attributes_name,
verify); verify));
return op_info; return op_info;
} }
void OpInfoImpl::Destroy(OpInfo info) { void OpInfoImpl::Destroy(OpInfo info) {
......
...@@ -51,7 +51,10 @@ class TypeId { ...@@ -51,7 +51,10 @@ class TypeId {
TypeId &operator=(const TypeId &other) = default; TypeId &operator=(const TypeId &other) = default;
const Storage *storage() const { return storage_; } void *AsOpaquePointer() const { return storage_; }
static TypeId RecoverFromOpaquePointer(void *pointer) {
return TypeId(static_cast<Storage *>(pointer));
}
/// ///
/// \brief Comparison operations. /// \brief Comparison operations.
...@@ -77,10 +80,11 @@ class TypeId { ...@@ -77,10 +80,11 @@ class TypeId {
/// ///
/// \param storage The storage of this TypeId. /// \param storage The storage of this TypeId.
/// ///
explicit TypeId(const Storage *storage) : storage_(storage) {} explicit TypeId(Storage *storage) : storage_(storage) {}
const Storage *storage_{nullptr}; Storage *storage_{nullptr};
}; };
} // namespace ir } // namespace ir
namespace std { namespace std {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include <memory>
#include <optional>
#include <set>
#include <string>
#include "paddle/ir/core/op_info.h"
#include "paddle/utils/optional.h"
namespace ir {
FrozenRewritePatternSet::FrozenRewritePatternSet()
: impl_(std::make_shared<Impl>()) {}
FrozenRewritePatternSet::FrozenRewritePatternSet(
RewritePatternSet&& patterns,
const std::vector<std::string>& disabled_pattern_labels,
const std::vector<std::string>& enabled_pattern_labels)
: impl_(std::make_shared<Impl>()) {
std::set<std::string> disabled_patterns, enabled_patterns;
disabled_patterns.insert(disabled_pattern_labels.begin(),
disabled_pattern_labels.end());
enabled_patterns.insert(enabled_pattern_labels.begin(),
enabled_pattern_labels.end());
ir::OpInfoMap op_info_map;
auto AddToOpsWhen = [&](std::unique_ptr<RewritePattern>& pattern,
std::function<bool(OpInfo)> callback) {
if (op_info_map.empty())
op_info_map = pattern->ir_context()->registered_op_info_map();
for (auto& info_map : op_info_map) {
if (callback(info_map.second))
impl_->op_specific_native_pattern_map_[info_map.second].push_back(
pattern.get());
impl_->op_specific_native_patterns_.push_back(std::move(pattern));
}
};
for (std::unique_ptr<RewritePattern>& pat : patterns.native_patterns()) {
// Don't add patterns that haven't been enabled by the user.
if (!enabled_patterns.empty()) {
auto IsEnableFn = [&](const std::string& label) {
return enabled_patterns.count(label);
};
if (!IsEnableFn(pat->debug_name()) &&
std::none_of(pat->debug_labels().begin(),
pat->debug_labels().end(),
IsEnableFn))
continue;
}
// Don't add patterns that have been disabled by the user.
if (!disabled_patterns.empty()) {
auto IsDisabledFn = [&](const std::string& label) {
return disabled_patterns.count(label);
};
if (IsDisabledFn(pat->debug_name()) ||
std::any_of(pat->debug_labels().begin(),
pat->debug_labels().end(),
IsDisabledFn))
continue;
}
if (paddle::optional<OpInfo> root_name = pat->root_kind()) {
impl_->op_specific_native_pattern_map_[*root_name].push_back(pat.get());
impl_->op_specific_native_patterns_.push_back(std::move(pat));
continue;
}
if (paddle::optional<TypeId> interface_id = pat->GetRootInterfaceID()) {
AddToOpsWhen(
pat, [&](OpInfo info) { return info.HasInterface(*interface_id); });
continue;
}
if (paddle::optional<TypeId> trait_id = pat->GetRootTraitID()) {
AddToOpsWhen(pat, [&](OpInfo info) { return info.HasTrait(*trait_id); });
continue;
}
impl_->match_any_op_native_patterns_.push_back(std::move(pat));
}
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// The design and code is mainly from MLIR, thanks to the greate project.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace ir {
class FrozenRewritePatternSet {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
using OpSpecificNativePatternListT =
std::unordered_map<OpInfo, std::vector<RewritePattern*>>;
FrozenRewritePatternSet();
FrozenRewritePatternSet(FrozenRewritePatternSet&& patterns) = default;
FrozenRewritePatternSet(const FrozenRewritePatternSet& patterns) = default;
FrozenRewritePatternSet& operator=(FrozenRewritePatternSet&& patterns) =
default;
FrozenRewritePatternSet& operator=(const FrozenRewritePatternSet& patterns) =
default;
~FrozenRewritePatternSet() = default;
/// Freeze the patterns held in `patterns`, and take ownership.
FrozenRewritePatternSet(
RewritePatternSet&& patterns,
const std::vector<std::string>& disabled_pattern_labels = {},
const std::vector<std::string>& enabled_pattern_labels = {});
/// Return the op specific native patterns held by this list.
const OpSpecificNativePatternListT& op_specific_native_patterns() const {
return impl_->op_specific_native_pattern_map_;
}
/// Return the "match any" native patterns held by this list.
const NativePatternListT& match_any_op_native_patterns() const {
return impl_->match_any_op_native_patterns_;
}
private:
struct Impl {
OpSpecificNativePatternListT op_specific_native_pattern_map_;
NativePatternListT op_specific_native_patterns_;
NativePatternListT match_any_op_native_patterns_;
};
std::shared_ptr<Impl> impl_;
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace ir {
PatternApplicator::PatternApplicator(
const FrozenRewritePatternSet& frozen_patter_list)
: frozen_patter_list_(frozen_patter_list) {}
void PatternApplicator::ApplyCostModel(CostModel model) {
// TODO(wilber): remove impossible patterns.
patterns_.clear();
for (const auto& it : frozen_patter_list_.op_specific_native_patterns()) {
for (const RewritePattern* pattern : it.second) {
patterns_[it.first].push_back(pattern);
}
}
any_op_patterns_.clear();
for (auto& pattern : frozen_patter_list_.match_any_op_native_patterns()) {
any_op_patterns_.push_back(pattern.get());
}
// Sort by benefit based on the cost model.
std::unordered_map<const Pattern*, PatternBenefit> benefits;
auto cmp = [&benefits](const Pattern* lhs, const Pattern* rhs) {
return benefits[lhs] > benefits[rhs];
};
auto ProcessPatternList = [&](std::vector<const RewritePattern*>& list) {
if (list.size() == 1) return;
benefits.clear();
for (const Pattern* pat : list) benefits.emplace(pat, model(*pat));
std::stable_sort(list.begin(), list.end(), cmp);
};
for (auto& it : patterns_) {
ProcessPatternList(it.second);
}
ProcessPatternList(any_op_patterns_);
}
void PatternApplicator::WalkAllPatterns(
std::function<void(const Pattern&)> walk) {
for (const auto& it : frozen_patter_list_.op_specific_native_patterns())
for (auto* pattern : it.second) walk(*pattern);
for (auto& it : frozen_patter_list_.match_any_op_native_patterns()) walk(*it);
}
bool PatternApplicator::MatchAndRewrite(
Operation* op,
PatternRewriter& rewriter,
std::function<bool(const Pattern&)> can_apply,
std::function<void(const Pattern&)> on_failure,
std::function<bool(const Pattern&)> on_success) {
// whether there are patterns matching this operation type.
std::vector<const RewritePattern*> op_patterns;
auto pattern_it = patterns_.find(op->info());
if (pattern_it != patterns_.end()) op_patterns = pattern_it->second;
unsigned op_it = 0, op_e = op_patterns.size();
unsigned any_it = 0, any_e = any_op_patterns_.size();
bool result = false;
do {
// Find the next pattern with the highest benefit.
const Pattern* best_pattern = nullptr;
unsigned* best_pattern_it = &op_it;
// For specific patterns
if (op_it < op_e) best_pattern = op_patterns[op_it];
// For op-agnostic patterns
if (any_it < any_e &&
(!best_pattern ||
best_pattern->benefit() < any_op_patterns_[any_it]->benefit())) {
best_pattern_it = &any_it;
best_pattern = any_op_patterns_[any_it];
}
if (!best_pattern) break;
// Update the pattern iterator, so that this pattern isn't attempted again.
++(*best_pattern_it);
if (can_apply && !can_apply(*best_pattern)) continue;
rewriter.SetInsertionPoint(op);
const auto* pattern = static_cast<const RewritePattern*>(best_pattern);
result = pattern->MatchAndRewrite(op, rewriter);
if (result && on_success && !on_success(*best_pattern)) result = false;
if (result) break;
if (on_failure) on_failure(*best_pattern);
} while (true);
return result;
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The design and code is mainly from MLIR, thanks to the greate project.
#pragma once
#include <functional>
#include <unordered_map>
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace ir {
class PatternApplicator {
public:
using CostModel = std::function<PatternBenefit(const Pattern&)>;
explicit PatternApplicator(const FrozenRewritePatternSet& frozen_patter_list);
~PatternApplicator() = default;
bool MatchAndRewrite(Operation* op,
PatternRewriter& rewriter, // NOLINT
std::function<bool(const Pattern&)> can_apply = {},
std::function<void(const Pattern&)> on_failure = {},
std::function<bool(const Pattern&)> on_success = {});
void ApplyCostModel(CostModel model);
void ApplyDefaultCostModel() {
ApplyCostModel([](const Pattern& pattern) { return pattern.benefit(); });
}
void WalkAllPatterns(std::function<void(const Pattern&)> walk);
private:
const FrozenRewritePatternSet& frozen_patter_list_;
std::unordered_map<OpInfo, std::vector<const RewritePattern*>> patterns_;
std::vector<const RewritePattern*> any_op_patterns_;
};
} // namespace ir
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/pattern_rewrite/pattern_match.h" #include "paddle/ir/pattern_rewrite/pattern_match.h"
#include <algorithm>
#include <cassert> #include <cassert>
#include <cstdint> #include <cstdint>
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
namespace ir { namespace ir {
...@@ -22,62 +25,69 @@ namespace ir { ...@@ -22,62 +25,69 @@ namespace ir {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pattern // Pattern
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pattern::Pattern(const void* root_val,
// RootKind root_kind,
// const std::vector<std::string>& generated_names,
// PatternBenefit benefit,
// ir::IrContext* context)
// : benefit_(benefit), context_(context), generated_names_(generated_names)
// {}
Pattern::Pattern(const std::string& root_name, Pattern::Pattern(const std::string& root_name,
PatternBenefit benefit, PatternBenefit benefit,
IrContext* context, IrContext* context,
const std::vector<std::string>& generated_names) const std::vector<std::string>& generated_names)
: op_name_(root_name), : Pattern(context->GetRegisteredOpInfo(root_name).AsOpaquePointer(),
root_kind_(RootKind::OperationName), RootKind::OperationInfo,
benefit_(benefit), generated_names,
context_(context), benefit,
generated_names_(generated_names) {} context) {}
Pattern::Pattern(MatchAnyOpTypeTag tag, Pattern::Pattern(MatchAnyOpTypeTag tag,
PatternBenefit benefit, PatternBenefit benefit,
ir::IrContext* context, IrContext* context,
const std::vector<std::string>& generated_names) const std::vector<std::string>& generated_names)
: root_kind_(RootKind::Any), : Pattern(nullptr, RootKind::Any, generated_names, benefit, context) {}
benefit_(benefit),
context_(context),
generated_names_(generated_names) {}
Pattern::Pattern(MatchInterfaceOpTypeTag tag, Pattern::Pattern(MatchInterfaceOpTypeTag tag,
ir::TypeId interface_id, TypeId interface_id,
PatternBenefit benefit, PatternBenefit benefit,
ir::IrContext* context, IrContext* context,
const std::vector<std::string>& generated_names) const std::vector<std::string>& generated_names)
: interface_id_(interface_id), : Pattern(interface_id.AsOpaquePointer(),
root_kind_(RootKind::InterfaceId), RootKind::InterfaceId,
benefit_(benefit), generated_names,
context_(context), benefit,
generated_names_(generated_names) {} context) {}
Pattern::Pattern(MatchTraitOpTypeTag tag, Pattern::Pattern(MatchTraitOpTypeTag tag,
ir::TypeId trait_id, TypeId trait_id,
PatternBenefit benefit, PatternBenefit benefit,
ir::IrContext* context, IrContext* context,
const std::vector<std::string>& generated_names) const std::vector<std::string>& generated_names)
: trait_id_(trait_id), : Pattern(trait_id.AsOpaquePointer(),
root_kind_(RootKind::TraitId), RootKind::TraitId,
generated_names,
benefit,
context) {}
Pattern::Pattern(void* root_val,
RootKind root_kind,
const std::vector<std::string>& generated_names,
PatternBenefit benefit,
IrContext* context)
: root_val_(root_val),
root_kind_(root_kind),
benefit_(benefit), benefit_(benefit),
context_(context), context_(context) {
generated_names_(generated_names) {} if (generated_names.empty()) return;
generated_ops_.reserve(generated_names.size());
std::transform(generated_names.begin(),
generated_names.end(),
std::back_inserter(generated_ops_),
[context](const std::string& name) {
return context->GetRegisteredOpInfo(name);
});
}
RewritePattern::~RewritePattern() = default; RewritePattern::~RewritePattern() = default;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// RewriterBase // RewriterBase
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
RewriterBase::~RewriterBase() = default; RewriterBase::~RewriterBase() = default;
// TODO(wilber): value support replace method. // TODO(wilber): value support replace method.
...@@ -113,9 +123,9 @@ void RewriterBase::EraseOp(Operation* op) { ...@@ -113,9 +123,9 @@ void RewriterBase::EraseOp(Operation* op) {
void RewriterBase::ReplaceAllUsesWith(Value from, Value to) { void RewriterBase::ReplaceAllUsesWith(Value from, Value to) {
// from. // from.
// for (mlir::OpOperand& operand : llvm::make_early_inc_range(from.getUses())) // for (OpOperand& operand : llvm::make_early_inc_range(from.getUses()))
// { // {
// mlir::Operation* op = operand.getOwner(); // Operation* op = operand.getOwner();
// UpdateRootInPlace(op, [&]() { operand.set(to); }); // UpdateRootInPlace(op, [&]() { operand.set(to); });
// } // }
} }
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// The design is mainly from MLIR, very thanks to the greate project.
#pragma once #pragma once
#include <functional> #include <functional>
...@@ -19,46 +21,78 @@ ...@@ -19,46 +21,78 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <vector> #include <vector>
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/type_id.h" #include "paddle/ir/core/type_id.h"
#include "paddle/ir/core/type_name.h" #include "paddle/ir/core/type_name.h"
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
namespace ir { #include "paddle/utils/optional.h"
/// The design is mainly from MLIR, very thanks to the greate project. namespace ir {
/// This class reprensents the benefit of a pattern. The most common // This class reprensents the benefit of a pattern. The most common
/// unit to use is the `numver of operations` in the pattern. // unit to use is the `numver of operations` in the pattern.
class PatternBenefit { class PatternBenefit {
public: public:
PatternBenefit(unsigned val) : val_(val) {} // NOLINT PatternBenefit() = default;
PatternBenefit(uint32_t val) : val_(val) {} // NOLINT
unsigned benefit() { return val_; } uint32_t benefit() { return val_; }
bool operator==(const PatternBenefit& rhs) const { return val_ == rhs.val_; } bool operator==(const PatternBenefit& rhs) const { return val_ == rhs.val_; }
bool operator!=(const PatternBenefit& rhs) const { return !(*this == rhs); } bool operator!=(const PatternBenefit& rhs) const { return !(*this == rhs); }
bool operator<(const PatternBenefit& rhs) const { return val_ < rhs.val_; } bool operator<(const PatternBenefit& rhs) const { return val_ < rhs.val_; }
bool operator>(const PatternBenefit& rhs) const { return rhs < *this; } bool operator>(const PatternBenefit& rhs) const { return rhs < *this; }
bool operator<=(const PatternBenefit& rhs) const { return !(*this > rhs); } bool operator<=(const PatternBenefit& rhs) const { return !(*this > rhs); }
bool operator>=(const PatternBenefit& rhs) const { return !(*this <= rhs); } bool operator>=(const PatternBenefit& rhs) const { return !(*this < rhs); }
private: private:
unsigned int val_{0}; uint32_t val_{0};
}; };
/// This class contains all of the data related to a Pattern, but not contains // This class contains all of the data related to a Pattern, but not contains
/// any methods for the matching. This class is used to interface with the // any methods for the matching. This class is used to interface with the
/// metadata of a pattern, such as benefit or root operation. // metadata of a pattern, such as benefit or root operation.
class Pattern { class Pattern {
enum class RootKind { Any, OperationName, InterfaceId, TraitId }; enum class RootKind {
// The pattern root matches "any" operation.
Any,
// The pattern root is matched using a concrete operation.
OperationInfo,
// The pattern root is matched using an interface id.
InterfaceId,
// The patter root is matched using a trait id.
TraitId
};
public: public:
const std::vector<OpInfo>& generated_ops() const { return generated_ops_; }
paddle::optional<OpInfo> root_kind() const {
if (root_kind_ == RootKind::OperationInfo)
return OpInfo::RecoverFromOpaquePointer(root_val_);
return paddle::none;
}
paddle::optional<TypeId> GetRootInterfaceID() const {
if (root_kind_ == RootKind::InterfaceId)
return TypeId::RecoverFromOpaquePointer(root_val_);
return paddle::none;
}
paddle::optional<TypeId> GetRootTraitID() const {
if (root_kind_ == RootKind::TraitId)
return TypeId::RecoverFromOpaquePointer(root_val_);
return paddle::none;
}
PatternBenefit benefit() const { return benefit_; } PatternBenefit benefit() const { return benefit_; }
IrContext* context() const { return context_; } IrContext* ir_context() const { return context_; }
std::string debug_name() const { return debug_name_; } std::string debug_name() const { return debug_name_; }
...@@ -81,41 +115,39 @@ class Pattern { ...@@ -81,41 +115,39 @@ class Pattern {
Pattern(const std::string& root_name, Pattern(const std::string& root_name,
PatternBenefit benefit, PatternBenefit benefit,
ir::IrContext* context, IrContext* context,
const std::vector<std::string>& generated_names = {}); const std::vector<std::string>& generated_names = {});
Pattern(MatchAnyOpTypeTag tag, Pattern(MatchAnyOpTypeTag tag,
PatternBenefit benefit, PatternBenefit benefit,
ir::IrContext* context, IrContext* context,
const std::vector<std::string>& generated_names = {}); const std::vector<std::string>& generated_names = {});
Pattern(MatchInterfaceOpTypeTag tag, Pattern(MatchInterfaceOpTypeTag tag,
ir::TypeId interface_id, TypeId interface_id,
PatternBenefit benefit, PatternBenefit benefit,
ir::IrContext* context, IrContext* context,
const std::vector<std::string>& generated_names = {}); const std::vector<std::string>& generated_names = {});
Pattern(MatchTraitOpTypeTag tag, Pattern(MatchTraitOpTypeTag tag,
ir::TypeId trait_id, TypeId trait_id,
PatternBenefit benefit, PatternBenefit benefit,
ir::IrContext* context, IrContext* context,
const std::vector<std::string>& generated_names = {}); const std::vector<std::string>& generated_names = {});
private: private:
// TODO(wilber): How to uniform variables and constructor. Pattern(void* root_val,
// Pattern(const void* root_val, RootKind root_kind,
// RootKind root_kind, const std::vector<std::string>& generated_names,
// const std::vector<std::string>& generated_names, PatternBenefit benefit,
// PatternBenefit benefit, IrContext* context);
// ir::IrContext* context);
std::string op_name_; void* root_val_;
ir::TypeId interface_id_;
ir::TypeId trait_id_;
RootKind root_kind_; RootKind root_kind_;
const PatternBenefit benefit_; const PatternBenefit benefit_;
ir::IrContext* context_; IrContext* context_;
std::vector<std::string> generated_names_; std::vector<OpInfo> generated_ops_;
std::string debug_name_; std::string debug_name_;
std::vector<std::string> debug_labels_; std::vector<std::string> debug_labels_;
...@@ -127,19 +159,19 @@ class RewritePattern : public Pattern { ...@@ -127,19 +159,19 @@ class RewritePattern : public Pattern {
public: public:
virtual ~RewritePattern(); virtual ~RewritePattern();
virtual void Rewrite(ir::Operation* op, virtual void Rewrite(Operation* op,
PatternRewriter& rewriter) const { // NOLINT PatternRewriter& rewriter) const { // NOLINT
throw( throw(
"need to implement either MatchAndRewrite or one of the rewrite " "need to implement either MatchAndRewrite or one of the rewrite "
"functions."); "functions.");
} }
virtual bool Match(ir::Operation* op) const { virtual bool Match(Operation* op) const {
throw("need to implement either MatchAndRewrite or Match."); throw("need to implement either MatchAndRewrite or Match.");
return false; return false;
} }
virtual bool MatchAndRewrite(ir::Operation* op, virtual bool MatchAndRewrite(Operation* op,
PatternRewriter& rewriter) const { // NOLINT PatternRewriter& rewriter) const { // NOLINT
if (Match(op)) { if (Match(op)) {
Rewrite(op, rewriter); Rewrite(op, rewriter);
...@@ -157,7 +189,7 @@ class RewritePattern : public Pattern { ...@@ -157,7 +189,7 @@ class RewritePattern : public Pattern {
pattern->Initialize(); pattern->Initialize();
if (pattern->debug_name().empty()) if (pattern->debug_name().empty())
pattern->SetDebugName(get_type_name<T>()); pattern->SetDebugName(ir::get_type_name<T>());
return pattern; return pattern;
} }
...@@ -166,8 +198,8 @@ class RewritePattern : public Pattern { ...@@ -166,8 +198,8 @@ class RewritePattern : public Pattern {
}; };
namespace detail { namespace detail {
/// A wrapper around PatternWrite that allows for matching and rewriting // A wrapper around PatternWrite that allows for matching and rewriting
/// against an instance of a derived operation class or Interface. // against an instance of a derived operation class or Interface.
template <typename SourceOp> template <typename SourceOp>
struct OpOrInterfaceRewritePatternBase : public RewritePattern { struct OpOrInterfaceRewritePatternBase : public RewritePattern {
using RewritePattern::RewritePattern; using RewritePattern::RewritePattern;
...@@ -203,30 +235,25 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern { ...@@ -203,30 +235,25 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
}; };
} // namespace detail } // namespace detail
/// OpRewritePattern is a wrapper around RewritePattern that allows for // OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation // matching and rewriting against an instance of a derived operation
/// class as opposed to a raw Operation. // class as opposed to a raw Operation.
template <typename SourceOp> template <typename SourceOp>
struct OpRewritePattern struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> { : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
OpRewritePattern(ir::IrContext* context, OpRewritePattern(IrContext* context,
PatternBenefit benefit = 1, PatternBenefit benefit = 1,
const std::vector<std::string>& generated_names = {}) const std::vector<std::string>& generated_names = {})
: detail::OpOrInterfaceRewritePatternBase<SourceOp>( : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
"NeedToFix", // TODO(wilber): Need to fix. SourceOp maybe should SourceOp::name(), benefit, context, generated_names) {}
// have a getOperationName static method.
benefit,
context,
generated_names) {}
}; };
// TODO(wilber): Support OpInterfaceRewritePattern and OpTraitRewritePattern. // TODO(wilber): Support OpInterfaceRewritePattern and OpTraitRewritePattern.
// ... // ...
/// This class provides a series of interfaces for modifying IR and tracking IR // This class provides a series of interfaces for modifying IR and tracking IR
/// changes. This class provides a unified API for IR modification. // changes. This class provides a unified API for IR modification.
/// class RewriterBase : public Builder {
class RewriterBase { // maybe should inherit OpBuilder.
public: public:
// TODO(wilber): Supplementary methods of block and region. // TODO(wilber): Supplementary methods of block and region.
...@@ -262,7 +289,7 @@ class RewriterBase { // maybe should inherit OpBuilder. ...@@ -262,7 +289,7 @@ class RewriterBase { // maybe should inherit OpBuilder.
std::function<bool(OpOperand&)> functor); std::function<bool(OpOperand&)> functor);
protected: protected:
explicit RewriterBase(IrContext* ctx) : ctx_(ctx) {} explicit RewriterBase(IrContext* ctx) : Builder(ctx) {}
virtual ~RewriterBase(); virtual ~RewriterBase();
...@@ -277,9 +304,6 @@ class RewriterBase { // maybe should inherit OpBuilder. ...@@ -277,9 +304,6 @@ class RewriterBase { // maybe should inherit OpBuilder.
RewriterBase(const RewriterBase&) = delete; RewriterBase(const RewriterBase&) = delete;
void ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op); void ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op);
private:
IrContext* ctx_;
}; };
class PatternRewriter : public RewriterBase { class PatternRewriter : public RewriterBase {
...@@ -287,7 +311,7 @@ class PatternRewriter : public RewriterBase { ...@@ -287,7 +311,7 @@ class PatternRewriter : public RewriterBase {
using RewriterBase::RewriterBase; using RewriterBase::RewriterBase;
}; };
/// A pattern collection, easy to add patterns. // A pattern collection, easy to add patterns.
class RewritePatternSet { class RewritePatternSet {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
...@@ -299,7 +323,7 @@ class RewritePatternSet { ...@@ -299,7 +323,7 @@ class RewritePatternSet {
native_patterns_.emplace_back(std::move(pattern)); native_patterns_.emplace_back(std::move(pattern));
} }
IrContext* context() const { return context_; } IrContext* ir_context() const { return context_; }
NativePatternListT& native_patterns() { return native_patterns_; } NativePatternListT& native_patterns() { return native_patterns_; }
...@@ -351,6 +375,8 @@ class RewritePatternSet { ...@@ -351,6 +375,8 @@ class RewritePatternSet {
private: private:
IrContext* const context_; IrContext* const context_;
NativePatternListT native_patterns_; NativePatternListT native_patterns_;
}; };
} // namespace ir } // namespace ir
...@@ -102,7 +102,7 @@ class Operation1 : public ir::Op<Operation1> { ...@@ -102,7 +102,7 @@ class Operation1 : public ir::Op<Operation1> {
ir::OperationArgument &argument) { // NOLINT ir::OperationArgument &argument) { // NOLINT
std::vector<ir::OpResult> inputs = {}; std::vector<ir::OpResult> inputs = {};
std::vector<ir::Type> output_types = { std::vector<ir::Type> output_types = {
ir::Float32Type::get(builder.context())}; ir::Float32Type::get(builder.ir_context())};
std::unordered_map<std::string, ir::Attribute> attributes = std::unordered_map<std::string, ir::Attribute> attributes =
CreateAttributeMap({"op1_attr1", "op1_attr2"}, CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"}); {"op1_attr1", "op1_attr2"});
......
...@@ -5,6 +5,4 @@ cc_test_old( ...@@ -5,6 +5,4 @@ cc_test_old(
DEPS DEPS
new_pass new_pass
pattern_rewrite pattern_rewrite
pd_dialect
phi
gtest) gtest)
...@@ -80,6 +80,7 @@ class TestPatternRewrite : public ir::OpRewritePattern<Operation1> { ...@@ -80,6 +80,7 @@ class TestPatternRewrite : public ir::OpRewritePattern<Operation1> {
void Rewrite(Operation1 op, ir::PatternRewriter &rewriter) const override {} void Rewrite(Operation1 op, ir::PatternRewriter &rewriter) const override {}
bool Match(Operation1 op) const override { return false; } bool Match(Operation1 op) const override { return false; }
}; };
class TestPatternRewrite2 : public ir::OpRewritePattern<Operation1> { class TestPatternRewrite2 : public ir::OpRewritePattern<Operation1> {
public: public:
using ir::OpRewritePattern<Operation1>::OpRewritePattern; using ir::OpRewritePattern<Operation1>::OpRewritePattern;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册