// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/type_id.h" #include "paddle/ir/core/type_name.h" #include "paddle/ir/core/value.h" namespace ir { /// The design is mainly from MLIR, very thanks to the greate project. /// This class reprensents the benefit of a pattern. The most common /// unit to use is the `numver of operations` in the pattern. class PatternBenefit { public: PatternBenefit(unsigned val) : val_(val) {} // NOLINT unsigned benefit() { return val_; } bool operator==(const PatternBenefit& rhs) const { return val_ == rhs.val_; } bool operator!=(const PatternBenefit& rhs) const { return !(*this == rhs); } bool operator<(const PatternBenefit& rhs) const { return val_ < rhs.val_; } bool operator>(const PatternBenefit& rhs) const { return rhs < *this; } bool operator<=(const PatternBenefit& rhs) const { return !(*this > rhs); } bool operator>=(const PatternBenefit& rhs) const { return !(*this <= rhs); } private: unsigned int val_{0}; }; /// This class contains all of the data related to a Pattern, but not contains /// any methods for the matching. This class is used to interface with the /// metadata of a pattern, such as benefit or root operation. class Pattern { enum class RootKind { Any, OperationName, InterfaceId, TraitId }; public: PatternBenefit benefit() const { return benefit_; } IrContext* context() const { return context_; } std::string debug_name() const { return debug_name_; } void SetDebugName(const std::string& name) { debug_name_ = name; } const std::vector& debug_labels() const { return debug_labels_; } void AddDebugLabels(const std::vector& labels) { debug_labels_.insert(debug_labels_.end(), labels.begin(), labels.end()); } void AddDebugLabels(const std::string& label) { debug_labels_.push_back(label); } protected: struct MatchAnyOpTypeTag {}; struct MatchInterfaceOpTypeTag {}; struct MatchTraitOpTypeTag {}; Pattern(const std::string& root_name, PatternBenefit benefit, ir::IrContext* context, const std::vector& generated_names = {}); Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, ir::IrContext* context, const std::vector& generated_names = {}); Pattern(MatchInterfaceOpTypeTag tag, ir::TypeId interface_id, PatternBenefit benefit, ir::IrContext* context, const std::vector& generated_names = {}); Pattern(MatchTraitOpTypeTag tag, ir::TypeId trait_id, PatternBenefit benefit, ir::IrContext* context, const std::vector& generated_names = {}); private: // TODO(wilber): How to uniform variables and constructor. // Pattern(const void* root_val, // RootKind root_kind, // const std::vector& generated_names, // PatternBenefit benefit, // ir::IrContext* context); std::string op_name_; ir::TypeId interface_id_; ir::TypeId trait_id_; RootKind root_kind_; const PatternBenefit benefit_; ir::IrContext* context_; std::vector generated_names_; std::string debug_name_; std::vector debug_labels_; }; class PatternRewriter; class RewritePattern : public Pattern { public: virtual ~RewritePattern(); virtual void Rewrite(ir::Operation* op, PatternRewriter& rewriter) const { // NOLINT throw( "need to implement either MatchAndRewrite or one of the rewrite " "functions."); } virtual bool Match(ir::Operation* op) const { throw("need to implement either MatchAndRewrite or Match."); return false; } virtual bool MatchAndRewrite(ir::Operation* op, PatternRewriter& rewriter) const { // NOLINT if (Match(op)) { Rewrite(op, rewriter); return true; } return false; } virtual void Initialize() {} template static std::unique_ptr Create(Args&&... args) { std::unique_ptr pattern = std::make_unique(std::forward(args)...); pattern->Initialize(); if (pattern->debug_name().empty()) pattern->SetDebugName(get_type_name()); return pattern; } protected: using Pattern::Pattern; }; namespace detail { /// A wrapper around PatternWrite that allows for matching and rewriting /// against an instance of a derived operation class or Interface. template struct OpOrInterfaceRewritePatternBase : public RewritePattern { using RewritePattern::RewritePattern; void Rewrite(Operation* op, PatternRewriter& rewriter) const final { // NOLINT Rewrite(op->dyn_cast(), rewriter); } bool Match(Operation* op) const final { return Match(op->dyn_cast()); } bool MatchAndRewrite(Operation* op, PatternRewriter& rewriter) const final { // NOLINT return MatchAndRewrite(op->dyn_cast(), rewriter); } virtual void Rewrite(SourceOp op, PatternRewriter& rewriter) const { // NOLINT throw("must override Rewrite or MatchAndRewrite"); } virtual bool Match(SourceOp op) const { throw("must override Match or MatchAndRewrite"); } virtual bool MatchAndRewrite(SourceOp op, PatternRewriter& rewriter) const { // NOLINT if (Match(op)) { Rewrite(op, rewriter); return true; } return false; } }; } // namespace detail /// OpRewritePattern is a wrapper around RewritePattern that allows for /// matching and rewriting against an instance of a derived operation /// class as opposed to a raw Operation. template struct OpRewritePattern : public detail::OpOrInterfaceRewritePatternBase { OpRewritePattern(ir::IrContext* context, PatternBenefit benefit = 1, const std::vector& generated_names = {}) : detail::OpOrInterfaceRewritePatternBase( "NeedToFix", // TODO(wilber): Need to fix. SourceOp maybe should // have a getOperationName static method. benefit, context, generated_names) {} }; // TODO(wilber): Support OpInterfaceRewritePattern and OpTraitRewritePattern. // ... /// This class provides a series of interfaces for modifying IR and tracking IR /// changes. This class provides a unified API for IR modification. /// class RewriterBase { // maybe should inherit OpBuilder. public: // TODO(wilber): Supplementary methods of block and region. // TODO(wilber): Support ValueRange. // virtual void ReplaceOpWithIf(Operation* op, // ValueRange new_values, // bool* all_uses_replaced, // std::function functor); // void ReplaceOpWithIf(Operation* op, // ValueRange new_values, // std::function functor); // virtual void ReplaceOp(Operation* op, ValueRange new_values); // virtual void ReplaceOpWithNewOp() virtual void EraseOp(Operation* op); virtual void StartRootUpdate(Operation* op) {} virtual void FinalizeRootUpdate(Operation* op) {} virtual void CancleRootUpdate(Operation* op) {} template void UpdateRootInplace(Operation* root, CallableT&& callable) { StartRootUpdate(root); callable(); FinalizeRootUpdate(root); } void ReplaceAllUsesWith(Value from, Value to); void ReplaceUseIf(Value from, Value to, std::function functor); protected: explicit RewriterBase(IrContext* ctx) : ctx_(ctx) {} virtual ~RewriterBase(); // virtual void NotifyRootReplaced(Operation* op, ValueRange replacement) {} virtual void NotifyOperationRemoved(Operation* op) {} // virtual bool NotifyMatchFailure() private: void operator=(const RewriterBase&) = delete; RewriterBase(const RewriterBase&) = delete; void ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op); private: IrContext* ctx_; }; class PatternRewriter : public RewriterBase { public: using RewriterBase::RewriterBase; }; /// A pattern collection, easy to add patterns. class RewritePatternSet { using NativePatternListT = std::vector>; public: explicit RewritePatternSet(IrContext* context) : context_(context) {} RewritePatternSet(IrContext* context, std::unique_ptr pattern) : context_(context) { native_patterns_.emplace_back(std::move(pattern)); } IrContext* context() const { return context_; } NativePatternListT& native_patterns() { return native_patterns_; } void Clear() { native_patterns_.clear(); } // 'add' methods for adding patterns to the set. template > RewritePatternSet& Add(ConstructorArg&& arg, ConstructorArgs&&... args) { std::initializer_list{ (AddImpl({}, std::forward(arg), std::forward(args)...), 0)...}; return *this; } template > RewritePatternSet& AddWithLabel(const std::vector& debug_labels, ConstructorArg&& arg, ConstructorArgs&&... args) { std::initializer_list{ (AddImpl(debug_labels, std::forward(arg), std::forward(args)...), 0)...}; return *this; } RewritePatternSet& Add(std::unique_ptr pattern) { native_patterns_.emplace_back(std::move(pattern)); return *this; } private: template std::enable_if_t::value> AddImpl( const std::vector& debug_labels, Args&&... args) { std::unique_ptr pattern = RewritePattern::Create(std::forward(args)...); pattern->AddDebugLabels(debug_labels); native_patterns_.emplace_back(std::move(pattern)); } private: IrContext* const context_; NativePatternListT native_patterns_; }; } // namespace ir