From 32ceacf317c096cfd9d4b72b2e29ec194e3f68ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Tue, 25 Aug 2020 18:55:08 +0800 Subject: [PATCH] update op_version_registry, test=develop (#26644) --- paddle/fluid/framework/op_version_registry.h | 41 +++++++++++++++++-- .../framework/op_version_registry_test.cc | 9 +++- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index 0ffa38a037e..2a85c60305b 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -29,14 +29,20 @@ namespace framework { namespace compatible { struct OpUpdateRecord { - enum class Type { kInvalid = 0, kModifyAttr, kNewAttr }; + enum class Type { + kInvalid = 0, + kModifyAttr, + kNewAttr, + kNewInput, + kNewOutput + }; Type type_; std::string remark_; }; struct ModifyAttr : OpUpdateRecord { ModifyAttr(const std::string& name, const std::string& remark, - boost::any default_value) + const boost::any& default_value) : OpUpdateRecord({Type::kModifyAttr, remark}), name_(name), default_value_(default_value) { @@ -47,9 +53,10 @@ struct ModifyAttr : OpUpdateRecord { std::string name_; boost::any default_value_; }; + struct NewAttr : OpUpdateRecord { NewAttr(const std::string& name, const std::string& remark, - boost::any default_value) + const boost::any& default_value) : OpUpdateRecord({Type::kNewAttr, remark}), name_(name), default_value_(default_value) {} @@ -59,6 +66,22 @@ struct NewAttr : OpUpdateRecord { boost::any default_value_; }; +struct NewInput : OpUpdateRecord { + NewInput(const std::string& name, const std::string& remark) + : OpUpdateRecord({Type::kNewInput, remark}), name_(name) {} + + private: + std::string name_; +}; + +struct NewOutput : OpUpdateRecord { + NewOutput(const std::string& name, const std::string& remark) + : OpUpdateRecord({Type::kNewOutput, remark}), name_(name) {} + + private: + std::string name_; +}; + class OpVersionDesc { public: OpVersionDesc& ModifyAttr(const std::string& name, const std::string& remark, @@ -75,6 +98,18 @@ class OpVersionDesc { return *this; } + OpVersionDesc& NewInput(const std::string& name, const std::string& remark) { + infos_.push_back(std::shared_ptr( + new compatible::NewInput(name, remark))); + return *this; + } + + OpVersionDesc& NewOutput(const std::string& name, const std::string& remark) { + infos_.push_back(std::shared_ptr( + new compatible::NewOutput(name, remark))); + return *this; + } + private: std::vector> infos_; }; diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/paddle/fluid/framework/op_version_registry_test.cc index 77891dafc81..052bf3a4b88 100644 --- a/paddle/fluid/framework/op_version_registry_test.cc +++ b/paddle/fluid/framework/op_version_registry_test.cc @@ -42,7 +42,14 @@ TEST(test_operator_version, test_operator_version) { "height", "In order to represent a two-dimensional rectangle, the " "parameter height is added.", - 0)); + 0)) + .AddCheckpoint( + R"ROC( + Add a input [X2] and a output [Y2] + )ROC", + framework::compatible::OpVersionDesc() + .NewInput("X2", "The second input.") + .NewOutput("Y2", "The second output.")); } } // namespace compatible } // namespace framework -- GitLab