未验证 提交 61fc7a3e 编写于 作者: S Shang Zhizhou 提交者: GitHub

Pass version check (#26887)

上级 f772540d
......@@ -133,6 +133,9 @@ class OpVersion {
checkpoints_.push_back(Checkpoint({note, op_version_desc}));
return *this;
}
uint32_t GetVersionID() const {
return static_cast<uint32_t>(checkpoints_.size());
}
private:
struct Checkpoint {
......@@ -156,6 +159,14 @@ class OpVersionRegistrar {
op_version_map_.insert({op_type, OpVersion()});
return op_version_map_[op_type];
}
uint32_t GetVersionID(const std::string& op_type) const {
auto it = op_version_map_.find(op_type);
if (it == op_version_map_.end()) {
return 0;
}
return it->second.GetVersionID();
}
private:
std::unordered_map<std::string, OpVersion> op_version_map_;
......@@ -164,6 +175,125 @@ class OpVersionRegistrar {
OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete;
};
class OpVersionComparator {
public:
virtual bool operator()() = 0;
virtual ~OpVersionComparator() = default;
};
#define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \
class OpVersion##cmp_name##Comparator : public OpVersionComparator { \
public: \
explicit OpVersion##cmp_name##Comparator(const std::string op_name, \
uint32_t target_version) \
: op_name_(op_name), target_version_(target_version) {} \
virtual bool operator()() { \
return OpVersionRegistrar::GetInstance().GetVersionID(op_name_) \
cmp_math target_version_; \
} \
virtual ~OpVersion##cmp_name##Comparator() {} \
\
private: \
std::string op_name_; \
uint32_t target_version_; \
};
ADD_OP_VERSION_COMPARATOR(LE, <=);
ADD_OP_VERSION_COMPARATOR(EQ, ==);
ADD_OP_VERSION_COMPARATOR(GE, >=);
ADD_OP_VERSION_COMPARATOR(NE, !=);
class OpVersionComparatorCombination {
public:
OpVersionComparatorCombination() {}
OpVersionComparatorCombination& LE(const std::string& op_name,
int target_version) {
op_version_comparators_.push_back(std::shared_ptr<OpVersionComparator>(
new OpVersionLEComparator(op_name, target_version)));
return *this;
}
OpVersionComparatorCombination& EQ(const std::string& op_name,
int target_version) {
op_version_comparators_.push_back(std::shared_ptr<OpVersionComparator>(
new OpVersionEQComparator(op_name, target_version)));
return *this;
}
OpVersionComparatorCombination& GE(const std::string& op_name,
int target_version) {
op_version_comparators_.push_back(std::shared_ptr<OpVersionComparator>(
new OpVersionGEComparator(op_name, target_version)));
return *this;
}
OpVersionComparatorCombination& NE(const std::string& op_name,
int target_version) {
op_version_comparators_.push_back(std::shared_ptr<OpVersionComparator>(
new OpVersionNEComparator(op_name, target_version)));
return *this;
}
bool IsMatched() const {
for (const auto& cmp : op_version_comparators_) {
if (!(*cmp)()) {
return false;
}
}
return true;
}
private:
std::vector<std::shared_ptr<OpVersionComparator>> op_version_comparators_;
};
class PassVersionCheckers {
public:
PassVersionCheckers& AddCombination(
const OpVersionComparatorCombination& combinations) {
pass_version_checkers_.push_back(combinations);
return *this;
}
bool IsPassCompatible() const {
if (pass_version_checkers_.empty()) {
return true;
}
for (const auto& checker : pass_version_checkers_) {
if (checker.IsMatched()) {
return true;
}
}
return false;
}
private:
std::vector<OpVersionComparatorCombination> pass_version_checkers_;
};
class PassVersionCheckerRegistrar {
public:
static PassVersionCheckerRegistrar& GetInstance() {
static PassVersionCheckerRegistrar instance;
return instance;
}
PassVersionCheckers& Register(const std::string& pass_name) {
return pass_version_checkers_map_[pass_name];
}
bool IsPassCompatible(const std::string& fuse_pass_name) const {
auto iter = pass_version_checkers_map_.find(fuse_pass_name);
if (iter == pass_version_checkers_map_.end()) {
return true;
}
return iter->second.IsPassCompatible();
}
private:
std::unordered_map<std::string, PassVersionCheckers>
pass_version_checkers_map_;
PassVersionCheckerRegistrar() = default;
PassVersionCheckerRegistrar& operator=(const PassVersionCheckerRegistrar&) =
delete;
};
} // namespace compatible
} // namespace framework
} // namespace paddle
......@@ -173,3 +303,9 @@ class OpVersionRegistrar {
RegisterOpVersion__##op_type = \
paddle::framework::compatible::OpVersionRegistrar::GetInstance() \
.Register(#op_type)
#define REGISTER_PASS_CAPABILITY(pass_name) \
static auto RegisterOpPassVersionChecker__##pass_name = \
paddle::framework::compatible::PassVersionCheckerRegistrar:: \
GetInstance() \
.Register(#pass_name)
......@@ -55,6 +55,72 @@ TEST(test_operator_version, test_operator_version) {
.NewInput("X2", "The second input.")
.NewOutput("Y2", "The second output."));
}
TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_bind_pass"));
REGISTER_PASS_CAPABILITY(test_pass1)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("mul", 1)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass1"));
REGISTER_PASS_CAPABILITY(test_pass2)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("mul", 0)
.NE("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass2"));
REGISTER_PASS_CAPABILITY(test_pass3)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("mul", 0)
.NE("fc", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("mul", 1)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass3"));
REGISTER_PASS_CAPABILITY(test_pass4)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 5)
.EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass4"));
REGISTER_PASS_CAPABILITY(test_pass5)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 4)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass5"));
REGISTER_PASS_CAPABILITY(test_pass6)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("test__", 4)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass6"));
REGISTER_PASS_CAPABILITY(test_pass7)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.NE("test__", 4)
.EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass7"));
}
} // namespace compatible
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册