未验证 提交 8bc0a31b 编写于 作者: 石晓伟 提交者: GitHub

update the operator registration for incompatible upgrade, test=develop (#29720) (#29774)

上级 0e2f5bb1
......@@ -62,6 +62,37 @@ OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged(
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::DeleteAttr(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kDeleteAttr>(OpAttrInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::ModifyInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kModifyInput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::ModifyOutput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kModifyOutput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::DeleteInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kDeleteInput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::DeleteOutput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kDeleteOutput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersion& OpVersionRegistrar::Register(const std::string& op_type) {
PADDLE_ENFORCE_EQ(
op_version_map_.find(op_type), op_version_map_.end(),
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include <boost/none.hpp>
#include <boost/variant.hpp>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_version_proto.h"
......@@ -30,16 +31,17 @@ namespace framework {
namespace compatible {
using OpAttrVariantT =
boost::variant<bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */
int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */
std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */
std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string> /* AttrType::STRINGS */
boost::variant<bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */
int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */
std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */
std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string>, /* AttrType::STRINGS */
boost::none_t /* None */
>;
struct OpUpdateInfo {
......@@ -48,7 +50,7 @@ struct OpUpdateInfo {
struct OpAttrInfo : OpUpdateInfo {
OpAttrInfo(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value)
const OpAttrVariantT& default_value = boost::none)
: name_{name}, default_value_{default_value}, remark_{remark} {}
const std::string& name() const { return name_; }
......@@ -83,11 +85,18 @@ struct OpBugfixInfo : OpUpdateInfo {
enum class OpUpdateType {
kInvalid = 0,
/* Compatibility upgrade */
kModifyAttr,
kNewAttr,
kNewInput,
kNewOutput,
kBugfixWithBehaviorChanged,
/* Incompatible upgrade, only for existing registration. */
kDeleteAttr = 100,
kModifyInput,
kModifyOutput,
kDeleteInput,
kDeleteOutput,
};
class OpUpdateBase {
......@@ -111,6 +120,7 @@ class OpUpdate : public OpUpdateBase {
class OpVersionDesc {
public:
/* Compatibility upgrade */
OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value);
OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark,
......@@ -118,10 +128,23 @@ class OpVersionDesc {
OpVersionDesc&& NewInput(const std::string& name, const std::string& remark);
OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark);
OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark);
/* Incompatible upgrade, only for existing registration. */
OpVersionDesc&& DeleteAttr(const std::string& name,
const std::string& remark);
OpVersionDesc&& ModifyInput(const std::string& name,
const std::string& remark);
OpVersionDesc&& ModifyOutput(const std::string& name,
const std::string& remark);
OpVersionDesc&& DeleteInput(const std::string& name,
const std::string& remark);
OpVersionDesc&& DeleteOutput(const std::string& name,
const std::string& remark);
public:
const std::vector<std::unique_ptr<OpUpdateBase>>& infos() const {
return infos_;
}
OpVersionDesc() = default;
OpVersionDesc(OpVersionDesc&&) = default;
OpVersionDesc& operator=(OpVersionDesc&&) = default;
......
......@@ -53,6 +53,19 @@ TEST(test_operator_version, test_operator_version) {
framework::compatible::OpVersionDesc()
.NewInput("X2", "The second input.")
.NewOutput("Y2", "The second output."));
REGISTER_OP_VERSION(op_name_0__)
.AddCheckpoint(
R"ROC(
Incompatible upgrade of attribute [height], input [X2] and output [Y2]
)ROC",
framework::compatible::OpVersionDesc()
.DeleteAttr("height",
"Parameters deleted due to interface alignment.")
.ModifyInput("X2", "Modify input due to interface alignment.")
.ModifyOutput("Y2", "Modify output due to interface alignment.")
.DeleteInput("X2", "Delete input due to interface alignment.")
.DeleteOutput("Y2", "Delete output due to interface alignment."));
}
TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册