diff --git a/paddle/ir/core/block.h b/paddle/ir/core/block.h index ef87ac2871e312dbb6c01d3b3b56e2d28ce0ece8..24b48036916c274eb3bc650a3dd02cc104aadcf9 100644 --- a/paddle/ir/core/block.h +++ b/paddle/ir/core/block.h @@ -23,10 +23,12 @@ namespace ir { class Operation; class Block { + using OpListType = std::list; + public: - using iterator = std::list::iterator; - using reverse_iterator = std::list::reverse_iterator; - using const_iterator = std::list::const_iterator; + using iterator = OpListType::iterator; + using reverse_iterator = OpListType::reverse_iterator; + using const_iterator = OpListType::const_iterator; Block() = default; ~Block(); @@ -60,7 +62,7 @@ class Block { private: Region *parent_; // not owned + OpListType ops_; // owned Region::iterator position_; - std::list ops_; // owned }; } // namespace ir diff --git a/paddle/ir/core/builder.h b/paddle/ir/core/builder.h index 92b7b23f5ef14bcd83066aae2cca8c28c6ae4440..0e7f7b427f6bdf90acb072f53f5db3bf3e069598 100644 --- a/paddle/ir/core/builder.h +++ b/paddle/ir/core/builder.h @@ -20,6 +20,7 @@ #include "paddle/ir/core/operation.h" namespace ir { + /// /// \brief Unified interface of the Attribute class. Derivation of all Attribute /// classes only derives interfaces, not members. @@ -27,11 +28,47 @@ namespace ir { class Builder { public: Builder(IrContext *context, Block *block, Block::iterator insert_point) - : context_(context), block_(block), insert_point_(insert_point) {} + : context_(context) { + SetInsertionPoint(block, insert_point); + } + Builder(IrContext *context, Block *block) : Builder(context, block, block->end()) {} - IrContext *context() const { return context_; } + explicit Builder(IrContext *context) + : Builder(context, nullptr, Block::iterator{}) {} + + /// Set the insertion point to the specified location. + void SetInsertionPoint(Block *block, Block::iterator insert_point) { + // TODO(liuyuanle): check that insertPoint is in this rather than some other + // block. + this->block_ = block; + this->insert_point_ = insert_point; + } + + /// Set the insertion point to the specified operation, which will cause + /// subsequent insertions to go right before it. + void SetInsertionPoint(Operation *op) { + SetInsertionPoint(op->GetParent(), Block::iterator{*op}); + } + + /// Set the insertion point to the node after the specified operation, which + /// will cause subsequent insertions to go right after it. + void SetInsertionPointAfter(Operation *op) { + SetInsertionPoint(op->GetParent(), std::next(Block::iterator{*op})); + } + + /// Set the insertion point to the start of the specified block. + void SetInsertionPointToStart(Block *block) { + SetInsertionPoint(block, block->begin()); + } + + /// Set the insertion point to the end of the specified block. + void SetInsertionPointToEnd(Block *block) { + SetInsertionPoint(block, block->end()); + } + + IrContext *ir_context() const { return context_; } Block *block() const { return block_; } @@ -57,8 +94,9 @@ class Builder { Operation *Insert(Operation *op); IrContext *context_; - Block *block_ = nullptr; + Block *block_; // The insertion point within the list that this builder is inserting before. Block::iterator insert_point_; }; + } // namespace ir diff --git a/paddle/ir/core/op_base.h b/paddle/ir/core/op_base.h index 271ba135aacbaffff82a1ce548d1395c5490bc47..6df6e3a195655a03a639309bedab6c9425c1ae89 100644 --- a/paddle/ir/core/op_base.h +++ b/paddle/ir/core/op_base.h @@ -142,8 +142,8 @@ class ConstructInterfacesOrTraits { static void PlacementConstrctInterface( InterfaceValue *&p_interface) { // NOLINT p_interface->swap(InterfaceValue::get()); - VLOG(4) << "New a interface: id[" << (p_interface->type_id()).storage() - << "]."; + VLOG(4) << "New a interface: id[" + << (p_interface->type_id()).AsOpaquePointer() << "]."; ++p_interface; } @@ -151,7 +151,7 @@ class ConstructInterfacesOrTraits { template static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT *p_trait = TypeId::get(); - VLOG(4) << "New a trait: id[" << p_trait->storage() << "]."; + VLOG(4) << "New a trait: id[" << p_trait->AsOpaquePointer() << "]."; ++p_trait; } }; @@ -206,4 +206,5 @@ class Op : public OpBase { return trait_set; } }; + } // namespace ir diff --git a/paddle/ir/core/op_info.h b/paddle/ir/core/op_info.h index 345f6c984d9ffa87956777a98ca9ec6de4c95178..5b77aa8de3ec6dc4a932a33919c6a3e216d7bfb5 100644 --- a/paddle/ir/core/op_info.h +++ b/paddle/ir/core/op_info.h @@ -71,15 +71,15 @@ class OpInfo { typename Interface::Concept *GetInterfaceImpl() const; void *AsOpaquePointer() const { return impl_; } - static OpInfo RecoverFromOpaquePointer(void *impl) { - return static_cast(impl); + static OpInfo RecoverFromOpaquePointer(void *pointer) { + return OpInfo(static_cast(pointer)); } friend class OpInfoImpl; friend struct std::hash; private: - OpInfo(OpInfoImpl *impl) : impl_(impl) {} // NOLINT + explicit OpInfo(OpInfoImpl *impl) : impl_(impl) {} void *GetInterfaceImpl(TypeId interface_id) const; private: diff --git a/paddle/ir/core/op_info_impl.cc b/paddle/ir/core/op_info_impl.cc index 57d15b22c289689992a7af0196465f95ebf37450..c28686d10aa6c1b8477c7c2dc32161c25605aa70 100644 --- a/paddle/ir/core/op_info_impl.cc +++ b/paddle/ir/core/op_info_impl.cc @@ -50,14 +50,14 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, } // Construct OpInfoImpl. VLOG(4) << "Construct OpInfoImpl at " << base_ptr << " ......"; - OpInfo op_info = new (base_ptr) OpInfoImpl(dialect, - op_id, - op_name, - interfaces_num, - traits_num, - attributes_num, - attributes_name, - verify); + OpInfo op_info = OpInfo(new (base_ptr) OpInfoImpl(dialect, + op_id, + op_name, + interfaces_num, + traits_num, + attributes_num, + attributes_name, + verify)); return op_info; } void OpInfoImpl::Destroy(OpInfo info) { diff --git a/paddle/ir/core/type_id.h b/paddle/ir/core/type_id.h index 736152b4ff6400177b91769779af74fcbccc6881..7d22abb7388a2600f99c21c944d94150b43a050b 100644 --- a/paddle/ir/core/type_id.h +++ b/paddle/ir/core/type_id.h @@ -51,7 +51,10 @@ class TypeId { TypeId &operator=(const TypeId &other) = default; - const Storage *storage() const { return storage_; } + void *AsOpaquePointer() const { return storage_; } + static TypeId RecoverFromOpaquePointer(void *pointer) { + return TypeId(static_cast(pointer)); + } /// /// \brief Comparison operations. @@ -77,10 +80,11 @@ class TypeId { /// /// \param storage The storage of this TypeId. /// - explicit TypeId(const Storage *storage) : storage_(storage) {} + explicit TypeId(Storage *storage) : storage_(storage) {} - const Storage *storage_{nullptr}; + Storage *storage_{nullptr}; }; + } // namespace ir namespace std { diff --git a/paddle/ir/pass/pass.cc b/paddle/ir/pass/pass.cc index 5ede04a97c810593a10acb409405e433f50cab74..4b4db8e1d28d4ec9b4c10ec10a85383c16ee38ba 100644 --- a/paddle/ir/pass/pass.cc +++ b/paddle/ir/pass/pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/ir/pass/pass.h" + #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/program.h" diff --git a/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.cc b/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba9a680a3064e63faa07cd7181af7bd92ffc9e12 --- /dev/null +++ b/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" + +#include +#include +#include +#include + +#include "paddle/ir/core/op_info.h" +#include "paddle/utils/optional.h" + +namespace ir { + +FrozenRewritePatternSet::FrozenRewritePatternSet() + : impl_(std::make_shared()) {} + +FrozenRewritePatternSet::FrozenRewritePatternSet( + RewritePatternSet&& patterns, + const std::vector& disabled_pattern_labels, + const std::vector& enabled_pattern_labels) + : impl_(std::make_shared()) { + std::set disabled_patterns, enabled_patterns; + disabled_patterns.insert(disabled_pattern_labels.begin(), + disabled_pattern_labels.end()); + enabled_patterns.insert(enabled_pattern_labels.begin(), + enabled_pattern_labels.end()); + + ir::OpInfoMap op_info_map; + auto AddToOpsWhen = [&](std::unique_ptr& pattern, + std::function callback) { + if (op_info_map.empty()) + op_info_map = pattern->ir_context()->registered_op_info_map(); + for (auto& info_map : op_info_map) { + if (callback(info_map.second)) + impl_->op_specific_native_pattern_map_[info_map.second].push_back( + pattern.get()); + impl_->op_specific_native_patterns_.push_back(std::move(pattern)); + } + }; + + for (std::unique_ptr& pat : patterns.native_patterns()) { + // Don't add patterns that haven't been enabled by the user. + if (!enabled_patterns.empty()) { + auto IsEnableFn = [&](const std::string& label) { + return enabled_patterns.count(label); + }; + if (!IsEnableFn(pat->debug_name()) && + std::none_of(pat->debug_labels().begin(), + pat->debug_labels().end(), + IsEnableFn)) + continue; + } + + // Don't add patterns that have been disabled by the user. + if (!disabled_patterns.empty()) { + auto IsDisabledFn = [&](const std::string& label) { + return disabled_patterns.count(label); + }; + if (IsDisabledFn(pat->debug_name()) || + std::any_of(pat->debug_labels().begin(), + pat->debug_labels().end(), + IsDisabledFn)) + continue; + } + + if (paddle::optional root_name = pat->root_kind()) { + impl_->op_specific_native_pattern_map_[*root_name].push_back(pat.get()); + impl_->op_specific_native_patterns_.push_back(std::move(pat)); + continue; + } + + if (paddle::optional interface_id = pat->GetRootInterfaceID()) { + AddToOpsWhen( + pat, [&](OpInfo info) { return info.HasInterface(*interface_id); }); + continue; + } + + if (paddle::optional trait_id = pat->GetRootTraitID()) { + AddToOpsWhen(pat, [&](OpInfo info) { return info.HasTrait(*trait_id); }); + continue; + } + + impl_->match_any_op_native_patterns_.push_back(std::move(pat)); + } +} + +} // namespace ir diff --git a/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h b/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h new file mode 100644 index 0000000000000000000000000000000000000000..eb7b33d7a1f2bdf6c50c21f11dd220cbf255b391 --- /dev/null +++ b/paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// The design and code is mainly from MLIR, thanks to the greate project. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/ir/core/op_info.h" +#include "paddle/ir/pattern_rewrite/pattern_match.h" + +namespace ir { + +class FrozenRewritePatternSet { + using NativePatternListT = std::vector>; + + public: + using OpSpecificNativePatternListT = + std::unordered_map>; + + FrozenRewritePatternSet(); + FrozenRewritePatternSet(FrozenRewritePatternSet&& patterns) = default; + FrozenRewritePatternSet(const FrozenRewritePatternSet& patterns) = default; + FrozenRewritePatternSet& operator=(FrozenRewritePatternSet&& patterns) = + default; + FrozenRewritePatternSet& operator=(const FrozenRewritePatternSet& patterns) = + default; + ~FrozenRewritePatternSet() = default; + + /// Freeze the patterns held in `patterns`, and take ownership. + FrozenRewritePatternSet( + RewritePatternSet&& patterns, + const std::vector& disabled_pattern_labels = {}, + const std::vector& enabled_pattern_labels = {}); + + /// Return the op specific native patterns held by this list. + const OpSpecificNativePatternListT& op_specific_native_patterns() const { + return impl_->op_specific_native_pattern_map_; + } + + /// Return the "match any" native patterns held by this list. + const NativePatternListT& match_any_op_native_patterns() const { + return impl_->match_any_op_native_patterns_; + } + + private: + struct Impl { + OpSpecificNativePatternListT op_specific_native_pattern_map_; + + NativePatternListT op_specific_native_patterns_; + + NativePatternListT match_any_op_native_patterns_; + }; + + std::shared_ptr impl_; +}; + +} // namespace ir diff --git a/paddle/ir/pattern_rewrite/pattern_applicator.cc b/paddle/ir/pattern_rewrite/pattern_applicator.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a0a712afbeb289fd17c1463c545a163725c357d --- /dev/null +++ b/paddle/ir/pattern_rewrite/pattern_applicator.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/ir/pattern_rewrite/pattern_applicator.h" + +#include "paddle/ir/pattern_rewrite/pattern_match.h" + +namespace ir { + +PatternApplicator::PatternApplicator( + const FrozenRewritePatternSet& frozen_patter_list) + : frozen_patter_list_(frozen_patter_list) {} + +void PatternApplicator::ApplyCostModel(CostModel model) { + // TODO(wilber): remove impossible patterns. + patterns_.clear(); + for (const auto& it : frozen_patter_list_.op_specific_native_patterns()) { + for (const RewritePattern* pattern : it.second) { + patterns_[it.first].push_back(pattern); + } + } + + any_op_patterns_.clear(); + for (auto& pattern : frozen_patter_list_.match_any_op_native_patterns()) { + any_op_patterns_.push_back(pattern.get()); + } + + // Sort by benefit based on the cost model. + std::unordered_map benefits; + auto cmp = [&benefits](const Pattern* lhs, const Pattern* rhs) { + return benefits[lhs] > benefits[rhs]; + }; + auto ProcessPatternList = [&](std::vector& list) { + if (list.size() == 1) return; + + benefits.clear(); + for (const Pattern* pat : list) benefits.emplace(pat, model(*pat)); + + std::stable_sort(list.begin(), list.end(), cmp); + }; + for (auto& it : patterns_) { + ProcessPatternList(it.second); + } + ProcessPatternList(any_op_patterns_); +} + +void PatternApplicator::WalkAllPatterns( + std::function walk) { + for (const auto& it : frozen_patter_list_.op_specific_native_patterns()) + for (auto* pattern : it.second) walk(*pattern); + + for (auto& it : frozen_patter_list_.match_any_op_native_patterns()) walk(*it); +} + +bool PatternApplicator::MatchAndRewrite( + Operation* op, + PatternRewriter& rewriter, + std::function can_apply, + std::function on_failure, + std::function on_success) { + // whether there are patterns matching this operation type. + std::vector op_patterns; + auto pattern_it = patterns_.find(op->info()); + if (pattern_it != patterns_.end()) op_patterns = pattern_it->second; + + unsigned op_it = 0, op_e = op_patterns.size(); + unsigned any_it = 0, any_e = any_op_patterns_.size(); + bool result = false; + do { + // Find the next pattern with the highest benefit. + const Pattern* best_pattern = nullptr; + unsigned* best_pattern_it = &op_it; + + // For specific patterns + if (op_it < op_e) best_pattern = op_patterns[op_it]; + // For op-agnostic patterns + if (any_it < any_e && + (!best_pattern || + best_pattern->benefit() < any_op_patterns_[any_it]->benefit())) { + best_pattern_it = &any_it; + best_pattern = any_op_patterns_[any_it]; + } + + if (!best_pattern) break; + + // Update the pattern iterator, so that this pattern isn't attempted again. + ++(*best_pattern_it); + + if (can_apply && !can_apply(*best_pattern)) continue; + + rewriter.SetInsertionPoint(op); + + const auto* pattern = static_cast(best_pattern); + result = pattern->MatchAndRewrite(op, rewriter); + + if (result && on_success && !on_success(*best_pattern)) result = false; + + if (result) break; + + if (on_failure) on_failure(*best_pattern); + } while (true); + + return result; +} + +} // namespace ir diff --git a/paddle/ir/pattern_rewrite/pattern_applicator.h b/paddle/ir/pattern_rewrite/pattern_applicator.h new file mode 100644 index 0000000000000000000000000000000000000000..5c4bc8784607b0a64891c644427265261b4c1423 --- /dev/null +++ b/paddle/ir/pattern_rewrite/pattern_applicator.h @@ -0,0 +1,56 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The design and code is mainly from MLIR, thanks to the greate project. + +#pragma once + +#include +#include + +#include "paddle/ir/core/op_info.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/ir/pattern_rewrite/pattern_match.h" + +namespace ir { + +class PatternApplicator { + public: + using CostModel = std::function; + + explicit PatternApplicator(const FrozenRewritePatternSet& frozen_patter_list); + ~PatternApplicator() = default; + + bool MatchAndRewrite(Operation* op, + PatternRewriter& rewriter, // NOLINT + std::function can_apply = {}, + std::function on_failure = {}, + std::function on_success = {}); + + void ApplyCostModel(CostModel model); + + void ApplyDefaultCostModel() { + ApplyCostModel([](const Pattern& pattern) { return pattern.benefit(); }); + } + + void WalkAllPatterns(std::function walk); + + private: + const FrozenRewritePatternSet& frozen_patter_list_; + std::unordered_map> patterns_; + std::vector any_op_patterns_; +}; + +} // namespace ir diff --git a/paddle/ir/pattern_rewrite/pattern_match.cc b/paddle/ir/pattern_rewrite/pattern_match.cc index 7763789fe5976494ef26be951a031848ed2777b0..cd7950b0af5d9cbc920429696ddb3f6390bc996e 100644 --- a/paddle/ir/pattern_rewrite/pattern_match.cc +++ b/paddle/ir/pattern_rewrite/pattern_match.cc @@ -13,8 +13,11 @@ // limitations under the License. #include "paddle/ir/pattern_rewrite/pattern_match.h" + +#include #include #include + #include "paddle/ir/core/operation.h" namespace ir { @@ -22,62 +25,69 @@ namespace ir { //===----------------------------------------------------------------------===// // Pattern //===----------------------------------------------------------------------===// - -// Pattern::Pattern(const void* root_val, -// RootKind root_kind, -// const std::vector& generated_names, -// PatternBenefit benefit, -// ir::IrContext* context) -// : benefit_(benefit), context_(context), generated_names_(generated_names) -// {} - Pattern::Pattern(const std::string& root_name, PatternBenefit benefit, IrContext* context, const std::vector& generated_names) - : op_name_(root_name), - root_kind_(RootKind::OperationName), - benefit_(benefit), - context_(context), - generated_names_(generated_names) {} + : Pattern(context->GetRegisteredOpInfo(root_name).AsOpaquePointer(), + RootKind::OperationInfo, + generated_names, + benefit, + context) {} Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, - ir::IrContext* context, + IrContext* context, const std::vector& generated_names) - : root_kind_(RootKind::Any), - benefit_(benefit), - context_(context), - generated_names_(generated_names) {} + : Pattern(nullptr, RootKind::Any, generated_names, benefit, context) {} Pattern::Pattern(MatchInterfaceOpTypeTag tag, - ir::TypeId interface_id, + TypeId interface_id, PatternBenefit benefit, - ir::IrContext* context, + IrContext* context, const std::vector& generated_names) - : interface_id_(interface_id), - root_kind_(RootKind::InterfaceId), - benefit_(benefit), - context_(context), - generated_names_(generated_names) {} + : Pattern(interface_id.AsOpaquePointer(), + RootKind::InterfaceId, + generated_names, + benefit, + context) {} Pattern::Pattern(MatchTraitOpTypeTag tag, - ir::TypeId trait_id, + TypeId trait_id, PatternBenefit benefit, - ir::IrContext* context, + IrContext* context, const std::vector& generated_names) - : trait_id_(trait_id), - root_kind_(RootKind::TraitId), + : Pattern(trait_id.AsOpaquePointer(), + RootKind::TraitId, + generated_names, + benefit, + context) {} + +Pattern::Pattern(void* root_val, + RootKind root_kind, + const std::vector& generated_names, + PatternBenefit benefit, + IrContext* context) + : root_val_(root_val), + root_kind_(root_kind), benefit_(benefit), - context_(context), - generated_names_(generated_names) {} + context_(context) { + if (generated_names.empty()) return; + + generated_ops_.reserve(generated_names.size()); + std::transform(generated_names.begin(), + generated_names.end(), + std::back_inserter(generated_ops_), + [context](const std::string& name) { + return context->GetRegisteredOpInfo(name); + }); +} RewritePattern::~RewritePattern() = default; //===----------------------------------------------------------------------===// // RewriterBase //===----------------------------------------------------------------------===// - RewriterBase::~RewriterBase() = default; // TODO(wilber): value support replace method. @@ -113,9 +123,9 @@ void RewriterBase::EraseOp(Operation* op) { void RewriterBase::ReplaceAllUsesWith(Value from, Value to) { // from. - // for (mlir::OpOperand& operand : llvm::make_early_inc_range(from.getUses())) + // for (OpOperand& operand : llvm::make_early_inc_range(from.getUses())) // { - // mlir::Operation* op = operand.getOwner(); + // Operation* op = operand.getOwner(); // UpdateRootInPlace(op, [&]() { operand.set(to); }); // } } diff --git a/paddle/ir/pattern_rewrite/pattern_match.h b/paddle/ir/pattern_rewrite/pattern_match.h index 0017afea6123636367103e6ea0d7cfd1c35a686f..f2f073ffc047dcec20022957e6c77e5b1487b7ac 100644 --- a/paddle/ir/pattern_rewrite/pattern_match.h +++ b/paddle/ir/pattern_rewrite/pattern_match.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// The design is mainly from MLIR, very thanks to the greate project. + #pragma once #include @@ -19,46 +21,78 @@ #include #include #include -#include #include + +#include "paddle/ir/core/builder.h" #include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/op_info.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/type_id.h" #include "paddle/ir/core/type_name.h" #include "paddle/ir/core/value.h" -namespace ir { +#include "paddle/utils/optional.h" -/// The design is mainly from MLIR, very thanks to the greate project. +namespace ir { -/// This class reprensents the benefit of a pattern. The most common -/// unit to use is the `numver of operations` in the pattern. +// This class reprensents the benefit of a pattern. The most common +// unit to use is the `numver of operations` in the pattern. class PatternBenefit { public: - PatternBenefit(unsigned val) : val_(val) {} // NOLINT + PatternBenefit() = default; + PatternBenefit(uint32_t val) : val_(val) {} // NOLINT - unsigned benefit() { return val_; } + uint32_t benefit() { return val_; } bool operator==(const PatternBenefit& rhs) const { return val_ == rhs.val_; } bool operator!=(const PatternBenefit& rhs) const { return !(*this == rhs); } bool operator<(const PatternBenefit& rhs) const { return val_ < rhs.val_; } bool operator>(const PatternBenefit& rhs) const { return rhs < *this; } bool operator<=(const PatternBenefit& rhs) const { return !(*this > rhs); } - bool operator>=(const PatternBenefit& rhs) const { return !(*this <= rhs); } + bool operator>=(const PatternBenefit& rhs) const { return !(*this < rhs); } private: - unsigned int val_{0}; + uint32_t val_{0}; }; -/// This class contains all of the data related to a Pattern, but not contains -/// any methods for the matching. This class is used to interface with the -/// metadata of a pattern, such as benefit or root operation. +// This class contains all of the data related to a Pattern, but not contains +// any methods for the matching. This class is used to interface with the +// metadata of a pattern, such as benefit or root operation. class Pattern { - enum class RootKind { Any, OperationName, InterfaceId, TraitId }; + enum class RootKind { + // The pattern root matches "any" operation. + Any, + // The pattern root is matched using a concrete operation. + OperationInfo, + // The pattern root is matched using an interface id. + InterfaceId, + // The patter root is matched using a trait id. + TraitId + }; public: + const std::vector& generated_ops() const { return generated_ops_; } + + paddle::optional root_kind() const { + if (root_kind_ == RootKind::OperationInfo) + return OpInfo::RecoverFromOpaquePointer(root_val_); + return paddle::none; + } + + paddle::optional GetRootInterfaceID() const { + if (root_kind_ == RootKind::InterfaceId) + return TypeId::RecoverFromOpaquePointer(root_val_); + return paddle::none; + } + + paddle::optional GetRootTraitID() const { + if (root_kind_ == RootKind::TraitId) + return TypeId::RecoverFromOpaquePointer(root_val_); + return paddle::none; + } + PatternBenefit benefit() const { return benefit_; } - IrContext* context() const { return context_; } + IrContext* ir_context() const { return context_; } std::string debug_name() const { return debug_name_; } @@ -81,41 +115,39 @@ class Pattern { Pattern(const std::string& root_name, PatternBenefit benefit, - ir::IrContext* context, + IrContext* context, const std::vector& generated_names = {}); Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, - ir::IrContext* context, + IrContext* context, const std::vector& generated_names = {}); Pattern(MatchInterfaceOpTypeTag tag, - ir::TypeId interface_id, + TypeId interface_id, PatternBenefit benefit, - ir::IrContext* context, + IrContext* context, const std::vector& generated_names = {}); Pattern(MatchTraitOpTypeTag tag, - ir::TypeId trait_id, + TypeId trait_id, PatternBenefit benefit, - ir::IrContext* context, + IrContext* context, const std::vector& generated_names = {}); private: - // TODO(wilber): How to uniform variables and constructor. - // Pattern(const void* root_val, - // RootKind root_kind, - // const std::vector& generated_names, - // PatternBenefit benefit, - // ir::IrContext* context); - std::string op_name_; - ir::TypeId interface_id_; - ir::TypeId trait_id_; + Pattern(void* root_val, + RootKind root_kind, + const std::vector& generated_names, + PatternBenefit benefit, + IrContext* context); + + void* root_val_; RootKind root_kind_; const PatternBenefit benefit_; - ir::IrContext* context_; - std::vector generated_names_; + IrContext* context_; + std::vector generated_ops_; std::string debug_name_; std::vector debug_labels_; @@ -127,19 +159,19 @@ class RewritePattern : public Pattern { public: virtual ~RewritePattern(); - virtual void Rewrite(ir::Operation* op, + virtual void Rewrite(Operation* op, PatternRewriter& rewriter) const { // NOLINT throw( "need to implement either MatchAndRewrite or one of the rewrite " "functions."); } - virtual bool Match(ir::Operation* op) const { + virtual bool Match(Operation* op) const { throw("need to implement either MatchAndRewrite or Match."); return false; } - virtual bool MatchAndRewrite(ir::Operation* op, + virtual bool MatchAndRewrite(Operation* op, PatternRewriter& rewriter) const { // NOLINT if (Match(op)) { Rewrite(op, rewriter); @@ -157,7 +189,7 @@ class RewritePattern : public Pattern { pattern->Initialize(); if (pattern->debug_name().empty()) - pattern->SetDebugName(get_type_name()); + pattern->SetDebugName(ir::get_type_name()); return pattern; } @@ -166,8 +198,8 @@ class RewritePattern : public Pattern { }; namespace detail { -/// A wrapper around PatternWrite that allows for matching and rewriting -/// against an instance of a derived operation class or Interface. +// A wrapper around PatternWrite that allows for matching and rewriting +// against an instance of a derived operation class or Interface. template struct OpOrInterfaceRewritePatternBase : public RewritePattern { using RewritePattern::RewritePattern; @@ -203,30 +235,25 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern { }; } // namespace detail -/// OpRewritePattern is a wrapper around RewritePattern that allows for -/// matching and rewriting against an instance of a derived operation -/// class as opposed to a raw Operation. +// OpRewritePattern is a wrapper around RewritePattern that allows for +// matching and rewriting against an instance of a derived operation +// class as opposed to a raw Operation. template struct OpRewritePattern : public detail::OpOrInterfaceRewritePatternBase { - OpRewritePattern(ir::IrContext* context, + OpRewritePattern(IrContext* context, PatternBenefit benefit = 1, const std::vector& generated_names = {}) : detail::OpOrInterfaceRewritePatternBase( - "NeedToFix", // TODO(wilber): Need to fix. SourceOp maybe should - // have a getOperationName static method. - benefit, - context, - generated_names) {} + SourceOp::name(), benefit, context, generated_names) {} }; // TODO(wilber): Support OpInterfaceRewritePattern and OpTraitRewritePattern. // ... -/// This class provides a series of interfaces for modifying IR and tracking IR -/// changes. This class provides a unified API for IR modification. -/// -class RewriterBase { // maybe should inherit OpBuilder. +// This class provides a series of interfaces for modifying IR and tracking IR +// changes. This class provides a unified API for IR modification. +class RewriterBase : public Builder { public: // TODO(wilber): Supplementary methods of block and region. @@ -262,7 +289,7 @@ class RewriterBase { // maybe should inherit OpBuilder. std::function functor); protected: - explicit RewriterBase(IrContext* ctx) : ctx_(ctx) {} + explicit RewriterBase(IrContext* ctx) : Builder(ctx) {} virtual ~RewriterBase(); @@ -277,9 +304,6 @@ class RewriterBase { // maybe should inherit OpBuilder. RewriterBase(const RewriterBase&) = delete; void ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op); - - private: - IrContext* ctx_; }; class PatternRewriter : public RewriterBase { @@ -287,7 +311,7 @@ class PatternRewriter : public RewriterBase { using RewriterBase::RewriterBase; }; -/// A pattern collection, easy to add patterns. +// A pattern collection, easy to add patterns. class RewritePatternSet { using NativePatternListT = std::vector>; @@ -299,7 +323,7 @@ class RewritePatternSet { native_patterns_.emplace_back(std::move(pattern)); } - IrContext* context() const { return context_; } + IrContext* ir_context() const { return context_; } NativePatternListT& native_patterns() { return native_patterns_; } @@ -351,6 +375,8 @@ class RewritePatternSet { private: IrContext* const context_; + NativePatternListT native_patterns_; }; + } // namespace ir diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 95f6b7c598e5cde7dfcd01e19a0278c91eb38c13..d9189df2be09f913ad1aa384352badaabbb0201b 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -102,7 +102,7 @@ class Operation1 : public ir::Op { ir::OperationArgument &argument) { // NOLINT std::vector inputs = {}; std::vector output_types = { - ir::Float32Type::get(builder.context())}; + ir::Float32Type::get(builder.ir_context())}; std::unordered_map attributes = CreateAttributeMap({"op1_attr1", "op1_attr2"}, {"op1_attr1", "op1_attr2"}); diff --git a/test/cpp/ir/pattern_rewrite/CMakeLists.txt b/test/cpp/ir/pattern_rewrite/CMakeLists.txt index 67cd3c8c0fd809819484e300e63a380582705418..4332a6989828e656c5bcd7412c586e70ee9465ac 100644 --- a/test/cpp/ir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/ir/pattern_rewrite/CMakeLists.txt @@ -5,6 +5,4 @@ cc_test_old( DEPS new_pass pattern_rewrite - pd_dialect - phi gtest) diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index de8f25809acc4dc729aa5519b69684fd8e950418..a4eb263d35c1d5ecc6226c1ee5795c774c84a75f 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -80,6 +80,7 @@ class TestPatternRewrite : public ir::OpRewritePattern { void Rewrite(Operation1 op, ir::PatternRewriter &rewriter) const override {} bool Match(Operation1 op) const override { return false; } }; + class TestPatternRewrite2 : public ir::OpRewritePattern { public: using ir::OpRewritePattern::OpRewritePattern;