未验证 提交 8bc0a31b 编写于 作者: 石晓伟 提交者: GitHub

update the operator registration for incompatible upgrade, test=develop (#29720) (#29774)

上级 0e2f5bb1
...@@ -62,6 +62,37 @@ OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged( ...@@ -62,6 +62,37 @@ OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged(
return std::move(*this); return std::move(*this);
} }
OpVersionDesc&& OpVersionDesc::DeleteAttr(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kDeleteAttr>(OpAttrInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::ModifyInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kModifyInput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::ModifyOutput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kModifyOutput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::DeleteInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kDeleteInput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::DeleteOutput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kDeleteOutput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersion& OpVersionRegistrar::Register(const std::string& op_type) { OpVersion& OpVersionRegistrar::Register(const std::string& op_type) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_version_map_.find(op_type), op_version_map_.end(), op_version_map_.find(op_type), op_version_map_.end(),
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <boost/none.hpp>
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_version_proto.h" #include "paddle/fluid/framework/op_version_proto.h"
...@@ -30,16 +31,17 @@ namespace framework { ...@@ -30,16 +31,17 @@ namespace framework {
namespace compatible { namespace compatible {
using OpAttrVariantT = using OpAttrVariantT =
boost::variant<bool, /* AttrType::BOOL */ boost::variant<bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */ float, /* AttrType::FLOAT */
int32_t, /* AttrType::INT */ int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/ int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */ std::string, /* AttrType::STRING */
std::vector<bool>, /* AttrType::BOOLS */ std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */ std::vector<float>, /* AttrType::FLOATS */
std::vector<int32_t>, /* AttrType::INTS */ std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */ std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string> /* AttrType::STRINGS */ std::vector<std::string>, /* AttrType::STRINGS */
boost::none_t /* None */
>; >;
struct OpUpdateInfo { struct OpUpdateInfo {
...@@ -48,7 +50,7 @@ struct OpUpdateInfo { ...@@ -48,7 +50,7 @@ struct OpUpdateInfo {
struct OpAttrInfo : OpUpdateInfo { struct OpAttrInfo : OpUpdateInfo {
OpAttrInfo(const std::string& name, const std::string& remark, OpAttrInfo(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value) const OpAttrVariantT& default_value = boost::none)
: name_{name}, default_value_{default_value}, remark_{remark} {} : name_{name}, default_value_{default_value}, remark_{remark} {}
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
...@@ -83,11 +85,18 @@ struct OpBugfixInfo : OpUpdateInfo { ...@@ -83,11 +85,18 @@ struct OpBugfixInfo : OpUpdateInfo {
enum class OpUpdateType { enum class OpUpdateType {
kInvalid = 0, kInvalid = 0,
/* Compatibility upgrade */
kModifyAttr, kModifyAttr,
kNewAttr, kNewAttr,
kNewInput, kNewInput,
kNewOutput, kNewOutput,
kBugfixWithBehaviorChanged, kBugfixWithBehaviorChanged,
/* Incompatible upgrade, only for existing registration. */
kDeleteAttr = 100,
kModifyInput,
kModifyOutput,
kDeleteInput,
kDeleteOutput,
}; };
class OpUpdateBase { class OpUpdateBase {
...@@ -111,6 +120,7 @@ class OpUpdate : public OpUpdateBase { ...@@ -111,6 +120,7 @@ class OpUpdate : public OpUpdateBase {
class OpVersionDesc { class OpVersionDesc {
public: public:
/* Compatibility upgrade */
OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark, OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value); const OpAttrVariantT& default_value);
OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark, OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark,
...@@ -118,10 +128,23 @@ class OpVersionDesc { ...@@ -118,10 +128,23 @@ class OpVersionDesc {
OpVersionDesc&& NewInput(const std::string& name, const std::string& remark); OpVersionDesc&& NewInput(const std::string& name, const std::string& remark);
OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark); OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark);
OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark); OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark);
/* Incompatible upgrade, only for existing registration. */
OpVersionDesc&& DeleteAttr(const std::string& name,
const std::string& remark);
OpVersionDesc&& ModifyInput(const std::string& name,
const std::string& remark);
OpVersionDesc&& ModifyOutput(const std::string& name,
const std::string& remark);
OpVersionDesc&& DeleteInput(const std::string& name,
const std::string& remark);
OpVersionDesc&& DeleteOutput(const std::string& name,
const std::string& remark);
public:
const std::vector<std::unique_ptr<OpUpdateBase>>& infos() const { const std::vector<std::unique_ptr<OpUpdateBase>>& infos() const {
return infos_; return infos_;
} }
OpVersionDesc() = default; OpVersionDesc() = default;
OpVersionDesc(OpVersionDesc&&) = default; OpVersionDesc(OpVersionDesc&&) = default;
OpVersionDesc& operator=(OpVersionDesc&&) = default; OpVersionDesc& operator=(OpVersionDesc&&) = default;
......
...@@ -53,6 +53,19 @@ TEST(test_operator_version, test_operator_version) { ...@@ -53,6 +53,19 @@ TEST(test_operator_version, test_operator_version) {
framework::compatible::OpVersionDesc() framework::compatible::OpVersionDesc()
.NewInput("X2", "The second input.") .NewInput("X2", "The second input.")
.NewOutput("Y2", "The second output.")); .NewOutput("Y2", "The second output."));
REGISTER_OP_VERSION(op_name_0__)
.AddCheckpoint(
R"ROC(
Incompatible upgrade of attribute [height], input [X2] and output [Y2]
)ROC",
framework::compatible::OpVersionDesc()
.DeleteAttr("height",
"Parameters deleted due to interface alignment.")
.ModifyInput("X2", "Modify input due to interface alignment.")
.ModifyOutput("Y2", "Modify output due to interface alignment.")
.DeleteInput("X2", "Delete input due to interface alignment.")
.DeleteOutput("Y2", "Delete output due to interface alignment."));
} }
TEST(test_pass_op_version_checker, test_pass_op_version_checker) { TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册