未验证 提交 53bb1265 编写于 作者: 石晓伟 提交者: GitHub

fix a bug in op_version_registry, test=develop, test=op_version (#29994)

上级 3e0c4929
...@@ -18,29 +18,6 @@ namespace paddle { ...@@ -18,29 +18,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace compatible { namespace compatible {
namespace {
template <OpUpdateType type__, typename InfoType>
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
return new OpUpdate<InfoType, type__>(info);
}
}
OpVersionDesc&& OpVersionDesc::ModifyAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name, OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name,
const std::string& remark) { const std::string& remark) {
infos_.emplace_back( infos_.emplace_back(
......
...@@ -118,13 +118,44 @@ class OpUpdate : public OpUpdateBase { ...@@ -118,13 +118,44 @@ class OpUpdate : public OpUpdateBase {
OpUpdateType type_; OpUpdateType type_;
}; };
template <OpUpdateType type__, typename InfoType>
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
return new OpUpdate<InfoType, type__>(info);
}
template <typename T>
OpAttrVariantT op_attr_wrapper(const T& val) {
return OpAttrVariantT{val};
}
template <int N>
OpAttrVariantT op_attr_wrapper(const char (&val)[N]) {
PADDLE_ENFORCE_EQ(
val[N - 1], 0,
platform::errors::InvalidArgument(
"The argument of operator register %c is illegal.", val[N - 1]));
return OpAttrVariantT{std::string{val}};
}
class OpVersionDesc { class OpVersionDesc {
public: public:
/* Compatibility upgrade */ /* Compatibility upgrade */
template <typename T>
OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark, OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value); const T& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
OpAttrInfo(name, remark, op_attr_wrapper(default_value))));
return std::move(*this);
}
template <typename T>
OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark, OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value); const T& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
OpAttrInfo(name, remark, op_attr_wrapper(default_value))));
return std::move(*this);
}
OpVersionDesc&& NewInput(const std::string& name, const std::string& remark); OpVersionDesc&& NewInput(const std::string& name, const std::string& remark);
OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark); OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark);
OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark); OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark);
......
...@@ -661,7 +661,7 @@ REGISTER_OP_VERSION(conv_transpose) ...@@ -661,7 +661,7 @@ REGISTER_OP_VERSION(conv_transpose)
"output_padding", "output_padding",
"In order to add additional size to one side of each dimension " "In order to add additional size to one side of each dimension "
"in the output", "in the output",
{})); std::vector<int>{}));
REGISTER_OP_VERSION(conv2d_transpose) REGISTER_OP_VERSION(conv2d_transpose)
.AddCheckpoint( .AddCheckpoint(
...@@ -672,7 +672,7 @@ REGISTER_OP_VERSION(conv2d_transpose) ...@@ -672,7 +672,7 @@ REGISTER_OP_VERSION(conv2d_transpose)
"output_padding", "output_padding",
"In order to add additional size to one side of each dimension " "In order to add additional size to one side of each dimension "
"in the output", "in the output",
{})); std::vector<int>{}));
REGISTER_OP_VERSION(conv3d_transpose) REGISTER_OP_VERSION(conv3d_transpose)
.AddCheckpoint( .AddCheckpoint(
...@@ -683,7 +683,7 @@ REGISTER_OP_VERSION(conv3d_transpose) ...@@ -683,7 +683,7 @@ REGISTER_OP_VERSION(conv3d_transpose)
"output_padding", "output_padding",
"In order to add additional size to one side of each dimension " "In order to add additional size to one side of each dimension "
"in the output", "in the output",
{})); std::vector<int>{}));
REGISTER_OP_VERSION(depthwise_conv2d_transpose) REGISTER_OP_VERSION(depthwise_conv2d_transpose)
.AddCheckpoint( .AddCheckpoint(
...@@ -694,4 +694,4 @@ REGISTER_OP_VERSION(depthwise_conv2d_transpose) ...@@ -694,4 +694,4 @@ REGISTER_OP_VERSION(depthwise_conv2d_transpose)
"output_padding", "output_padding",
"In order to add additional size to one side of each dimension " "In order to add additional size to one side of each dimension "
"in the output", "in the output",
{})); std::vector<int>{}));
...@@ -489,4 +489,4 @@ REGISTER_OP_VERSION(fusion_gru) ...@@ -489,4 +489,4 @@ REGISTER_OP_VERSION(fusion_gru)
"Scale_weights", "Scale_weights",
"The added attribute 'Scale_weights' is not yet " "The added attribute 'Scale_weights' is not yet "
"registered.", "registered.",
{1.0f})); std::vector<float>{1.0f}));
...@@ -184,7 +184,7 @@ REGISTER_OP_VERSION(unique) ...@@ -184,7 +184,7 @@ REGISTER_OP_VERSION(unique)
.NewAttr("axis", .NewAttr("axis",
"The axis to apply unique. If None, the input will be " "The axis to apply unique. If None, the input will be "
"flattened.", "flattened.",
{}) std::vector<int>{})
.NewAttr("is_sorted", .NewAttr("is_sorted",
"If True, the unique elements of X are in ascending order." "If True, the unique elements of X are in ascending order."
"Otherwise, the unique elements are not sorted.", "Otherwise, the unique elements are not sorted.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册