From 61fc7a3e45291e6c206afde34c5826b1a8c7c336 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Thu, 3 Sep 2020 09:53:05 +0800 Subject: [PATCH] Pass version check (#26887) --- paddle/fluid/framework/op_version_registry.h | 136 ++++++++++++++++++ .../framework/op_version_registry_test.cc | 66 +++++++++ 2 files changed, 202 insertions(+) diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index 79b15fc87d..5edd70e035 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -133,6 +133,9 @@ class OpVersion { checkpoints_.push_back(Checkpoint({note, op_version_desc})); return *this; } + uint32_t GetVersionID() const { + return static_cast(checkpoints_.size()); + } private: struct Checkpoint { @@ -156,6 +159,14 @@ class OpVersionRegistrar { op_version_map_.insert({op_type, OpVersion()}); return op_version_map_[op_type]; } + uint32_t GetVersionID(const std::string& op_type) const { + auto it = op_version_map_.find(op_type); + if (it == op_version_map_.end()) { + return 0; + } + + return it->second.GetVersionID(); + } private: std::unordered_map op_version_map_; @@ -164,6 +175,125 @@ class OpVersionRegistrar { OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete; }; +class OpVersionComparator { + public: + virtual bool operator()() = 0; + virtual ~OpVersionComparator() = default; +}; + +#define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \ + class OpVersion##cmp_name##Comparator : public OpVersionComparator { \ + public: \ + explicit OpVersion##cmp_name##Comparator(const std::string op_name, \ + uint32_t target_version) \ + : op_name_(op_name), target_version_(target_version) {} \ + virtual bool operator()() { \ + return OpVersionRegistrar::GetInstance().GetVersionID(op_name_) \ + cmp_math target_version_; \ + } \ + virtual ~OpVersion##cmp_name##Comparator() {} \ + \ + private: \ + std::string op_name_; \ + uint32_t target_version_; \ + }; + +ADD_OP_VERSION_COMPARATOR(LE, <=); +ADD_OP_VERSION_COMPARATOR(EQ, ==); +ADD_OP_VERSION_COMPARATOR(GE, >=); +ADD_OP_VERSION_COMPARATOR(NE, !=); + +class OpVersionComparatorCombination { + public: + OpVersionComparatorCombination() {} + + OpVersionComparatorCombination& LE(const std::string& op_name, + int target_version) { + op_version_comparators_.push_back(std::shared_ptr( + new OpVersionLEComparator(op_name, target_version))); + return *this; + } + OpVersionComparatorCombination& EQ(const std::string& op_name, + int target_version) { + op_version_comparators_.push_back(std::shared_ptr( + new OpVersionEQComparator(op_name, target_version))); + return *this; + } + OpVersionComparatorCombination& GE(const std::string& op_name, + int target_version) { + op_version_comparators_.push_back(std::shared_ptr( + new OpVersionGEComparator(op_name, target_version))); + return *this; + } + OpVersionComparatorCombination& NE(const std::string& op_name, + int target_version) { + op_version_comparators_.push_back(std::shared_ptr( + new OpVersionNEComparator(op_name, target_version))); + return *this; + } + + bool IsMatched() const { + for (const auto& cmp : op_version_comparators_) { + if (!(*cmp)()) { + return false; + } + } + return true; + } + + private: + std::vector> op_version_comparators_; +}; + +class PassVersionCheckers { + public: + PassVersionCheckers& AddCombination( + const OpVersionComparatorCombination& combinations) { + pass_version_checkers_.push_back(combinations); + return *this; + } + bool IsPassCompatible() const { + if (pass_version_checkers_.empty()) { + return true; + } + for (const auto& checker : pass_version_checkers_) { + if (checker.IsMatched()) { + return true; + } + } + return false; + } + + private: + std::vector pass_version_checkers_; +}; + +class PassVersionCheckerRegistrar { + public: + static PassVersionCheckerRegistrar& GetInstance() { + static PassVersionCheckerRegistrar instance; + return instance; + } + PassVersionCheckers& Register(const std::string& pass_name) { + return pass_version_checkers_map_[pass_name]; + } + bool IsPassCompatible(const std::string& fuse_pass_name) const { + auto iter = pass_version_checkers_map_.find(fuse_pass_name); + if (iter == pass_version_checkers_map_.end()) { + return true; + } + return iter->second.IsPassCompatible(); + } + + private: + std::unordered_map + pass_version_checkers_map_; + + PassVersionCheckerRegistrar() = default; + PassVersionCheckerRegistrar& operator=(const PassVersionCheckerRegistrar&) = + delete; +}; + } // namespace compatible } // namespace framework } // namespace paddle @@ -173,3 +303,9 @@ class OpVersionRegistrar { RegisterOpVersion__##op_type = \ paddle::framework::compatible::OpVersionRegistrar::GetInstance() \ .Register(#op_type) + +#define REGISTER_PASS_CAPABILITY(pass_name) \ + static auto RegisterOpPassVersionChecker__##pass_name = \ + paddle::framework::compatible::PassVersionCheckerRegistrar:: \ + GetInstance() \ + .Register(#pass_name) diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/paddle/fluid/framework/op_version_registry_test.cc index 80ad51ad07..239dbc4357 100644 --- a/paddle/fluid/framework/op_version_registry_test.cc +++ b/paddle/fluid/framework/op_version_registry_test.cc @@ -55,6 +55,72 @@ TEST(test_operator_version, test_operator_version) { .NewInput("X2", "The second input.") .NewOutput("Y2", "The second output.")); } + +TEST(test_pass_op_version_checker, test_pass_op_version_checker) { + ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( + "no_bind_pass")); + + REGISTER_PASS_CAPABILITY(test_pass1) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("mul", 1) + .EQ("fc", 0)); + ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( + "test_pass1")); + + REGISTER_PASS_CAPABILITY(test_pass2) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("mul", 0) + .NE("fc", 0)); + ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( + "test_pass2")); + + REGISTER_PASS_CAPABILITY(test_pass3) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("mul", 0) + .NE("fc", 0)) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("mul", 1) + .EQ("fc", 0)); + ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( + "test_pass3")); + + REGISTER_PASS_CAPABILITY(test_pass4) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("test__", 5) + .EQ("fc", 0)); + ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( + "test_pass4")); + + REGISTER_PASS_CAPABILITY(test_pass5) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("test__", 4) + .EQ("fc", 0)); + ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( + "test_pass5")); + + REGISTER_PASS_CAPABILITY(test_pass6) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("test__", 4) + .EQ("fc", 0)); + ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( + "test_pass6")); + + REGISTER_PASS_CAPABILITY(test_pass7) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .NE("test__", 4) + .EQ("fc", 0)); + ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( + "test_pass7")); +} + } // namespace compatible } // namespace framework } // namespace paddle -- GitLab