// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // The design is mainly from MLIR, very thanks to the greate project. #pragma once #include #include #include #include #include #include #include "paddle/ir/core/builder.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/type_id.h" #include "paddle/ir/core/type_name.h" #include "paddle/ir/core/value.h" #include "paddle/utils/optional.h" namespace ir { // This class reprensents the benefit of a pattern. The most common // unit to use is the `numver of operations` in the pattern. class PatternBenefit { public: PatternBenefit() = default; PatternBenefit(uint32_t val) : val_(val) {} // NOLINT uint32_t benefit() { return val_; } bool operator==(const PatternBenefit& rhs) const { return val_ == rhs.val_; } bool operator!=(const PatternBenefit& rhs) const { return !(*this == rhs); } bool operator<(const PatternBenefit& rhs) const { return val_ < rhs.val_; } bool operator>(const PatternBenefit& rhs) const { return rhs < *this; } bool operator<=(const PatternBenefit& rhs) const { return !(*this > rhs); } bool operator>=(const PatternBenefit& rhs) const { return !(*this < rhs); } private: uint32_t val_{0}; }; // This class contains all of the data related to a Pattern, but not contains // any methods for the matching. This class is used to interface with the // metadata of a pattern, such as benefit or root operation. class Pattern { enum class RootKind { // The pattern root matches "any" operation. Any, // The pattern root is matched using a concrete operation. OperationInfo, // The pattern root is matched using an interface id. InterfaceId, // The patter root is matched using a trait id. TraitId }; public: const std::vector& generated_ops() const { return generated_ops_; } paddle::optional root_kind() const { if (root_kind_ == RootKind::OperationInfo) return OpInfo::RecoverFromOpaquePointer(root_val_); return paddle::none; } paddle::optional GetRootInterfaceID() const { if (root_kind_ == RootKind::InterfaceId) return TypeId::RecoverFromOpaquePointer(root_val_); return paddle::none; } paddle::optional GetRootTraitID() const { if (root_kind_ == RootKind::TraitId) return TypeId::RecoverFromOpaquePointer(root_val_); return paddle::none; } PatternBenefit benefit() const { return benefit_; } IrContext* ir_context() const { return context_; } std::string debug_name() const { return debug_name_; } void SetDebugName(const std::string& name) { debug_name_ = name; } const std::vector& debug_labels() const { return debug_labels_; } void AddDebugLabels(const std::vector& labels) { debug_labels_.insert(debug_labels_.end(), labels.begin(), labels.end()); } void AddDebugLabels(const std::string& label) { debug_labels_.push_back(label); } protected: struct MatchAnyOpTypeTag {}; struct MatchInterfaceOpTypeTag {}; struct MatchTraitOpTypeTag {}; Pattern(const std::string& root_name, PatternBenefit benefit, IrContext* context, const std::vector& generated_names = {}); Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, IrContext* context, const std::vector& generated_names = {}); Pattern(MatchInterfaceOpTypeTag tag, TypeId interface_id, PatternBenefit benefit, IrContext* context, const std::vector& generated_names = {}); Pattern(MatchTraitOpTypeTag tag, TypeId trait_id, PatternBenefit benefit, IrContext* context, const std::vector& generated_names = {}); private: Pattern(void* root_val, RootKind root_kind, const std::vector& generated_names, PatternBenefit benefit, IrContext* context); void* root_val_; RootKind root_kind_; const PatternBenefit benefit_; IrContext* context_; std::vector generated_ops_; std::string debug_name_; std::vector debug_labels_; }; class PatternRewriter; class RewritePattern : public Pattern { public: virtual ~RewritePattern(); virtual void Rewrite(Operation* op, PatternRewriter& rewriter) const { // NOLINT throw( "need to implement either MatchAndRewrite or one of the rewrite " "functions."); } virtual bool Match(Operation* op) const { throw("need to implement either MatchAndRewrite or Match."); return false; } virtual bool MatchAndRewrite(Operation* op, PatternRewriter& rewriter) const { // NOLINT if (Match(op)) { Rewrite(op, rewriter); return true; } return false; } virtual void Initialize() {} template static std::unique_ptr Create(Args&&... args) { std::unique_ptr pattern = std::make_unique(std::forward(args)...); pattern->Initialize(); if (pattern->debug_name().empty()) pattern->SetDebugName(ir::get_type_name()); return pattern; } protected: using Pattern::Pattern; }; namespace detail { // A wrapper around PatternWrite that allows for matching and rewriting // against an instance of a derived operation class or Interface. template struct OpOrInterfaceRewritePatternBase : public RewritePattern { using RewritePattern::RewritePattern; void Rewrite(Operation* op, PatternRewriter& rewriter) const final { // NOLINT Rewrite(op->dyn_cast(), rewriter); } bool Match(Operation* op) const final { return Match(op->dyn_cast()); } bool MatchAndRewrite(Operation* op, PatternRewriter& rewriter) const final { // NOLINT return MatchAndRewrite(op->dyn_cast(), rewriter); } virtual void Rewrite(SourceOp op, PatternRewriter& rewriter) const { // NOLINT throw("must override Rewrite or MatchAndRewrite"); } virtual bool Match(SourceOp op) const { throw("must override Match or MatchAndRewrite"); } virtual bool MatchAndRewrite(SourceOp op, PatternRewriter& rewriter) const { // NOLINT if (Match(op)) { Rewrite(op, rewriter); return true; } return false; } }; } // namespace detail // OpRewritePattern is a wrapper around RewritePattern that allows for // matching and rewriting against an instance of a derived operation // class as opposed to a raw Operation. template struct OpRewritePattern : public detail::OpOrInterfaceRewritePatternBase { OpRewritePattern(IrContext* context, PatternBenefit benefit = 1, const std::vector& generated_names = {}) : detail::OpOrInterfaceRewritePatternBase( SourceOp::name(), benefit, context, generated_names) {} }; // TODO(wilber): Support OpInterfaceRewritePattern and OpTraitRewritePattern. // ... // This class provides a series of interfaces for modifying IR and tracking IR // changes. This class provides a unified API for IR modification. class RewriterBase : public Builder { public: // TODO(wilber): Supplementary methods of block and region. // TODO(wilber): Support ValueRange. // virtual void ReplaceOpWithIf(Operation* op, // ValueRange new_values, // bool* all_uses_replaced, // std::function functor); // void ReplaceOpWithIf(Operation* op, // ValueRange new_values, // std::function functor); // virtual void ReplaceOp(Operation* op, ValueRange new_values); // virtual void ReplaceOpWithNewOp() virtual void EraseOp(Operation* op); virtual void StartRootUpdate(Operation* op) {} virtual void FinalizeRootUpdate(Operation* op) {} virtual void CancleRootUpdate(Operation* op) {} template void UpdateRootInplace(Operation* root, CallableT&& callable) { StartRootUpdate(root); callable(); FinalizeRootUpdate(root); } void ReplaceAllUsesWith(Value from, Value to); void ReplaceUseIf(Value from, Value to, std::function functor); protected: explicit RewriterBase(IrContext* ctx) : Builder(ctx) {} virtual ~RewriterBase(); // virtual void NotifyRootReplaced(Operation* op, ValueRange replacement) {} virtual void NotifyOperationRemoved(Operation* op) {} // virtual bool NotifyMatchFailure() private: void operator=(const RewriterBase&) = delete; RewriterBase(const RewriterBase&) = delete; void ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op); }; class PatternRewriter : public RewriterBase { public: using RewriterBase::RewriterBase; }; // A pattern collection, easy to add patterns. class RewritePatternSet { using NativePatternListT = std::vector>; public: explicit RewritePatternSet(IrContext* context) : context_(context) {} RewritePatternSet(IrContext* context, std::unique_ptr pattern) : context_(context) { native_patterns_.emplace_back(std::move(pattern)); } IrContext* ir_context() const { return context_; } NativePatternListT& native_patterns() { return native_patterns_; } void Clear() { native_patterns_.clear(); } // 'add' methods for adding patterns to the set. template > RewritePatternSet& Add(ConstructorArg&& arg, ConstructorArgs&&... args) { std::initializer_list{ (AddImpl({}, std::forward(arg), std::forward(args)...), 0)...}; return *this; } template > RewritePatternSet& AddWithLabel(const std::vector& debug_labels, ConstructorArg&& arg, ConstructorArgs&&... args) { std::initializer_list{ (AddImpl(debug_labels, std::forward(arg), std::forward(args)...), 0)...}; return *this; } RewritePatternSet& Add(std::unique_ptr pattern) { native_patterns_.emplace_back(std::move(pattern)); return *this; } private: template std::enable_if_t::value> AddImpl( const std::vector& debug_labels, Args&&... args) { std::unique_ptr pattern = RewritePattern::Create(std::forward(args)...); pattern->AddDebugLabels(debug_labels); native_patterns_.emplace_back(std::move(pattern)); } private: IrContext* const context_; NativePatternListT native_patterns_; }; } // namespace ir