// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "paddle/ir/core/cast_utils.h" #include "paddle/ir/core/type_id.h" #include "paddle/ir/core/type_name.h" #include "paddle/ir/pass/pass_instrumentation.h" #include "paddle/ir/pass/utils.h" namespace ir { class Operation; class AnalysisManager; class PassInstrumentor; namespace detail { /// A utility class to reprensent the analyses that are kwnown to be preserved. class PreservedAnalyses { struct AllAnalysesType {}; public: /// Mark all analyses as preserved. void PreserveAll() { preserved_ids_.insert(TypeId::get()); } bool IsAll() const { return preserved_ids_.count(TypeId::get()); } bool IsNone() const { return preserved_ids_.empty(); } template void Preserve() { Preserve(TypeId::get()); } template void Preserve() { Preserve(); Preserve(); } void Preserve(TypeId id) { preserved_ids_.insert(id); } template bool IsPreserved() const { return IsPreserved(TypeId::get()); } bool IsPreserved(TypeId id) const { return preserved_ids_.count(id); } template void Unpreserve() { preserved_ids_.erase(TypeId::get()); } friend ir::detail::TypeIdResolver; private: template friend struct AnalysisModel; std::unordered_set preserved_ids_; }; namespace detail { /// Trait to check if T provides a static `IsInvalidated` method. template using has_is_invalidated = decltype(std::declval().IsInvalidated( std::declval())); /// Implementation of `IsInvalidated` if the analysis provides a definition. template std::enable_if_t::value, bool> IsInvalidated(AnalysisT& analysis, const PreservedAnalyses& pa) { // NOLINT return analysis.IsInvalidated(pa); } /// Default implementation of `IsInvalidated`. template std::enable_if_t::value, bool> IsInvalidated(AnalysisT& analysis, const PreservedAnalyses& pa) { // NOLINT return !pa.IsPreserved(); } } // namespace detail /// Abstract base class representing an analysis. struct AnalysisConcept { virtual ~AnalysisConcept() = default; // A hook used to query analyses for invalidation. virtual bool Invalidate(PreservedAnalyses& pa) = 0; // NOLINT }; template struct AnalysisModel : public AnalysisConcept { template explicit AnalysisModel(Args&&... args) : analysis(std::forward(args)...) {} bool Invalidate(PreservedAnalyses& pa) final { bool result = detail::IsInvalidated(analysis, pa); if (result) pa.Unpreserve(); return result; } AnalysisT analysis; }; /// This class represents a cache of analyses for a single operation. /// All computation, caching and invalidation of analyses takes place here. class AnalysisMap { public: explicit AnalysisMap(Operation* ir) : ir_(ir) {} template AnalysisT& GetAnalysis(PassInstrumentor* pi, AnalysisManager& am) { return GetAnalysisImpl(pi, ir_, am); } template std::enable_if_t< std::is_constructible::value || std::is_constructible::value, AnalysisT&> GetAnalysis(PassInstrumentor* pi, AnalysisManager& am) { // NOLINT return GetAnalysisImpl(pi, ir::cast(ir_), am); } template std::optional> GetCachedAnalysis() const { auto res = analyses_.find(TypeId::get()); if (res == analyses_.end()) return std::nullopt; return {static_cast&>(*res->second).analysis}; } Operation* GetOperation() const { return ir_; } void Clear() { analyses_.clear(); } /// Invalidate any cached analyses based upon the given set of preserved void Invalidate(const PreservedAnalyses& pa) { PreservedAnalyses pa_copy(pa); // Remove any analyses that were invalidaed. // As using MapVector, order of insertion is preserved and // dependencies always go before users, so need only one iteration. for (auto it = analyses_.begin(); it != analyses_.end();) { if (it->second->Invalidate(pa_copy)) it = analyses_.erase(it); else ++it; } } private: template static std::string GetAnalysisName() { std::string name = ir::get_type_name(); auto pos = name.rfind("::"); if (pos != std::string::npos) { name = name.substr(pos + 2); } return name; } template AnalysisT& GetAnalysisImpl(PassInstrumentor* pi, OpT op, AnalysisManager& am) { // NOLINT TypeId id = TypeId::get(); auto it = analyses_.find(id); if (it == analyses_.end()) { if (pi) { pi->RunBeforeAnalysis(GetAnalysisName(), id, ir_); } bool was_inserted; std::tie(it, was_inserted) = analyses_.insert({id, ConstructAnalysis(am, op)}); assert(was_inserted); if (pi) { pi->RunAfterAnalysis(GetAnalysisName(), id, ir_); } } return static_cast&>(*it->second).analysis; } /// Construct analysis using two arguments constructor (OpT, /// AnalysisManager&). template < typename AnalysisT, typename OpT, std::enable_if_t< std::is_constructible::value>* = nullptr> static auto ConstructAnalysis(AnalysisManager& am, OpT op) { // NOLINT return std::make_unique>(op, am); } /// Construct analysis using single argument constructor (OpT) template < typename AnalysisT, typename OpT, std::enable_if_t< !std::is_constructible::value>* = nullptr> static auto ConstructAnalysis(AnalysisManager&, OpT op) { return std::make_unique>(op); } private: Operation* ir_; std::unordered_map> analyses_; }; } // namespace detail /// This class is intended to be passed around by value, and can not be /// constructed direcyly. class AnalysisManager { public: using PreservedAnalyses = detail::PreservedAnalyses; template AnalysisT& GetAnalysis() { return analyses_->GetAnalysis(GetPassInstrumentor(), *this); } template AnalysisT& GetAnalysis() { return analyses_->GetAnalysis(GetPassInstrumentor(), *this); } template std::optional> GetCachedAnalysis() const { return analyses_->GetCachedAnalysis(); } void Invalidate(const PreservedAnalyses& pa) { if (pa.IsAll()) return; // Invalidate the analyses for the current operation directly. analyses_->Invalidate(pa); } void Clear() { analyses_->Clear(); } PassInstrumentor* GetPassInstrumentor() const { return instrumentor_; } Operation* GetOperation() { return analyses_->GetOperation(); } private: AnalysisManager(detail::AnalysisMap* impl, PassInstrumentor* pi) : analyses_(impl), instrumentor_(pi) {} private: detail::AnalysisMap* analyses_; PassInstrumentor* instrumentor_; // For access constructor. friend class AnalysisManagerHolder; }; /// A manager class for the container operation. This class hold the /// memory for the analyses. AnalysisManager just hold the ref to the /// analyses. class AnalysisManagerHolder { public: AnalysisManagerHolder(Operation* op, PassInstrumentor* pi) : analyses_(op), pi_(pi) {} AnalysisManagerHolder(const AnalysisManagerHolder&) = delete; AnalysisManagerHolder& operator=(const AnalysisManagerHolder&) = delete; /// Returns an analysis manager for the current container op. operator AnalysisManager() { return AnalysisManager(&analyses_, pi_); } private: detail::AnalysisMap analyses_; PassInstrumentor* pi_; }; } // namespace ir IR_DECLARE_EXPLICIT_TYPE_ID(ir::detail::PreservedAnalyses::AllAnalysesType)