diff --git a/paddle/fluid/framework/op_version_registry.cc b/paddle/fluid/framework/op_version_registry.cc index 38eb8af77db7d046610adacb8ff7dcc024e89a14..bab1f20079c5ba2b9d8686a8f9c536bc42e54883 100644 --- a/paddle/fluid/framework/op_version_registry.cc +++ b/paddle/fluid/framework/op_version_registry.cc @@ -62,6 +62,37 @@ OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged( return std::move(*this); } +OpVersionDesc&& OpVersionDesc::DeleteAttr(const std::string& name, + const std::string& remark) { + infos_.emplace_back( + new_update(OpAttrInfo(name, remark))); + return std::move(*this); +} +OpVersionDesc&& OpVersionDesc::ModifyInput(const std::string& name, + const std::string& remark) { + infos_.emplace_back( + new_update(OpInputOutputInfo(name, remark))); + return std::move(*this); +} +OpVersionDesc&& OpVersionDesc::ModifyOutput(const std::string& name, + const std::string& remark) { + infos_.emplace_back( + new_update(OpInputOutputInfo(name, remark))); + return std::move(*this); +} +OpVersionDesc&& OpVersionDesc::DeleteInput(const std::string& name, + const std::string& remark) { + infos_.emplace_back( + new_update(OpInputOutputInfo(name, remark))); + return std::move(*this); +} +OpVersionDesc&& OpVersionDesc::DeleteOutput(const std::string& name, + const std::string& remark) { + infos_.emplace_back( + new_update(OpInputOutputInfo(name, remark))); + return std::move(*this); +} + OpVersion& OpVersionRegistrar::Register(const std::string& op_type) { PADDLE_ENFORCE_EQ( op_version_map_.find(op_type), op_version_map_.end(), diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index c121e6429dbb414ceb3773ad62f23951f525626d..125346cb22789fd194ce2846f026702a723b4999 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include +#include #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_version_proto.h" @@ -30,16 +31,17 @@ namespace framework { namespace compatible { using OpAttrVariantT = - boost::variant, /* AttrType::BOOLS */ - std::vector, /* AttrType::FLOATS */ - std::vector, /* AttrType::INTS */ - std::vector, /* AttrType::LONGS */ - std::vector /* AttrType::STRINGS */ + boost::variant, /* AttrType::BOOLS */ + std::vector, /* AttrType::FLOATS */ + std::vector, /* AttrType::INTS */ + std::vector, /* AttrType::LONGS */ + std::vector, /* AttrType::STRINGS */ + boost::none_t /* None */ >; struct OpUpdateInfo { @@ -48,7 +50,7 @@ struct OpUpdateInfo { struct OpAttrInfo : OpUpdateInfo { OpAttrInfo(const std::string& name, const std::string& remark, - const OpAttrVariantT& default_value) + const OpAttrVariantT& default_value = boost::none) : name_{name}, default_value_{default_value}, remark_{remark} {} const std::string& name() const { return name_; } @@ -83,11 +85,18 @@ struct OpBugfixInfo : OpUpdateInfo { enum class OpUpdateType { kInvalid = 0, + /* Compatibility upgrade */ kModifyAttr, kNewAttr, kNewInput, kNewOutput, kBugfixWithBehaviorChanged, + /* Incompatible upgrade, only for existing registration. */ + kDeleteAttr = 100, + kModifyInput, + kModifyOutput, + kDeleteInput, + kDeleteOutput, }; class OpUpdateBase { @@ -111,6 +120,7 @@ class OpUpdate : public OpUpdateBase { class OpVersionDesc { public: + /* Compatibility upgrade */ OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark, const OpAttrVariantT& default_value); OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark, @@ -118,10 +128,23 @@ class OpVersionDesc { OpVersionDesc&& NewInput(const std::string& name, const std::string& remark); OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark); OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark); + + /* Incompatible upgrade, only for existing registration. */ + OpVersionDesc&& DeleteAttr(const std::string& name, + const std::string& remark); + OpVersionDesc&& ModifyInput(const std::string& name, + const std::string& remark); + OpVersionDesc&& ModifyOutput(const std::string& name, + const std::string& remark); + OpVersionDesc&& DeleteInput(const std::string& name, + const std::string& remark); + OpVersionDesc&& DeleteOutput(const std::string& name, + const std::string& remark); + + public: const std::vector>& infos() const { return infos_; } - OpVersionDesc() = default; OpVersionDesc(OpVersionDesc&&) = default; OpVersionDesc& operator=(OpVersionDesc&&) = default; diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/paddle/fluid/framework/op_version_registry_test.cc index 888dd6de0618ba1d996983eb4229f53899dc2e86..e66d0dc5a1f79108168337ef1f4b08344d6e3063 100644 --- a/paddle/fluid/framework/op_version_registry_test.cc +++ b/paddle/fluid/framework/op_version_registry_test.cc @@ -53,6 +53,19 @@ TEST(test_operator_version, test_operator_version) { framework::compatible::OpVersionDesc() .NewInput("X2", "The second input.") .NewOutput("Y2", "The second output.")); + + REGISTER_OP_VERSION(op_name_0__) + .AddCheckpoint( + R"ROC( + Incompatible upgrade of attribute [height], input [X2] and output [Y2] + )ROC", + framework::compatible::OpVersionDesc() + .DeleteAttr("height", + "Parameters deleted due to interface alignment.") + .ModifyInput("X2", "Modify input due to interface alignment.") + .ModifyOutput("Y2", "Modify output due to interface alignment.") + .DeleteInput("X2", "Delete input due to interface alignment.") + .DeleteOutput("Y2", "Delete output due to interface alignment.")); } TEST(test_pass_op_version_checker, test_pass_op_version_checker) {