diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index f9ab60c5c7478c59a2a99a3acbf510687c61a939..f2f7e16ff2bbe195822397fc9d271c14be8c4449 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -23,9 +23,9 @@ function(pass_library TARGET DEST) cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) if(pass_library_DIR) - cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS}) + cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS}) else() - cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS}) + cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS}) endif() # add more DEST here, such as train, dist and collect USE_PASS into a file automatically. diff --git a/paddle/fluid/framework/op_version_registry.cc b/paddle/fluid/framework/op_version_registry.cc index 9a67c160f0233565e97b0d1280c39eab2e1bd4f6..38eb8af77db7d046610adacb8ff7dcc024e89a14 100644 --- a/paddle/fluid/framework/op_version_registry.cc +++ b/paddle/fluid/framework/op_version_registry.cc @@ -13,3 +13,75 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace compatible { + +namespace { +template +OpUpdate* new_update(InfoType&& info) { + return new OpUpdate(info); +} +} + +OpVersionDesc&& OpVersionDesc::ModifyAttr(const std::string& name, + const std::string& remark, + const OpAttrVariantT& default_value) { + infos_.emplace_back(new_update( + OpAttrInfo(name, remark, default_value))); + return std::move(*this); +} + +OpVersionDesc&& OpVersionDesc::NewAttr(const std::string& name, + const std::string& remark, + const OpAttrVariantT& default_value) { + infos_.emplace_back(new_update( + OpAttrInfo(name, remark, default_value))); + return std::move(*this); +} + +OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name, + const std::string& remark) { + infos_.emplace_back( + new_update(OpInputOutputInfo(name, remark))); + return std::move(*this); +} + +OpVersionDesc&& OpVersionDesc::NewOutput(const std::string& name, + const std::string& remark) { + infos_.emplace_back( + new_update(OpInputOutputInfo(name, remark))); + return std::move(*this); +} + +OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged( + const std::string& remark) { + infos_.emplace_back(new_update( + OpBugfixInfo(remark))); + return std::move(*this); +} + +OpVersion& OpVersionRegistrar::Register(const std::string& op_type) { + PADDLE_ENFORCE_EQ( + op_version_map_.find(op_type), op_version_map_.end(), + platform::errors::AlreadyExists( + "'%s' is registered in operator version more than once.", op_type)); + op_version_map_.insert( + std::pair{op_type, OpVersion()}); + return op_version_map_[op_type]; +} +uint32_t OpVersionRegistrar::version_id(const std::string& op_type) const { + PADDLE_ENFORCE_NE( + op_version_map_.count(op_type), 0, + platform::errors::InvalidArgument( + "The version of operator type %s has not been registered.", op_type)); + return op_version_map_.find(op_type)->second.version_id(); +} + +// Provide a fake registration item for pybind testing. +#include "paddle/fluid/framework/op_version_registry.inl" + +} // namespace compatible +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index 5ddaf1bd8d8ce1b3881db914455684c5cfabb566..5822dfa11dd25a5e84800871c7efc73f375e2109 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -20,7 +20,7 @@ limitations under the License. */ #include #include -#include +#include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_version_proto.h" #include "paddle/fluid/platform/enforce.h" @@ -29,160 +29,173 @@ namespace paddle { namespace framework { namespace compatible { -struct OpUpdateRecord { - enum class Type { - kInvalid = 0, - kModifyAttr, - kNewAttr, - kNewInput, - kNewOutput, - kBugfixWithBehaviorChanged, - }; - Type type_; - std::string remark_; +using OpAttrVariantT = + boost::variant, /* AttrType::BOOLS */ + std::vector, /* AttrType::FLOATS */ + std::vector, /* AttrType::INTS */ + std::vector, /* AttrType::LONGS */ + std::vector /* AttrType::STRINGS */ + >; + +struct OpUpdateInfo { + virtual ~OpUpdateInfo() = default; }; -struct ModifyAttr : OpUpdateRecord { - ModifyAttr(const std::string& name, const std::string& remark, - const boost::any& default_value) - : OpUpdateRecord({Type::kModifyAttr, remark}), - name_(name), - default_value_(default_value) { - // TODO(Shixiaowei02): Check the data type with proto::OpDesc. - } +struct OpAttrInfo : OpUpdateInfo { + OpAttrInfo(const std::string& name, const std::string& remark, + const OpAttrVariantT& default_value) + : name_{name}, default_value_{default_value}, remark_{remark} {} + + const std::string& name() const { return name_; } + const OpAttrVariantT& default_value() const { return default_value_; } + const std::string& remark() const { return remark_; } private: std::string name_; - boost::any default_value_; + OpAttrVariantT default_value_; + std::string remark_; }; -struct NewAttr : OpUpdateRecord { - NewAttr(const std::string& name, const std::string& remark, - const boost::any& default_value) - : OpUpdateRecord({Type::kNewAttr, remark}), - name_(name), - default_value_(default_value) {} +struct OpInputOutputInfo : OpUpdateInfo { + OpInputOutputInfo(const std::string& name, const std::string& remark) + : name_{name}, remark_{remark} {} + + const std::string& name() const { return name_; } + const std::string& remark() const { return remark_; } private: std::string name_; - boost::any default_value_; + std::string remark_; }; -struct NewInput : OpUpdateRecord { - NewInput(const std::string& name, const std::string& remark) - : OpUpdateRecord({Type::kNewInput, remark}), name_(name) {} +struct OpBugfixInfo : OpUpdateInfo { + explicit OpBugfixInfo(const std::string& remark) : remark_{remark} {} + const std::string& remark() const { return remark_; } private: - std::string name_; + std::string remark_; }; -struct NewOutput : OpUpdateRecord { - NewOutput(const std::string& name, const std::string& remark) - : OpUpdateRecord({Type::kNewOutput, remark}), name_(name) {} +enum class OpUpdateType { + kInvalid = 0, + kModifyAttr, + kNewAttr, + kNewInput, + kNewOutput, + kBugfixWithBehaviorChanged, +}; - private: - std::string name_; +class OpUpdateBase { + public: + virtual const OpUpdateInfo* info() const = 0; + virtual OpUpdateType type() const = 0; + virtual ~OpUpdateBase() = default; }; -struct BugfixWithBehaviorChanged : OpUpdateRecord { - explicit BugfixWithBehaviorChanged(const std::string& remark) - : OpUpdateRecord({Type::kBugfixWithBehaviorChanged, remark}) {} +template +class OpUpdate : public OpUpdateBase { + public: + explicit OpUpdate(const InfoType& info) : info_{info}, type_{type__} {} + const OpUpdateInfo* info() const override { return &info_; } + OpUpdateType type() const override { return type_; } + + private: + InfoType info_; + OpUpdateType type_; }; class OpVersionDesc { public: - OpVersionDesc& ModifyAttr(const std::string& name, const std::string& remark, - boost::any default_value) { - infos_.push_back(std::shared_ptr( - new compatible::ModifyAttr(name, remark, default_value))); - return *this; + OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark, + const OpAttrVariantT& default_value); + OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark, + const OpAttrVariantT& default_value); + OpVersionDesc&& NewInput(const std::string& name, const std::string& remark); + OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark); + OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark); + const std::vector>& infos() const { + return infos_; } - OpVersionDesc& NewAttr(const std::string& name, const std::string& remark, - boost::any default_value) { - infos_.push_back(std::shared_ptr( - new compatible::NewAttr(name, remark, default_value))); - return *this; - } + OpVersionDesc() = default; + OpVersionDesc(OpVersionDesc&&) = default; + OpVersionDesc& operator=(OpVersionDesc&&) = default; - OpVersionDesc& NewInput(const std::string& name, const std::string& remark) { - infos_.push_back(std::shared_ptr( - new compatible::NewInput(name, remark))); - return *this; - } + private: + std::vector> infos_; +}; - OpVersionDesc& NewOutput(const std::string& name, const std::string& remark) { - infos_.push_back(std::shared_ptr( - new compatible::NewOutput(name, remark))); - return *this; - } +class OpCheckpoint { + public: + OpCheckpoint(const std::string& note, OpVersionDesc&& op_version_desc) + : note_{note}, + op_version_desc_{std::forward(op_version_desc)} {} + const std::string& note() const { return note_; } + const OpVersionDesc& version_desc() { return op_version_desc_; } - OpVersionDesc& BugfixWithBehaviorChanged(const std::string& remark) { - infos_.push_back(std::shared_ptr( - new compatible::BugfixWithBehaviorChanged(remark))); - return *this; - } + OpCheckpoint() = default; + OpCheckpoint(OpCheckpoint&&) = default; + OpCheckpoint& operator=(OpCheckpoint&&) = default; private: - std::vector> infos_; + std::string note_; + OpVersionDesc op_version_desc_; }; class OpVersion { public: OpVersion& AddCheckpoint(const std::string& note, - const OpVersionDesc& op_version_desc) { - checkpoints_.push_back(Checkpoint({note, op_version_desc})); + OpVersionDesc&& op_version_desc) { + checkpoints_.emplace_back(OpCheckpoint{note, std::move(op_version_desc)}); return *this; } - uint32_t GetVersionID() const { + uint32_t version_id() const { return static_cast(checkpoints_.size()); } + const std::vector& checkpoints() const { return checkpoints_; } + + OpVersion() = default; + OpVersion(OpVersion&&) = default; + OpVersion& operator=(OpVersion&&) = default; private: - struct Checkpoint { - std::string note_; - OpVersionDesc op_version_desc_; - }; - std::vector checkpoints_; + std::vector checkpoints_; }; class OpVersionRegistrar { public: + OpVersionRegistrar() = default; static OpVersionRegistrar& GetInstance() { static OpVersionRegistrar instance; return instance; } - OpVersion& Register(const std::string& op_type) { - PADDLE_ENFORCE_EQ( - op_version_map_.find(op_type), op_version_map_.end(), - platform::errors::AlreadyExists( - "'%s' is registered in operator version more than once.", op_type)); - op_version_map_.insert({op_type, OpVersion()}); - return op_version_map_[op_type]; - } + OpVersion& Register(const std::string& op_type); const std::unordered_map& GetVersionMap() { return op_version_map_; } - uint32_t GetVersionID(const std::string& op_type) const { - auto it = op_version_map_.find(op_type); - if (it == op_version_map_.end()) { - return 0; - } - return it->second.GetVersionID(); + bool Has(const std::string& op_type) const { + return op_version_map_.count(op_type); } + uint32_t version_id(const std::string& op_type) const; private: std::unordered_map op_version_map_; - - OpVersionRegistrar() = default; - OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete; }; +inline const std::unordered_map& get_op_version_map() { + return OpVersionRegistrar::GetInstance().GetVersionMap(); +} + inline void SaveOpVersions( const std::unordered_map& src, pb::OpVersionMap* dst) { for (const auto& pair : src) { - (*dst)[pair.first].SetVersionID(pair.second.GetVersionID()); + (*dst)[pair.first].SetVersionID(pair.second.version_id()); } } @@ -192,21 +205,24 @@ class OpVersionComparator { virtual ~OpVersionComparator() = default; }; -#define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \ - class OpVersion##cmp_name##Comparator : public OpVersionComparator { \ - public: \ - explicit OpVersion##cmp_name##Comparator(const std::string op_name, \ - uint32_t target_version) \ - : op_name_(op_name), target_version_(target_version) {} \ - virtual bool operator()() { \ - return OpVersionRegistrar::GetInstance().GetVersionID(op_name_) \ - cmp_math target_version_; \ - } \ - virtual ~OpVersion##cmp_name##Comparator() {} \ - \ - private: \ - std::string op_name_; \ - uint32_t target_version_; \ +#define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \ + class OpVersion##cmp_name##Comparator : public OpVersionComparator { \ + public: \ + explicit OpVersion##cmp_name##Comparator(const std::string op_name, \ + uint32_t target_version) \ + : op_name_(op_name), target_version_(target_version) {} \ + virtual bool operator()() { \ + uint32_t version_id = 0; \ + if (OpVersionRegistrar::GetInstance().Has(op_name_)) { \ + version_id = OpVersionRegistrar::GetInstance().version_id(op_name_); \ + } \ + return version_id cmp_math target_version_; \ + } \ + virtual ~OpVersion##cmp_name##Comparator() {} \ + \ + private: \ + std::string op_name_; \ + uint32_t target_version_; \ }; ADD_OP_VERSION_COMPARATOR(LE, <=); @@ -310,7 +326,7 @@ class PassVersionCheckerRegistrar { } // namespace paddle #define REGISTER_OP_VERSION(op_type) \ - static paddle::framework::compatible::OpVersion \ + UNUSED static paddle::framework::compatible::OpVersion& \ RegisterOpVersion__##op_type = \ paddle::framework::compatible::OpVersionRegistrar::GetInstance() \ .Register(#op_type) diff --git a/paddle/fluid/framework/op_version_registry.inl b/paddle/fluid/framework/op_version_registry.inl new file mode 100644 index 0000000000000000000000000000000000000000..ec90b3028be220256763120355ee6fac86c0fd69 --- /dev/null +++ b/paddle/fluid/framework/op_version_registry.inl @@ -0,0 +1,42 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +REGISTER_OP_VERSION(for_pybind_test__) + .AddCheckpoint("Note 0", framework::compatible::OpVersionDesc() + .BugfixWithBehaviorChanged( + "BugfixWithBehaviorChanged Remark")) + .AddCheckpoint("Note 1", framework::compatible::OpVersionDesc() + .ModifyAttr("BOOL", "bool", true) + .ModifyAttr("FLOAT", "float", 1.23f) + .ModifyAttr("INT", "int32", -1) + .ModifyAttr("STRING", "std::string", + std::string{"hello"})) + .AddCheckpoint("Note 2", + framework::compatible::OpVersionDesc() + .ModifyAttr("BOOLS", "std::vector", + std::vector{true, false}) + .ModifyAttr("FLOATS", "std::vector", + std::vector{2.56f, 1.28f}) + .ModifyAttr("INTS", "std::vector", + std::vector{10, 100}) + .NewAttr("LONGS", "std::vector", + std::vector{10000001, -10000001})) + .AddCheckpoint("Note 3", framework::compatible::OpVersionDesc() + .NewAttr("STRINGS", "std::vector", + std::vector{"str1", "str2"}) + .ModifyAttr("LONG", "int64", static_cast(10000001)) + .NewInput("NewInput", "NewInput_") + .NewOutput("NewOutput", "NewOutput_") + .BugfixWithBehaviorChanged( + "BugfixWithBehaviorChanged_")); diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/paddle/fluid/framework/op_version_registry_test.cc index 2b173c95715881ad54fe623fe3d84ae3ce06b5d5..ef8384c1e7ee1d58f1e8e8cfda6d0ae54fc756ed 100644 --- a/paddle/fluid/framework/op_version_registry_test.cc +++ b/paddle/fluid/framework/op_version_registry_test.cc @@ -21,7 +21,7 @@ namespace framework { namespace compatible { TEST(test_operator_version, test_operator_version) { - REGISTER_OP_VERSION(test__) + REGISTER_OP_VERSION(op_name__) .AddCheckpoint( R"ROC(Fix the bug of reshape op, support the case of axis < 0)ROC", framework::compatible::OpVersionDesc().BugfixWithBehaviorChanged( @@ -56,6 +56,7 @@ TEST(test_operator_version, test_operator_version) { } TEST(test_pass_op_version_checker, test_pass_op_version_checker) { + const std::string fake_op_name{"op_name__"}; ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( "no_bind_pass")); @@ -90,7 +91,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) { REGISTER_PASS_CAPABILITY(test_pass4) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .GE("test__", 5) + .GE(fake_op_name, 5) .EQ("fc", 0)); ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( "test_pass4")); @@ -98,7 +99,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) { REGISTER_PASS_CAPABILITY(test_pass5) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .GE("test__", 4) + .GE(fake_op_name, 4) .EQ("fc", 0)); ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( "test_pass5")); @@ -106,7 +107,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) { REGISTER_PASS_CAPABILITY(test_pass6) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("test__", 4) + .EQ(fake_op_name, 4) .EQ("fc", 0)); ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( "test_pass6")); @@ -114,7 +115,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) { REGISTER_PASS_CAPABILITY(test_pass7) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .NE("test__", 4) + .NE(fake_op_name, 4) .EQ("fc", 0)); ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( "test_pass7")); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 5fa8f6bab8cca94617f401f8b50b2572d9a55cb3..ca80ada7b6ea78d96c3b4e3e2657c6e0e929acab 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -104,6 +104,7 @@ endif() set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} op_version_registry) # FIXME(typhoonzero): operator deps may not needed. # op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 92d9473141009216e3c7e64ccb793884dc67aadc..6fd1b7e1d36c2363ec6730df7b979ed279702087 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context - gloo_wrapper infer_io_utils heter_wrapper generator) + gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry) if (WITH_NCCL) set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper) diff --git a/paddle/fluid/pybind/compatible.cc b/paddle/fluid/pybind/compatible.cc index 971d230458db4bc2196ca529e01b0586da79567c..57b024c25cbaf9ce87081a33e2b8756ed4e725eb 100644 --- a/paddle/fluid/pybind/compatible.cc +++ b/paddle/fluid/pybind/compatible.cc @@ -13,26 +13,136 @@ // limitations under the License. #include "paddle/fluid/pybind/compatible.h" - #include #include - #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/pybind/pybind_boost_headers.h" namespace py = pybind11; -using paddle::framework::compatible::PassVersionCheckerRegistrar; +using paddle::framework::compatible::OpAttrVariantT; +using paddle::framework::compatible::OpUpdateInfo; +using paddle::framework::compatible::OpAttrInfo; +using paddle::framework::compatible::OpInputOutputInfo; +using paddle::framework::compatible::OpBugfixInfo; +using paddle::framework::compatible::OpUpdateType; +using paddle::framework::compatible::OpUpdateBase; +using paddle::framework::compatible::OpVersionDesc; +using paddle::framework::compatible::OpCheckpoint; +using paddle::framework::compatible::OpVersion; namespace paddle { namespace pybind { -void BindCompatible(py::module* m) { +namespace { +using paddle::framework::compatible::PassVersionCheckerRegistrar; +void BindPassVersionChecker(py::module *m) { py::class_(*m, "PassVersionChecker") - .def_static("IsCompatible", [](const std::string& name) -> bool { + .def_static("IsCompatible", [](const std::string &name) -> bool { auto instance = PassVersionCheckerRegistrar::GetInstance(); return instance.IsPassCompatible(name); }); } +void BindPassCompatible(py::module *m) { BindPassVersionChecker(m); } + +void BindOpUpdateInfo(py::module *m) { + py::class_(*m, "OpUpdateInfo").def(py::init<>()); +} + +void BindOpAttrInfo(py::module *m) { + py::class_(*m, "OpAttrInfo") + .def(py::init()) + .def(py::init()) + .def("name", &OpAttrInfo::name) + .def("default_value", &OpAttrInfo::default_value) + .def("remark", &OpAttrInfo::remark); +} + +void BindOpInputOutputInfo(py::module *m) { + py::class_(*m, "OpInputOutputInfo") + .def(py::init()) + .def(py::init()) + .def("name", &OpInputOutputInfo::name) + .def("remark", &OpInputOutputInfo::remark); +} + +void BindOpBugfixInfo(py::module *m) { + py::class_(*m, "OpBugfixInfo") + .def(py::init()) + .def(py::init()) + .def("remark", &OpBugfixInfo::remark); +} + +void BindOpCompatible(py::module *m) { + BindOpUpdateInfo(m); + BindOpAttrInfo(m); + BindOpInputOutputInfo(m); + BindOpBugfixInfo(m); +} + +void BindOpUpdateType(py::module *m) { + py::enum_(*m, "OpUpdateType") + .value("kInvalid", OpUpdateType::kInvalid) + .value("kModifyAttr", OpUpdateType::kModifyAttr) + .value("kNewAttr", OpUpdateType::kNewAttr) + .value("kNewInput", OpUpdateType::kNewInput) + .value("kNewOutput", OpUpdateType::kNewOutput) + .value("kBugfixWithBehaviorChanged", + OpUpdateType::kBugfixWithBehaviorChanged); +} + +void BindOpUpdateBase(py::module *m) { + py::class_(*m, "OpUpdateBase") + .def("info", [](const OpUpdateBase &obj) { return obj.info(); }, + py::return_value_policy::reference) + .def("type", &OpUpdateBase::type); +} + +void BindOpVersionDesc(py::module *m) { + py::class_(*m, "OpVersionDesc") + // Pybind11 does not yet support the transfer of `const + // std::vector>&` type objects. + .def("infos", [](const OpVersionDesc &obj) { + auto pylist = py::list(); + for (const auto &ptr : obj.infos()) { + auto pyobj = py::cast(*ptr, py::return_value_policy::reference); + pylist.append(pyobj); + } + return pylist; + }); +} + +void BindOpCheckpoint(py::module *m) { + py::class_(*m, "OpCheckpoint") + .def("note", &OpCheckpoint::note, py::return_value_policy::reference) + .def("version_desc", &OpCheckpoint::version_desc, + py::return_value_policy::reference); +} + +void BindOpVersion(py::module *m) { + py::class_(*m, "OpVersion") + .def("version_id", &OpVersion::version_id, + py::return_value_policy::reference) + .def("checkpoints", &OpVersion::checkpoints, + py::return_value_policy::reference); + // At least pybind v2.3.0 is required because of bug #1603 of pybind11. + m->def("get_op_version_map", &framework::compatible::get_op_version_map, + py::return_value_policy::reference); +} + +} // namespace + +void BindCompatible(py::module *m) { + BindPassCompatible(m); + BindOpCompatible(m); + BindOpUpdateType(m); + BindOpUpdateBase(m); + BindOpVersionDesc(m); + BindOpCheckpoint(m); + BindOpVersion(m); +} + } // namespace pybind } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_op_version.py b/python/paddle/fluid/tests/unittests/test_op_version.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7167955ac7c93563936c0bae7c1898a50740d6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_op_version.py @@ -0,0 +1,83 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle.utils as utils +import paddle.fluid as fluid + + +class OpLastCheckpointCheckerTest(unittest.TestCase): + def __init__(self, methodName='runTest'): + super(OpLastCheckpointCheckerTest, self).__init__(methodName) + self.checker = utils.OpLastCheckpointChecker() + self.fake_op = 'for_pybind_test__' + + def test_op_attr_info(self): + update_type = fluid.core.OpUpdateType.kNewAttr + info_list = self.checker.filter_updates(self.fake_op, update_type, + 'STRINGS') + self.assertTrue(info_list) + self.assertEqual(info_list[0].name(), 'STRINGS') + self.assertEqual(info_list[0].default_value(), ['str1', 'str2']) + self.assertEqual(info_list[0].remark(), 'std::vector') + + def test_op_input_output_info(self): + update_type = fluid.core.OpUpdateType.kNewInput + info_list = self.checker.filter_updates(self.fake_op, update_type, + 'NewInput') + self.assertTrue(info_list) + self.assertEqual(info_list[0].name(), 'NewInput') + self.assertEqual(info_list[0].remark(), 'NewInput_') + + def test_op_bug_fix_info(self): + update_type = fluid.core.OpUpdateType.kBugfixWithBehaviorChanged + info_list = self.checker.filter_updates(self.fake_op, update_type) + self.assertTrue(info_list) + self.assertEqual(info_list[0].remark(), 'BugfixWithBehaviorChanged_') + + +class OpVersionTest(unittest.TestCase): + def __init__(self, methodName='runTest'): + super(OpVersionTest, self).__init__(methodName) + self.vmap = fluid.core.get_op_version_map() + self.fake_op = 'for_pybind_test__' + + def test_checkpoints(self): + version_id = self.vmap[self.fake_op].version_id() + checkpoints = self.vmap[self.fake_op].checkpoints() + self.assertEqual(version_id, 4) + self.assertEqual(len(checkpoints), 4) + self.assertEqual(checkpoints[2].note(), 'Note 2') + desc_1 = checkpoints[1].version_desc().infos() + self.assertEqual(desc_1[0].info().default_value(), True) + self.assertAlmostEqual(desc_1[1].info().default_value(), 1.23, 2) + self.assertEqual(desc_1[2].info().default_value(), -1) + self.assertEqual(desc_1[3].info().default_value(), 'hello') + desc_2 = checkpoints[2].version_desc().infos() + self.assertEqual(desc_2[0].info().default_value(), [True, False]) + true_l = [2.56, 1.28] + self.assertEqual(len(true_l), len(desc_2[1].info().default_value())) + for i in range(len(true_l)): + self.assertAlmostEqual(desc_2[1].info().default_value()[i], + true_l[i], 2) + self.assertEqual(desc_2[2].info().default_value(), [10, 100]) + self.assertEqual(desc_2[3].info().default_value(), + [10000001, -10000001]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/utils/__init__.py b/python/paddle/utils/__init__.py index 9d7a05131ffa13c294ac63dfa58d8e42d5143165..faf0fd4984d7ca0c1d1c2457cf9b6f6186852eb9 100644 --- a/python/paddle/utils/__init__.py +++ b/python/paddle/utils/__init__.py @@ -17,6 +17,7 @@ from .profiler import Profiler from .profiler import get_profiler from .deprecated import deprecated from .lazy_import import try_import +from .op_version import OpLastCheckpointChecker from .install_check import run_check from ..fluid.framework import unique_name from ..fluid.framework import load_op_library diff --git a/python/paddle/utils/op_version.py b/python/paddle/utils/op_version.py new file mode 100644 index 0000000000000000000000000000000000000000..68acc9de081518959b0016f8e7ec2064b6e01527 --- /dev/null +++ b/python/paddle/utils/op_version.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..fluid import core + +__all__ = ['OpLastCheckpointChecker'] + + +def Singleton(cls): + _instance = {} + + def _singleton(*args, **kargs): + if cls not in _instance: + _instance[cls] = cls(*args, **kargs) + return _instance[cls] + + return _singleton + + +class OpUpdateInfoHelper(object): + def __init__(self, info): + self._info = info + + def verify_key_value(self, name=''): + result = False + key_funcs = { + core.OpAttrInfo: 'name', + core.OpInputOutputInfo: 'name', + } + if name == '': + result = True + elif type(self._info) in key_funcs: + if getattr(self._info, key_funcs[type(self._info)])() == name: + result = True + return result + + +@Singleton +class OpLastCheckpointChecker(object): + def __init__(self): + self.raw_version_map = core.get_op_version_map() + self.checkpoints_map = {} + self._construct_map() + + def _construct_map(self): + for op_name in self.raw_version_map: + last_checkpoint = self.raw_version_map[op_name].checkpoints()[-1] + infos = last_checkpoint.version_desc().infos() + self.checkpoints_map[op_name] = infos + + def filter_updates(self, op_name, type=core.OpUpdateType.kInvalid, key=''): + updates = [] + if op_name in self.checkpoints_map: + for update in self.checkpoints_map[op_name]: + if (update.type() == type) or ( + type == core.OpUpdateType.kInvalid): + if OpUpdateInfoHelper(update.info()).verify_key_value(key): + updates.append(update.info()) + return updates