未验证 提交 efab2eb4 编写于 作者: F Feiyu Chan 提交者: GitHub

add support for double attributes (#45390)

上级 0c363de8
......@@ -36,6 +36,7 @@ atype_to_parsing_function = {
"long": "CastPyArg2Long",
"int64_t": "CastPyArg2Long",
"float": "CastPyArg2Float",
"double": "CastPyArg2Double",
"std::string": "CastPyArg2String",
"std::vector<bool>": "CastPyArg2Booleans",
"std::vector<int>": "CastPyArg2Ints",
......
......@@ -180,6 +180,37 @@ struct ExtractAttribute<float> {
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<double> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
double* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = PADDLE_GET_CONST(int, attr);
attr = static_cast<double>(val);
} else if (attr.type() == typeid(int64_t)) { // NOLINT
int64_t val = PADDLE_GET_CONST(int64_t, attr);
attr = static_cast<double>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int64_t val = PADDLE_GET_CONST(float, attr);
attr = static_cast<double>(val);
}
double* attr_value = nullptr;
try {
attr_value = &paddle::get<double>(attr);
} catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type double, its type is %s.",
attr_name_,
paddle::platform::demangle(attr.type().name())));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<std::vector<double>> {
explicit ExtractAttribute(const std::string& attr_name)
......
......@@ -38,6 +38,7 @@ enum AttrType {
FLOAT64S = 12;
VAR = 13;
VARS = 14;
FLOAT64 = 15;
}
// OpDesc describes an instance of a C++ framework::OperatorBase
......@@ -62,6 +63,7 @@ message OpDesc {
repeated double float64s = 16;
optional string var_name = 17;
repeated string vars_name = 18;
optional double float64 = 19;
};
message Var {
......
......@@ -482,6 +482,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context.EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(float, attr)));
break;
case framework::proto::AttrType::FLOAT64:
infer_meta_context.EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(double, attr)));
break;
case framework::proto::AttrType::INT:
infer_meta_context.EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(int, attr)));
......@@ -651,6 +655,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
case phi::AttributeType::FLOAT32:
infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
break;
case phi::AttributeType::FLOAT64:
infer_meta_context.EmplaceBackAttr(
PADDLE_GET_CONST(double, attr));
break;
case phi::AttributeType::INT32:
infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(int, attr));
break;
......
......@@ -668,6 +668,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
this->attrs_[name] = std::vector<float>();
break;
}
case proto::AttrType::FLOAT64S: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to FLOAT64S";
this->attrs_[name] = std::vector<double>();
break;
}
case proto::AttrType::STRINGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to STRINGS";
......@@ -838,6 +844,7 @@ struct SetAttrDescVisitor {
mutable proto::OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(double v) const { attr_->set_float64(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
// Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
......
......@@ -2745,6 +2745,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context->EmplaceBackAttr(std::move(
phi::Scalar(PADDLE_GET_CONST(float, attr_iter->second))));
break;
case proto::AttrType::FLOAT64:
phi_kernel_context->EmplaceBackAttr(std::move(
phi::Scalar(PADDLE_GET_CONST(double, attr_iter->second))));
break;
case proto::AttrType::INT:
phi_kernel_context->EmplaceBackAttr(std::move(
phi::Scalar(PADDLE_GET_CONST(int, attr_iter->second))));
......@@ -2884,6 +2888,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(float, attr_iter->second));
break;
case phi::AttributeType::FLOAT64:
phi_kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(double, attr_iter->second));
break;
case phi::AttributeType::INT32:
phi_kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(int, attr_iter->second));
......
......@@ -58,7 +58,8 @@ using Attribute = paddle::variant<paddle::blank,
std::vector<int64_t>,
std::vector<double>,
VarDesc*,
std::vector<VarDesc*>>;
std::vector<VarDesc*>,
double>;
using AttributeMap = std::unordered_map<std::string, Attribute>;
#ifdef PADDLE_WITH_ASCEND_CL
......
......@@ -412,6 +412,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(float, attr))));
break;
case framework::proto::AttrType::FLOAT64:
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(double, attr))));
break;
case framework::proto::AttrType::INT:
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(int, attr))));
......@@ -549,6 +553,9 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
case phi::AttributeType::FLOAT32:
kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
break;
case phi::AttributeType::FLOAT64:
kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(double, attr));
break;
case phi::AttributeType::INT32:
kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(int, attr));
break;
......
......@@ -126,45 +126,6 @@ CastPyHandleToVarBaseList(const std::string& op_type,
return result;
} // namespace pybind
static inline void ConstructAttrMapFromPyArgs(const std::string& op_type,
int start_idx,
framework::AttributeMap* attrs,
const py::args& args) {
PADDLE_ENFORCE_EQ(
args.size() % 2,
0,
platform::errors::InvalidArgument(
"The number of arguments for arributes should be even."));
for (size_t i = 0; i < args.size(); i += 2) {
std::string name;
framework::Attribute value;
try {
name = args[i].cast<std::string>();
} catch (std::exception& e) {
PyObject* py_obj = args[i].ptr(); // get underlying PyObject
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be str, but got "
"%s",
op_type,
start_idx + i,
Py_TYPE(py_obj)->tp_name));
}
try {
value = args[i + 1].cast<framework::Attribute>();
} catch (std::exception& e) {
PyObject* py_obj = args[i + 1].ptr(); // get underlying PyObject
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"Attribute type (one of str, bool, int, int64, float, or list of "
"them), but got %s",
op_type,
start_idx + i + 1,
Py_TYPE(py_obj)->tp_name));
}
(*attrs)[name] = value;
}
}
static inline std::vector<std::shared_ptr<imperative::VarBase>>
ConstructDuplicableOutput(const size_t num) {
auto tracer = imperative::GetCurrentTracer();
......
......@@ -188,6 +188,14 @@ float CastPyArg2Float(PyObject* obj,
return static_cast<float>(CastPyArg2Double(obj, op_type, arg_pos));
}
void CastPyArg2AttrFloat(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Float(obj, op_type, arg_pos);
}
double CastPyArg2Double(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
......@@ -196,7 +204,7 @@ double CastPyArg2Double(PyObject* obj,
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"float, but got %s",
"double, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
......@@ -205,12 +213,12 @@ double CastPyArg2Double(PyObject* obj,
return 0.0;
}
void CastPyArg2AttrFloat(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Float(obj, op_type, arg_pos);
void CastPyArg2AttrDouble(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Double(obj, op_type, arg_pos);
}
std::string CastPyArg2String(PyObject* obj,
......@@ -735,6 +743,9 @@ void ConstructAttrMapFromPyArgs(
case paddle::framework::proto::AttrType::FLOAT:
CastPyArg2AttrFloat(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::FLOAT64:
CastPyArg2AttrDouble(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::STRING:
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
break;
......
......@@ -107,6 +107,12 @@ void CastPyArg2AttrFloat(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrDouble(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrString(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
......
......@@ -44,7 +44,8 @@ using Attribute = paddle::variant<paddle::blank,
std::vector<int64_t>,
std::vector<double>,
VarDesc*,
std::vector<VarDesc*>>;
std::vector<VarDesc*>,
double>;
using AttributeMap = std::unordered_map<std::string, Attribute>;
} // namespace framework
namespace imperative {
......
......@@ -192,6 +192,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(double);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
......
......@@ -124,6 +124,8 @@ class OpDescCreationMethod(object):
new_attr.bools.extend(user_defined_attr)
elif attr.type == framework_pb2.LONGS:
new_attr.longs.extend(user_defined_attr)
elif attr.type == framework_pb2.FLOAT64:
new_attr.float64 = user_defined_attr
else:
raise NotImplementedError(
"A not supported attribute type: %s." %
......
......@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.op as op
import paddle.fluid.proto.framework_pb2 as framework_pb2
......@@ -152,6 +154,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
__add_attr__("int_attr", framework_pb2.INT)
__add_attr__("float_attr", framework_pb2.FLOAT)
__add_attr__("float64_attr", framework_pb2.FLOAT64)
__add_attr__("string_attr", framework_pb2.STRING)
__add_attr__("ints_attr", framework_pb2.INTS)
__add_attr__("floats_attr", framework_pb2.FLOATS)
......@@ -165,6 +168,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
generated = method(X="a",
int_attr=10,
float_attr=3.2,
float64_attr=np.finfo("float64").max,
string_attr="test_str",
ints_attr=[0, 1, 2, 3, 4],
floats_attr=[0.2, 3.2, 4.5],
......@@ -187,6 +191,11 @@ class TestOpDescCreationMethod(unittest.TestCase):
attr.type = framework_pb2.FLOAT
attr.f = 3.2
attr = expected.attrs.add()
attr.name = "float64_attr"
attr.type = framework_pb2.FLOAT64
attr.float64 = np.finfo("float64").max
attr = expected.attrs.add()
attr.name = "string_attr"
attr.type = framework_pb2.STRING
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册