From efab2eb4e0eda8c9186f93f9c11a8bdb72c87a25 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Thu, 25 Aug 2022 23:05:39 +0800 Subject: [PATCH] add support for double attributes (#45390) --- .../final_state_generator/python_c_gen.py | 1 + paddle/fluid/framework/attribute.h | 31 +++++++++++++++ paddle/fluid/framework/framework.proto | 2 + paddle/fluid/framework/infershape_utils.cc | 8 ++++ paddle/fluid/framework/op_desc.cc | 7 ++++ paddle/fluid/framework/operator.cc | 8 ++++ paddle/fluid/framework/type_defs.h | 3 +- paddle/fluid/imperative/prepared_operator.h | 7 ++++ paddle/fluid/pybind/op_function.h | 39 ------------------- paddle/fluid/pybind/op_function_common.cc | 25 ++++++++---- paddle/fluid/pybind/op_function_common.h | 6 +++ paddle/phi/core/enforce.cc | 3 +- paddle/phi/core/infermeta_utils.h | 1 + python/paddle/fluid/op.py | 2 + .../fluid/tests/unittests/test_operator.py | 9 +++++ 15 files changed, 104 insertions(+), 48 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 4d5f5c9d61e..4e8e4775ed6 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -36,6 +36,7 @@ atype_to_parsing_function = { "long": "CastPyArg2Long", "int64_t": "CastPyArg2Long", "float": "CastPyArg2Float", + "double": "CastPyArg2Double", "std::string": "CastPyArg2String", "std::vector": "CastPyArg2Booleans", "std::vector": "CastPyArg2Ints", diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index b4a939f822b..a82e8e7e768 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -180,6 +180,37 @@ struct ExtractAttribute { const std::string& attr_name_; }; +template <> +struct ExtractAttribute { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + double* operator()(Attribute& attr) const { + if (attr.type() == typeid(int)) { // NOLINT + int val = PADDLE_GET_CONST(int, attr); + attr = static_cast(val); + } else if (attr.type() == typeid(int64_t)) { // NOLINT + int64_t val = PADDLE_GET_CONST(int64_t, attr); + attr = static_cast(val); + } else if (attr.type() == typeid(float)) { // NOLINT + int64_t val = PADDLE_GET_CONST(float, attr); + attr = static_cast(val); + } + double* attr_value = nullptr; + try { + attr_value = &paddle::get(attr); + } catch (paddle::bad_variant_access const& bad_get) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Cannot get attribute (%s) by type double, its type is %s.", + attr_name_, + paddle::platform::demangle(attr.type().name()))); + } + return attr_value; + } + + const std::string& attr_name_; +}; + template <> struct ExtractAttribute> { explicit ExtractAttribute(const std::string& attr_name) diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 61a495a59a9..2a56dc60335 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -38,6 +38,7 @@ enum AttrType { FLOAT64S = 12; VAR = 13; VARS = 14; + FLOAT64 = 15; } // OpDesc describes an instance of a C++ framework::OperatorBase @@ -62,6 +63,7 @@ message OpDesc { repeated double float64s = 16; optional string var_name = 17; repeated string vars_name = 18; + optional double float64 = 19; }; message Var { diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index debe43fab82..3a451c19ec2 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -482,6 +482,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, infer_meta_context.EmplaceBackAttr( phi::Scalar(PADDLE_GET_CONST(float, attr))); break; + case framework::proto::AttrType::FLOAT64: + infer_meta_context.EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(double, attr))); + break; case framework::proto::AttrType::INT: infer_meta_context.EmplaceBackAttr( phi::Scalar(PADDLE_GET_CONST(int, attr))); @@ -651,6 +655,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, case phi::AttributeType::FLOAT32: infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr)); break; + case phi::AttributeType::FLOAT64: + infer_meta_context.EmplaceBackAttr( + PADDLE_GET_CONST(double, attr)); + break; case phi::AttributeType::INT32: infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(int, attr)); break; diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index e5d8f6f9f0e..507f7cd166e 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -668,6 +668,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { this->attrs_[name] = std::vector(); break; } + case proto::AttrType::FLOAT64S: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to FLOAT64S"; + this->attrs_[name] = std::vector(); + break; + } case proto::AttrType::STRINGS: { VLOG(11) << "SetAttr: " << Type() << ", " << name << " from INTS to STRINGS"; @@ -838,6 +844,7 @@ struct SetAttrDescVisitor { mutable proto::OpDesc::Attr *attr_; void operator()(int v) const { attr_->set_i(v); } void operator()(float v) const { attr_->set_f(v); } + void operator()(double v) const { attr_->set_float64(v); } void operator()(const std::string &v) const { attr_->set_s(v); } // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162 diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index b5d6a3786c3..23fce93ef30 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2745,6 +2745,10 @@ void OperatorWithKernel::BuildPhiKernelContext( phi_kernel_context->EmplaceBackAttr(std::move( phi::Scalar(PADDLE_GET_CONST(float, attr_iter->second)))); break; + case proto::AttrType::FLOAT64: + phi_kernel_context->EmplaceBackAttr(std::move( + phi::Scalar(PADDLE_GET_CONST(double, attr_iter->second)))); + break; case proto::AttrType::INT: phi_kernel_context->EmplaceBackAttr(std::move( phi::Scalar(PADDLE_GET_CONST(int, attr_iter->second)))); @@ -2884,6 +2888,10 @@ void OperatorWithKernel::BuildPhiKernelContext( phi_kernel_context->EmplaceBackAttr( PADDLE_GET_CONST(float, attr_iter->second)); break; + case phi::AttributeType::FLOAT64: + phi_kernel_context->EmplaceBackAttr( + PADDLE_GET_CONST(double, attr_iter->second)); + break; case phi::AttributeType::INT32: phi_kernel_context->EmplaceBackAttr( PADDLE_GET_CONST(int, attr_iter->second)); diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 31a006914ac..d4739209e7a 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -58,7 +58,8 @@ using Attribute = paddle::variant, std::vector, VarDesc*, - std::vector>; + std::vector, + double>; using AttributeMap = std::unordered_map; #ifdef PADDLE_WITH_ASCEND_CL diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index d043b4a5aad..1e76757e1c0 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -412,6 +412,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, kernel_ctx->EmplaceBackAttr( std::move(phi::Scalar(PADDLE_GET_CONST(float, attr)))); break; + case framework::proto::AttrType::FLOAT64: + kernel_ctx->EmplaceBackAttr( + std::move(phi::Scalar(PADDLE_GET_CONST(double, attr)))); + break; case framework::proto::AttrType::INT: kernel_ctx->EmplaceBackAttr( std::move(phi::Scalar(PADDLE_GET_CONST(int, attr)))); @@ -549,6 +553,9 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, case phi::AttributeType::FLOAT32: kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(float, attr)); break; + case phi::AttributeType::FLOAT64: + kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(double, attr)); + break; case phi::AttributeType::INT32: kernel_ctx->EmplaceBackAttr(PADDLE_GET_CONST(int, attr)); break; diff --git a/paddle/fluid/pybind/op_function.h b/paddle/fluid/pybind/op_function.h index 884136ec0d3..542860fa0dc 100644 --- a/paddle/fluid/pybind/op_function.h +++ b/paddle/fluid/pybind/op_function.h @@ -126,45 +126,6 @@ CastPyHandleToVarBaseList(const std::string& op_type, return result; } // namespace pybind -static inline void ConstructAttrMapFromPyArgs(const std::string& op_type, - int start_idx, - framework::AttributeMap* attrs, - const py::args& args) { - PADDLE_ENFORCE_EQ( - args.size() % 2, - 0, - platform::errors::InvalidArgument( - "The number of arguments for arributes should be even.")); - for (size_t i = 0; i < args.size(); i += 2) { - std::string name; - framework::Attribute value; - try { - name = args[i].cast(); - } catch (std::exception& e) { - PyObject* py_obj = args[i].ptr(); // get underlying PyObject - PADDLE_THROW(platform::errors::InvalidArgument( - "%s(): argument (position %d) must be str, but got " - "%s", - op_type, - start_idx + i, - Py_TYPE(py_obj)->tp_name)); - } - try { - value = args[i + 1].cast(); - } catch (std::exception& e) { - PyObject* py_obj = args[i + 1].ptr(); // get underlying PyObject - PADDLE_THROW(platform::errors::InvalidArgument( - "%s(): argument (position %d) must be " - "Attribute type (one of str, bool, int, int64, float, or list of " - "them), but got %s", - op_type, - start_idx + i + 1, - Py_TYPE(py_obj)->tp_name)); - } - (*attrs)[name] = value; - } -} - static inline std::vector> ConstructDuplicableOutput(const size_t num) { auto tracer = imperative::GetCurrentTracer(); diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 28bdbf92d18..e7970f69e57 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -188,6 +188,14 @@ float CastPyArg2Float(PyObject* obj, return static_cast(CastPyArg2Double(obj, op_type, arg_pos)); } +void CastPyArg2AttrFloat(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, + const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Float(obj, op_type, arg_pos); +} + double CastPyArg2Double(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { @@ -196,7 +204,7 @@ double CastPyArg2Double(PyObject* obj, } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " - "float, but got %s", + "double, but got %s", op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT @@ -205,12 +213,12 @@ double CastPyArg2Double(PyObject* obj, return 0.0; } -void CastPyArg2AttrFloat(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, - const std::string& op_type, - ssize_t arg_pos) { - attrs[key] = CastPyArg2Float(obj, op_type, arg_pos); +void CastPyArg2AttrDouble(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, + const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Double(obj, op_type, arg_pos); } std::string CastPyArg2String(PyObject* obj, @@ -735,6 +743,9 @@ void ConstructAttrMapFromPyArgs( case paddle::framework::proto::AttrType::FLOAT: CastPyArg2AttrFloat(obj, attrs, key, op_type, arg_pos); break; + case paddle::framework::proto::AttrType::FLOAT64: + CastPyArg2AttrDouble(obj, attrs, key, op_type, arg_pos); + break; case paddle::framework::proto::AttrType::STRING: CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos); break; diff --git a/paddle/fluid/pybind/op_function_common.h b/paddle/fluid/pybind/op_function_common.h index 7bbfb8b85b8..efa16494e77 100644 --- a/paddle/fluid/pybind/op_function_common.h +++ b/paddle/fluid/pybind/op_function_common.h @@ -107,6 +107,12 @@ void CastPyArg2AttrFloat(PyObject* obj, const std::string& op_type, ssize_t arg_pos); +void CastPyArg2AttrDouble(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, + const std::string& op_type, + ssize_t arg_pos); + void CastPyArg2AttrString(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, diff --git a/paddle/phi/core/enforce.cc b/paddle/phi/core/enforce.cc index 4eb580955a9..7d4efead494 100644 --- a/paddle/phi/core/enforce.cc +++ b/paddle/phi/core/enforce.cc @@ -44,7 +44,8 @@ using Attribute = paddle::variant, std::vector, VarDesc*, - std::vector>; + std::vector, + double>; using AttributeMap = std::unordered_map; } // namespace framework namespace imperative { diff --git a/paddle/phi/core/infermeta_utils.h b/paddle/phi/core/infermeta_utils.h index de2dcd6909f..729c56352cf 100644 --- a/paddle/phi/core/infermeta_utils.h +++ b/paddle/phi/core/infermeta_utils.h @@ -192,6 +192,7 @@ struct InferMetaFnImpl { PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(double); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); diff --git a/python/paddle/fluid/op.py b/python/paddle/fluid/op.py index d5be4423775..4581248d06a 100644 --- a/python/paddle/fluid/op.py +++ b/python/paddle/fluid/op.py @@ -124,6 +124,8 @@ class OpDescCreationMethod(object): new_attr.bools.extend(user_defined_attr) elif attr.type == framework_pb2.LONGS: new_attr.longs.extend(user_defined_attr) + elif attr.type == framework_pb2.FLOAT64: + new_attr.float64 = user_defined_attr else: raise NotImplementedError( "A not supported attribute type: %s." % diff --git a/python/paddle/fluid/tests/unittests/test_operator.py b/python/paddle/fluid/tests/unittests/test_operator.py index a3ca52f78c9..923a7f21e2a 100644 --- a/python/paddle/fluid/tests/unittests/test_operator.py +++ b/python/paddle/fluid/tests/unittests/test_operator.py @@ -16,6 +16,8 @@ from __future__ import print_function import unittest +import numpy as np + import paddle.fluid.op as op import paddle.fluid.proto.framework_pb2 as framework_pb2 @@ -152,6 +154,7 @@ class TestOpDescCreationMethod(unittest.TestCase): __add_attr__("int_attr", framework_pb2.INT) __add_attr__("float_attr", framework_pb2.FLOAT) + __add_attr__("float64_attr", framework_pb2.FLOAT64) __add_attr__("string_attr", framework_pb2.STRING) __add_attr__("ints_attr", framework_pb2.INTS) __add_attr__("floats_attr", framework_pb2.FLOATS) @@ -165,6 +168,7 @@ class TestOpDescCreationMethod(unittest.TestCase): generated = method(X="a", int_attr=10, float_attr=3.2, + float64_attr=np.finfo("float64").max, string_attr="test_str", ints_attr=[0, 1, 2, 3, 4], floats_attr=[0.2, 3.2, 4.5], @@ -187,6 +191,11 @@ class TestOpDescCreationMethod(unittest.TestCase): attr.type = framework_pb2.FLOAT attr.f = 3.2 + attr = expected.attrs.add() + attr.name = "float64_attr" + attr.type = framework_pb2.FLOAT64 + attr.float64 = np.finfo("float64").max + attr = expected.attrs.add() attr.name = "string_attr" attr.type = framework_pb2.STRING -- GitLab