未验证 提交 efab2eb4 编写于 作者: F Feiyu Chan 提交者: GitHub

add support for double attributes (#45390)

上级 0c363de8
...@@ -36,6 +36,7 @@ atype_to_parsing_function = { ...@@ -36,6 +36,7 @@ atype_to_parsing_function = {
"long": "CastPyArg2Long", "long": "CastPyArg2Long",
"int64_t": "CastPyArg2Long", "int64_t": "CastPyArg2Long",
"float": "CastPyArg2Float", "float": "CastPyArg2Float",
"double": "CastPyArg2Double",
"std::string": "CastPyArg2String", "std::string": "CastPyArg2String",
"std::vector<bool>": "CastPyArg2Booleans", "std::vector<bool>": "CastPyArg2Booleans",
"std::vector<int>": "CastPyArg2Ints", "std::vector<int>": "CastPyArg2Ints",
......
...@@ -180,6 +180,37 @@ struct ExtractAttribute<float> { ...@@ -180,6 +180,37 @@ struct ExtractAttribute<float> {
const std::string& attr_name_; const std::string& attr_name_;
}; };
template <>
struct ExtractAttribute<double> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
double* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = PADDLE_GET_CONST(int, attr);
attr = static_cast<double>(val);
} else if (attr.type() == typeid(int64_t)) { // NOLINT
int64_t val = PADDLE_GET_CONST(int64_t, attr);
attr = static_cast<double>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int64_t val = PADDLE_GET_CONST(float, attr);
attr = static_cast<double>(val);
}
double* attr_value = nullptr;
try {
attr_value = &paddle::get<double>(attr);
} catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type double, its type is %s.",
attr_name_,
paddle::platform::demangle(attr.type().name())));
}
return attr_value;
}
const std::string& attr_name_;
};
template <> template <>
struct ExtractAttribute<std::vector<double>> { struct ExtractAttribute<std::vector<double>> {
explicit ExtractAttribute(const std::string& attr_name) explicit ExtractAttribute(const std::string& attr_name)
......
...@@ -38,6 +38,7 @@ enum AttrType { ...@@ -38,6 +38,7 @@ enum AttrType {
FLOAT64S = 12; FLOAT64S = 12;
VAR = 13; VAR = 13;
VARS = 14; VARS = 14;
FLOAT64 = 15;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
...@@ -62,6 +63,7 @@ message OpDesc { ...@@ -62,6 +63,7 @@ message OpDesc {
repeated double float64s = 16; repeated double float64s = 16;
optional string var_name = 17; optional string var_name = 17;
repeated string vars_name = 18; repeated string vars_name = 18;
optional double float64 = 19;
}; };
message Var { message Var {
......
...@@ -482,6 +482,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -482,6 +482,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(float, attr))); phi::Scalar(PADDLE_GET_CONST(float, attr)));
break; break;
case framework::proto::AttrType::FLOAT64:
infer_meta_context.EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(double, attr)));
break;
case framework::proto::AttrType::INT: case framework::proto::AttrType::INT:
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(int, attr))); phi::Scalar(PADDLE_GET_CONST(int, attr)));
...@@ -651,6 +655,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -651,6 +655,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
case phi::AttributeType::FLOAT32: case phi::AttributeType::FLOAT32:
infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr)); infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
break; break;
case phi::AttributeType::FLOAT64:
infer_meta_context.EmplaceBackAttr(
PADDLE_GET_CONST(double, attr));
break;
case phi::AttributeType::INT32: case phi::AttributeType::INT32:
infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(int, attr)); infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(int, attr));
break; break;
......
...@@ -668,6 +668,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -668,6 +668,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
this->attrs_[name] = std::vector<float>(); this->attrs_[name] = std::vector<float>();
break; break;
} }
case proto::AttrType::FLOAT64S: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to FLOAT64S";
this->attrs_[name] = std::vector<double>();
break;
}
case proto::AttrType::STRINGS: { case proto::AttrType::STRINGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to STRINGS"; << " from INTS to STRINGS";
...@@ -838,6 +844,7 @@ struct SetAttrDescVisitor { ...@@ -838,6 +844,7 @@ struct SetAttrDescVisitor {
mutable proto::OpDesc::Attr *attr_; mutable proto::OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); } void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); } void operator()(float v) const { attr_->set_f(v); }
void operator()(double v) const { attr_->set_float64(v); }
void operator()(const std::string &v) const { attr_->set_s(v); } void operator()(const std::string &v) const { attr_->set_s(v); }
// Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162 // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
......
...@@ -2745,6 +2745,10 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2745,6 +2745,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context->EmplaceBackAttr(std::move( phi_kernel_context->EmplaceBackAttr(std::move(
phi::Scalar(PADDLE_GET_CONST(float, attr_iter->second)))); phi::Scalar(PADDLE_GET_CONST(float, attr_iter->second))));
break; break;
case proto::AttrType::FLOAT64:
phi_kernel_context->EmplaceBackAttr(std::move(
phi::Scalar(PADDLE_GET_CONST(double, attr_iter->second))));
break;
case proto::AttrType::INT: case proto::AttrType::INT:
phi_kernel_context->EmplaceBackAttr(std::move( phi_kernel_context->EmplaceBackAttr(std::move(
phi::Scalar(PADDLE_GET_CONST(int, attr_iter->second)))); phi::Scalar(PADDLE_GET_CONST(int, attr_iter->second))));
...@@ -2884,6 +2888,10 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2884,6 +2888,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context->EmplaceBackAttr( phi_kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(float, attr_iter->second)); PADDLE_GET_CONST(float, attr_iter->second));
break; break;
case phi::AttributeType::FLOAT64:
phi_kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(double, attr_iter->second));
break;
case phi::AttributeType::INT32: case phi::AttributeType::INT32:
phi_kernel_context->EmplaceBackAttr( phi_kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(int, attr_iter->second)); PADDLE_GET_CONST(int, attr_iter->second));
......
...@@ -58,7 +58,8 @@ using Attribute = paddle::variant<paddle::blank, ...@@ -58,7 +58,8 @@ using Attribute = paddle::variant<paddle::blank,
std::vector<int64_t>, std::vector<int64_t>,
std::vector<double>, std::vector<double>,
VarDesc*, VarDesc*,
std::vector<VarDesc*>>; std::vector<VarDesc*>,
double>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
......
...@@ -412,6 +412,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -412,6 +412,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(float, attr)))); std::move(phi::Scalar(PADDLE_GET_CONST(float, attr))));
break; break;
case framework::proto::AttrType::FLOAT64:
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(double, attr))));
break;
case framework::proto::AttrType::INT: case framework::proto::AttrType::INT:
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(int, attr)))); std::move(phi::Scalar(PADDLE_GET_CONST(int, attr))));
...@@ -549,6 +553,9 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -549,6 +553,9 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
case phi::AttributeType::FLOAT32: case phi::AttributeType::FLOAT32:
kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(float, attr)); kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
break; break;
case phi::AttributeType::FLOAT64:
kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(double, attr));
break;
case phi::AttributeType::INT32: case phi::AttributeType::INT32:
kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(int, attr)); kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(int, attr));
break; break;
......
...@@ -126,45 +126,6 @@ CastPyHandleToVarBaseList(const std::string& op_type, ...@@ -126,45 +126,6 @@ CastPyHandleToVarBaseList(const std::string& op_type,
return result; return result;
} // namespace pybind } // namespace pybind
static inline void ConstructAttrMapFromPyArgs(const std::string& op_type,
int start_idx,
framework::AttributeMap* attrs,
const py::args& args) {
PADDLE_ENFORCE_EQ(
args.size() % 2,
0,
platform::errors::InvalidArgument(
"The number of arguments for arributes should be even."));
for (size_t i = 0; i < args.size(); i += 2) {
std::string name;
framework::Attribute value;
try {
name = args[i].cast<std::string>();
} catch (std::exception& e) {
PyObject* py_obj = args[i].ptr(); // get underlying PyObject
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be str, but got "
"%s",
op_type,
start_idx + i,
Py_TYPE(py_obj)->tp_name));
}
try {
value = args[i + 1].cast<framework::Attribute>();
} catch (std::exception& e) {
PyObject* py_obj = args[i + 1].ptr(); // get underlying PyObject
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"Attribute type (one of str, bool, int, int64, float, or list of "
"them), but got %s",
op_type,
start_idx + i + 1,
Py_TYPE(py_obj)->tp_name));
}
(*attrs)[name] = value;
}
}
static inline std::vector<std::shared_ptr<imperative::VarBase>> static inline std::vector<std::shared_ptr<imperative::VarBase>>
ConstructDuplicableOutput(const size_t num) { ConstructDuplicableOutput(const size_t num) {
auto tracer = imperative::GetCurrentTracer(); auto tracer = imperative::GetCurrentTracer();
......
...@@ -188,6 +188,14 @@ float CastPyArg2Float(PyObject* obj, ...@@ -188,6 +188,14 @@ float CastPyArg2Float(PyObject* obj,
return static_cast<float>(CastPyArg2Double(obj, op_type, arg_pos)); return static_cast<float>(CastPyArg2Double(obj, op_type, arg_pos));
} }
void CastPyArg2AttrFloat(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Float(obj, op_type, arg_pos);
}
double CastPyArg2Double(PyObject* obj, double CastPyArg2Double(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
...@@ -196,7 +204,7 @@ double CastPyArg2Double(PyObject* obj, ...@@ -196,7 +204,7 @@ double CastPyArg2Double(PyObject* obj,
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
"float, but got %s", "double, but got %s",
op_type, op_type,
arg_pos + 1, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
...@@ -205,12 +213,12 @@ double CastPyArg2Double(PyObject* obj, ...@@ -205,12 +213,12 @@ double CastPyArg2Double(PyObject* obj,
return 0.0; return 0.0;
} }
void CastPyArg2AttrFloat(PyObject* obj, void CastPyArg2AttrDouble(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& key,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Float(obj, op_type, arg_pos); attrs[key] = CastPyArg2Double(obj, op_type, arg_pos);
} }
std::string CastPyArg2String(PyObject* obj, std::string CastPyArg2String(PyObject* obj,
...@@ -735,6 +743,9 @@ void ConstructAttrMapFromPyArgs( ...@@ -735,6 +743,9 @@ void ConstructAttrMapFromPyArgs(
case paddle::framework::proto::AttrType::FLOAT: case paddle::framework::proto::AttrType::FLOAT:
CastPyArg2AttrFloat(obj, attrs, key, op_type, arg_pos); CastPyArg2AttrFloat(obj, attrs, key, op_type, arg_pos);
break; break;
case paddle::framework::proto::AttrType::FLOAT64:
CastPyArg2AttrDouble(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::STRING: case paddle::framework::proto::AttrType::STRING:
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos); CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
break; break;
......
...@@ -107,6 +107,12 @@ void CastPyArg2AttrFloat(PyObject* obj, ...@@ -107,6 +107,12 @@ void CastPyArg2AttrFloat(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos); ssize_t arg_pos);
void CastPyArg2AttrDouble(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrString(PyObject* obj, void CastPyArg2AttrString(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& key,
......
...@@ -44,7 +44,8 @@ using Attribute = paddle::variant<paddle::blank, ...@@ -44,7 +44,8 @@ using Attribute = paddle::variant<paddle::blank,
std::vector<int64_t>, std::vector<int64_t>,
std::vector<double>, std::vector<double>,
VarDesc*, VarDesc*,
std::vector<VarDesc*>>; std::vector<VarDesc*>,
double>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
} // namespace framework } // namespace framework
namespace imperative { namespace imperative {
......
...@@ -192,6 +192,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -192,6 +192,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(double);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
......
...@@ -124,6 +124,8 @@ class OpDescCreationMethod(object): ...@@ -124,6 +124,8 @@ class OpDescCreationMethod(object):
new_attr.bools.extend(user_defined_attr) new_attr.bools.extend(user_defined_attr)
elif attr.type == framework_pb2.LONGS: elif attr.type == framework_pb2.LONGS:
new_attr.longs.extend(user_defined_attr) new_attr.longs.extend(user_defined_attr)
elif attr.type == framework_pb2.FLOAT64:
new_attr.float64 = user_defined_attr
else: else:
raise NotImplementedError( raise NotImplementedError(
"A not supported attribute type: %s." % "A not supported attribute type: %s." %
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np
import paddle.fluid.op as op import paddle.fluid.op as op
import paddle.fluid.proto.framework_pb2 as framework_pb2 import paddle.fluid.proto.framework_pb2 as framework_pb2
...@@ -152,6 +154,7 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -152,6 +154,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
__add_attr__("int_attr", framework_pb2.INT) __add_attr__("int_attr", framework_pb2.INT)
__add_attr__("float_attr", framework_pb2.FLOAT) __add_attr__("float_attr", framework_pb2.FLOAT)
__add_attr__("float64_attr", framework_pb2.FLOAT64)
__add_attr__("string_attr", framework_pb2.STRING) __add_attr__("string_attr", framework_pb2.STRING)
__add_attr__("ints_attr", framework_pb2.INTS) __add_attr__("ints_attr", framework_pb2.INTS)
__add_attr__("floats_attr", framework_pb2.FLOATS) __add_attr__("floats_attr", framework_pb2.FLOATS)
...@@ -165,6 +168,7 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -165,6 +168,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
generated = method(X="a", generated = method(X="a",
int_attr=10, int_attr=10,
float_attr=3.2, float_attr=3.2,
float64_attr=np.finfo("float64").max,
string_attr="test_str", string_attr="test_str",
ints_attr=[0, 1, 2, 3, 4], ints_attr=[0, 1, 2, 3, 4],
floats_attr=[0.2, 3.2, 4.5], floats_attr=[0.2, 3.2, 4.5],
...@@ -187,6 +191,11 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -187,6 +191,11 @@ class TestOpDescCreationMethod(unittest.TestCase):
attr.type = framework_pb2.FLOAT attr.type = framework_pb2.FLOAT
attr.f = 3.2 attr.f = 3.2
attr = expected.attrs.add()
attr.name = "float64_attr"
attr.type = framework_pb2.FLOAT64
attr.float64 = np.finfo("float64").max
attr = expected.attrs.add() attr = expected.attrs.add()
attr.name = "string_attr" attr.name = "string_attr"
attr.type = framework_pb2.STRING attr.type = framework_pb2.STRING
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册