/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_version_proto.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { namespace compatible { using OpAttrVariantT = boost::variant, /* AttrType::BOOLS */ std::vector, /* AttrType::FLOATS */ std::vector, /* AttrType::INTS */ std::vector, /* AttrType::LONGS */ std::vector /* AttrType::STRINGS */ >; struct OpUpdateInfo { virtual ~OpUpdateInfo() = default; }; struct OpAttrInfo : OpUpdateInfo { OpAttrInfo(const std::string& name, const std::string& remark, const OpAttrVariantT& default_value) : name_{name}, default_value_{default_value}, remark_{remark} {} const std::string& name() const { return name_; } const OpAttrVariantT& default_value() const { return default_value_; } const std::string& remark() const { return remark_; } private: std::string name_; OpAttrVariantT default_value_; std::string remark_; }; struct OpInputOutputInfo : OpUpdateInfo { OpInputOutputInfo(const std::string& name, const std::string& remark) : name_{name}, remark_{remark} {} const std::string& name() const { return name_; } const std::string& remark() const { return remark_; } private: std::string name_; std::string remark_; }; struct OpBugfixInfo : OpUpdateInfo { explicit OpBugfixInfo(const std::string& remark) : remark_{remark} {} const std::string& remark() const { return remark_; } private: std::string remark_; }; enum class OpUpdateType { kInvalid = 0, kModifyAttr, kNewAttr, kNewInput, kNewOutput, kBugfixWithBehaviorChanged, }; class OpUpdateBase { public: virtual const OpUpdateInfo* info() const = 0; virtual OpUpdateType type() const = 0; virtual ~OpUpdateBase() = default; }; template class OpUpdate : public OpUpdateBase { public: explicit OpUpdate(const InfoType& info) : info_{info}, type_{type__} {} const OpUpdateInfo* info() const override { return &info_; } OpUpdateType type() const override { return type_; } private: InfoType info_; OpUpdateType type_; }; class OpVersionDesc { public: OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark, const OpAttrVariantT& default_value); OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark, const OpAttrVariantT& default_value); OpVersionDesc&& NewInput(const std::string& name, const std::string& remark); OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark); OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark); const std::vector>& infos() const { return infos_; } OpVersionDesc() = default; OpVersionDesc(OpVersionDesc&&) = default; OpVersionDesc& operator=(OpVersionDesc&&) = default; private: std::vector> infos_; }; class OpCheckpoint { public: OpCheckpoint(const std::string& note, OpVersionDesc&& op_version_desc) : note_{note}, op_version_desc_{std::forward(op_version_desc)} {} const std::string& note() const { return note_; } const OpVersionDesc& version_desc() { return op_version_desc_; } OpCheckpoint() = default; OpCheckpoint(OpCheckpoint&&) = default; OpCheckpoint& operator=(OpCheckpoint&&) = default; private: std::string note_; OpVersionDesc op_version_desc_; }; class OpVersion { public: OpVersion& AddCheckpoint(const std::string& note, OpVersionDesc&& op_version_desc) { checkpoints_.emplace_back(OpCheckpoint{note, std::move(op_version_desc)}); return *this; } uint32_t version_id() const { return static_cast(checkpoints_.size()); } const std::vector& checkpoints() const { return checkpoints_; } OpVersion() = default; OpVersion(OpVersion&&) = default; OpVersion& operator=(OpVersion&&) = default; private: std::vector checkpoints_; }; class OpVersionRegistrar { public: OpVersionRegistrar() = default; static OpVersionRegistrar& GetInstance() { static OpVersionRegistrar instance; return instance; } OpVersion& Register(const std::string& op_type); const std::unordered_map& GetVersionMap() { return op_version_map_; } bool Has(const std::string& op_type) const { return op_version_map_.count(op_type); } uint32_t version_id(const std::string& op_type) const; private: std::unordered_map op_version_map_; }; inline const std::unordered_map& get_op_version_map() { return OpVersionRegistrar::GetInstance().GetVersionMap(); } inline void SaveOpVersions( const std::unordered_map& src, pb::OpVersionMap* dst) { for (const auto& pair : src) { (*dst)[pair.first].SetVersionID(pair.second.version_id()); } } class OpVersionComparator { public: virtual bool operator()() = 0; virtual ~OpVersionComparator() = default; }; #define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \ class OpVersion##cmp_name##Comparator : public OpVersionComparator { \ public: \ explicit OpVersion##cmp_name##Comparator(const std::string op_name, \ uint32_t target_version) \ : op_name_(op_name), target_version_(target_version) {} \ virtual bool operator()() { \ uint32_t version_id = 0; \ if (OpVersionRegistrar::GetInstance().Has(op_name_)) { \ version_id = OpVersionRegistrar::GetInstance().version_id(op_name_); \ } \ return version_id cmp_math target_version_; \ } \ virtual ~OpVersion##cmp_name##Comparator() {} \ \ private: \ std::string op_name_; \ uint32_t target_version_; \ }; ADD_OP_VERSION_COMPARATOR(LE, <=); ADD_OP_VERSION_COMPARATOR(EQ, ==); ADD_OP_VERSION_COMPARATOR(GE, >=); ADD_OP_VERSION_COMPARATOR(NE, !=); class OpVersionComparatorCombination { public: OpVersionComparatorCombination() {} OpVersionComparatorCombination& LE(const std::string& op_name, int target_version) { op_version_comparators_.push_back(std::shared_ptr( new OpVersionLEComparator(op_name, target_version))); return *this; } OpVersionComparatorCombination& EQ(const std::string& op_name, int target_version) { op_version_comparators_.push_back(std::shared_ptr( new OpVersionEQComparator(op_name, target_version))); return *this; } OpVersionComparatorCombination& GE(const std::string& op_name, int target_version) { op_version_comparators_.push_back(std::shared_ptr( new OpVersionGEComparator(op_name, target_version))); return *this; } OpVersionComparatorCombination& NE(const std::string& op_name, int target_version) { op_version_comparators_.push_back(std::shared_ptr( new OpVersionNEComparator(op_name, target_version))); return *this; } bool IsMatched() const { for (const auto& cmp : op_version_comparators_) { if (!(*cmp)()) { return false; } } return true; } private: std::vector> op_version_comparators_; }; class PassVersionCheckers { public: PassVersionCheckers& AddCombination( const OpVersionComparatorCombination& combinations) { pass_version_checkers_.push_back(combinations); return *this; } bool IsPassCompatible() const { if (pass_version_checkers_.empty()) { return true; } for (const auto& checker : pass_version_checkers_) { if (checker.IsMatched()) { return true; } } return false; } private: std::vector pass_version_checkers_; }; class PassVersionCheckerRegistrar { public: static PassVersionCheckerRegistrar& GetInstance() { static PassVersionCheckerRegistrar instance; return instance; } PassVersionCheckers& Register(const std::string& pass_name) { return pass_version_checkers_map_[pass_name]; } bool IsPassCompatible(const std::string& fuse_pass_name) const { auto iter = pass_version_checkers_map_.find(fuse_pass_name); if (iter == pass_version_checkers_map_.end()) { return true; } return iter->second.IsPassCompatible(); } private: std::unordered_map pass_version_checkers_map_; PassVersionCheckerRegistrar() = default; PassVersionCheckerRegistrar& operator=(const PassVersionCheckerRegistrar&) = delete; }; } // namespace compatible } // namespace framework } // namespace paddle #define REGISTER_OP_VERSION(op_type) \ UNUSED static paddle::framework::compatible::OpVersion& \ RegisterOpVersion__##op_type = \ paddle::framework::compatible::OpVersionRegistrar::GetInstance() \ .Register(#op_type) #define REGISTER_PASS_CAPABILITY(pass_name) \ static auto RegisterOpPassVersionChecker__##pass_name = \ paddle::framework::compatible::PassVersionCheckerRegistrar:: \ GetInstance() \ .Register(#pass_name)