未验证 提交 53bb1265 编写于 作者: 石晓伟 提交者: GitHub

fix a bug in op_version_registry, test=develop, test=op_version (#29994)

上级 3e0c4929
......@@ -18,29 +18,6 @@ namespace paddle {
namespace framework {
namespace compatible {
namespace {
template <OpUpdateType type__, typename InfoType>
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
return new OpUpdate<InfoType, type__>(info);
}
}
OpVersionDesc&& OpVersionDesc::ModifyAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
......
......@@ -118,13 +118,44 @@ class OpUpdate : public OpUpdateBase {
OpUpdateType type_;
};
template <OpUpdateType type__, typename InfoType>
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
return new OpUpdate<InfoType, type__>(info);
}
template <typename T>
OpAttrVariantT op_attr_wrapper(const T& val) {
return OpAttrVariantT{val};
}
template <int N>
OpAttrVariantT op_attr_wrapper(const char (&val)[N]) {
PADDLE_ENFORCE_EQ(
val[N - 1], 0,
platform::errors::InvalidArgument(
"The argument of operator register %c is illegal.", val[N - 1]));
return OpAttrVariantT{std::string{val}};
}
class OpVersionDesc {
public:
/* Compatibility upgrade */
template <typename T>
OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value);
const T& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
OpAttrInfo(name, remark, op_attr_wrapper(default_value))));
return std::move(*this);
}
template <typename T>
OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value);
const T& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
OpAttrInfo(name, remark, op_attr_wrapper(default_value))));
return std::move(*this);
}
OpVersionDesc&& NewInput(const std::string& name, const std::string& remark);
OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark);
OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark);
......
......@@ -661,7 +661,7 @@ REGISTER_OP_VERSION(conv_transpose)
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
std::vector<int>{}));
REGISTER_OP_VERSION(conv2d_transpose)
.AddCheckpoint(
......@@ -672,7 +672,7 @@ REGISTER_OP_VERSION(conv2d_transpose)
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
std::vector<int>{}));
REGISTER_OP_VERSION(conv3d_transpose)
.AddCheckpoint(
......@@ -683,7 +683,7 @@ REGISTER_OP_VERSION(conv3d_transpose)
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
std::vector<int>{}));
REGISTER_OP_VERSION(depthwise_conv2d_transpose)
.AddCheckpoint(
......@@ -694,4 +694,4 @@ REGISTER_OP_VERSION(depthwise_conv2d_transpose)
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
std::vector<int>{}));
......@@ -489,4 +489,4 @@ REGISTER_OP_VERSION(fusion_gru)
"Scale_weights",
"The added attribute 'Scale_weights' is not yet "
"registered.",
{1.0f}));
std::vector<float>{1.0f}));
......@@ -184,7 +184,7 @@ REGISTER_OP_VERSION(unique)
.NewAttr("axis",
"The axis to apply unique. If None, the input will be "
"flattened.",
{})
std::vector<int>{})
.NewAttr("is_sorted",
"If True, the unique elements of X are in ascending order."
"Otherwise, the unique elements are not sorted.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册