diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index 0ffa38a037ea431c8c860da8db14c799d1524c0d..2a85c60305bd36e78c071f5703885c23e33b403e 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -29,14 +29,20 @@ namespace framework { namespace compatible { struct OpUpdateRecord { - enum class Type { kInvalid = 0, kModifyAttr, kNewAttr }; + enum class Type { + kInvalid = 0, + kModifyAttr, + kNewAttr, + kNewInput, + kNewOutput + }; Type type_; std::string remark_; }; struct ModifyAttr : OpUpdateRecord { ModifyAttr(const std::string& name, const std::string& remark, - boost::any default_value) + const boost::any& default_value) : OpUpdateRecord({Type::kModifyAttr, remark}), name_(name), default_value_(default_value) { @@ -47,9 +53,10 @@ struct ModifyAttr : OpUpdateRecord { std::string name_; boost::any default_value_; }; + struct NewAttr : OpUpdateRecord { NewAttr(const std::string& name, const std::string& remark, - boost::any default_value) + const boost::any& default_value) : OpUpdateRecord({Type::kNewAttr, remark}), name_(name), default_value_(default_value) {} @@ -59,6 +66,22 @@ struct NewAttr : OpUpdateRecord { boost::any default_value_; }; +struct NewInput : OpUpdateRecord { + NewInput(const std::string& name, const std::string& remark) + : OpUpdateRecord({Type::kNewInput, remark}), name_(name) {} + + private: + std::string name_; +}; + +struct NewOutput : OpUpdateRecord { + NewOutput(const std::string& name, const std::string& remark) + : OpUpdateRecord({Type::kNewOutput, remark}), name_(name) {} + + private: + std::string name_; +}; + class OpVersionDesc { public: OpVersionDesc& ModifyAttr(const std::string& name, const std::string& remark, @@ -75,6 +98,18 @@ class OpVersionDesc { return *this; } + OpVersionDesc& NewInput(const std::string& name, const std::string& remark) { + infos_.push_back(std::shared_ptr( + new compatible::NewInput(name, remark))); + return *this; + } + + OpVersionDesc& NewOutput(const std::string& name, const std::string& remark) { + infos_.push_back(std::shared_ptr( + new compatible::NewOutput(name, remark))); + return *this; + } + private: std::vector> infos_; }; diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/paddle/fluid/framework/op_version_registry_test.cc index 77891dafc81b3a96af28cb480f1620543caab0b8..052bf3a4b882be749e70704f18f09a7b24551ed7 100644 --- a/paddle/fluid/framework/op_version_registry_test.cc +++ b/paddle/fluid/framework/op_version_registry_test.cc @@ -42,7 +42,14 @@ TEST(test_operator_version, test_operator_version) { "height", "In order to represent a two-dimensional rectangle, the " "parameter height is added.", - 0)); + 0)) + .AddCheckpoint( + R"ROC( + Add a input [X2] and a output [Y2] + )ROC", + framework::compatible::OpVersionDesc() + .NewInput("X2", "The second input.") + .NewOutput("Y2", "The second output.")); } } // namespace compatible } // namespace framework