Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
efab2eb4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
efab2eb4
编写于
8月 25, 2022
作者:
F
Feiyu Chan
提交者:
GitHub
8月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add support for double attributes (#45390)
上级
0c363de8
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
104 addition
and
48 deletion
+104
-48
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
...auto_code_generator/final_state_generator/python_c_gen.py
+1
-0
paddle/fluid/framework/attribute.h
paddle/fluid/framework/attribute.h
+31
-0
paddle/fluid/framework/framework.proto
paddle/fluid/framework/framework.proto
+2
-0
paddle/fluid/framework/infershape_utils.cc
paddle/fluid/framework/infershape_utils.cc
+8
-0
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+7
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+8
-0
paddle/fluid/framework/type_defs.h
paddle/fluid/framework/type_defs.h
+2
-1
paddle/fluid/imperative/prepared_operator.h
paddle/fluid/imperative/prepared_operator.h
+7
-0
paddle/fluid/pybind/op_function.h
paddle/fluid/pybind/op_function.h
+0
-39
paddle/fluid/pybind/op_function_common.cc
paddle/fluid/pybind/op_function_common.cc
+18
-7
paddle/fluid/pybind/op_function_common.h
paddle/fluid/pybind/op_function_common.h
+6
-0
paddle/phi/core/enforce.cc
paddle/phi/core/enforce.cc
+2
-1
paddle/phi/core/infermeta_utils.h
paddle/phi/core/infermeta_utils.h
+1
-0
python/paddle/fluid/op.py
python/paddle/fluid/op.py
+2
-0
python/paddle/fluid/tests/unittests/test_operator.py
python/paddle/fluid/tests/unittests/test_operator.py
+9
-0
未找到文件。
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
浏览文件 @
efab2eb4
...
@@ -36,6 +36,7 @@ atype_to_parsing_function = {
...
@@ -36,6 +36,7 @@ atype_to_parsing_function = {
"long"
:
"CastPyArg2Long"
,
"long"
:
"CastPyArg2Long"
,
"int64_t"
:
"CastPyArg2Long"
,
"int64_t"
:
"CastPyArg2Long"
,
"float"
:
"CastPyArg2Float"
,
"float"
:
"CastPyArg2Float"
,
"double"
:
"CastPyArg2Double"
,
"std::string"
:
"CastPyArg2String"
,
"std::string"
:
"CastPyArg2String"
,
"std::vector<bool>"
:
"CastPyArg2Booleans"
,
"std::vector<bool>"
:
"CastPyArg2Booleans"
,
"std::vector<int>"
:
"CastPyArg2Ints"
,
"std::vector<int>"
:
"CastPyArg2Ints"
,
...
...
paddle/fluid/framework/attribute.h
浏览文件 @
efab2eb4
...
@@ -180,6 +180,37 @@ struct ExtractAttribute<float> {
...
@@ -180,6 +180,37 @@ struct ExtractAttribute<float> {
const
std
::
string
&
attr_name_
;
const
std
::
string
&
attr_name_
;
};
};
template
<
>
struct
ExtractAttribute
<
double
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
double
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
PADDLE_GET_CONST
(
int
,
attr
);
attr
=
static_cast
<
double
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
int64_t
))
{
// NOLINT
int64_t
val
=
PADDLE_GET_CONST
(
int64_t
,
attr
);
attr
=
static_cast
<
double
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
int64_t
val
=
PADDLE_GET_CONST
(
float
,
attr
);
attr
=
static_cast
<
double
>
(
val
);
}
double
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
paddle
::
get
<
double
>
(
attr
);
}
catch
(
paddle
::
bad_variant_access
const
&
bad_get
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Cannot get attribute (%s) by type double, its type is %s."
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
())));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
>
template
<
>
struct
ExtractAttribute
<
std
::
vector
<
double
>>
{
struct
ExtractAttribute
<
std
::
vector
<
double
>>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
...
...
paddle/fluid/framework/framework.proto
浏览文件 @
efab2eb4
...
@@ -38,6 +38,7 @@ enum AttrType {
...
@@ -38,6 +38,7 @@ enum AttrType {
FLOAT64S
=
12
;
FLOAT64S
=
12
;
VAR
=
13
;
VAR
=
13
;
VARS
=
14
;
VARS
=
14
;
FLOAT64
=
15
;
}
}
// OpDesc describes an instance of a C++ framework::OperatorBase
// OpDesc describes an instance of a C++ framework::OperatorBase
...
@@ -62,6 +63,7 @@ message OpDesc {
...
@@ -62,6 +63,7 @@ message OpDesc {
repeated
double
float64s
=
16
;
repeated
double
float64s
=
16
;
optional
string
var_name
=
17
;
optional
string
var_name
=
17
;
repeated
string
vars_name
=
18
;
repeated
string
vars_name
=
18
;
optional
double
float64
=
19
;
};
};
message
Var
{
message
Var
{
...
...
paddle/fluid/framework/infershape_utils.cc
浏览文件 @
efab2eb4
...
@@ -482,6 +482,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
...
@@ -482,6 +482,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context
.
EmplaceBackAttr
(
infer_meta_context
.
EmplaceBackAttr
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr
)));
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr
)));
break
;
break
;
case
framework
::
proto
::
AttrType
::
FLOAT64
:
infer_meta_context
.
EmplaceBackAttr
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
double
,
attr
)));
break
;
case
framework
::
proto
::
AttrType
::
INT
:
case
framework
::
proto
::
AttrType
::
INT
:
infer_meta_context
.
EmplaceBackAttr
(
infer_meta_context
.
EmplaceBackAttr
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr
)));
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr
)));
...
@@ -651,6 +655,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
...
@@ -651,6 +655,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
case
phi
::
AttributeType
::
FLOAT32
:
case
phi
::
AttributeType
::
FLOAT32
:
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr
));
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr
));
break
;
break
;
case
phi
::
AttributeType
::
FLOAT64
:
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
double
,
attr
));
break
;
case
phi
::
AttributeType
::
INT32
:
case
phi
::
AttributeType
::
INT32
:
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr
));
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr
));
break
;
break
;
...
...
paddle/fluid/framework/op_desc.cc
浏览文件 @
efab2eb4
...
@@ -668,6 +668,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
...
@@ -668,6 +668,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
this
->
attrs_
[
name
]
=
std
::
vector
<
float
>
();
this
->
attrs_
[
name
]
=
std
::
vector
<
float
>
();
break
;
break
;
}
}
case
proto
::
AttrType
::
FLOAT64S
:
{
VLOG
(
11
)
<<
"SetAttr: "
<<
Type
()
<<
", "
<<
name
<<
" from INTS to FLOAT64S"
;
this
->
attrs_
[
name
]
=
std
::
vector
<
double
>
();
break
;
}
case
proto
::
AttrType
::
STRINGS
:
{
case
proto
::
AttrType
::
STRINGS
:
{
VLOG
(
11
)
<<
"SetAttr: "
<<
Type
()
<<
", "
<<
name
VLOG
(
11
)
<<
"SetAttr: "
<<
Type
()
<<
", "
<<
name
<<
" from INTS to STRINGS"
;
<<
" from INTS to STRINGS"
;
...
@@ -838,6 +844,7 @@ struct SetAttrDescVisitor {
...
@@ -838,6 +844,7 @@ struct SetAttrDescVisitor {
mutable
proto
::
OpDesc
::
Attr
*
attr_
;
mutable
proto
::
OpDesc
::
Attr
*
attr_
;
void
operator
()(
int
v
)
const
{
attr_
->
set_i
(
v
);
}
void
operator
()(
int
v
)
const
{
attr_
->
set_i
(
v
);
}
void
operator
()(
float
v
)
const
{
attr_
->
set_f
(
v
);
}
void
operator
()(
float
v
)
const
{
attr_
->
set_f
(
v
);
}
void
operator
()(
double
v
)
const
{
attr_
->
set_float64
(
v
);
}
void
operator
()(
const
std
::
string
&
v
)
const
{
attr_
->
set_s
(
v
);
}
void
operator
()(
const
std
::
string
&
v
)
const
{
attr_
->
set_s
(
v
);
}
// Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
// Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
efab2eb4
...
@@ -2745,6 +2745,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2745,6 +2745,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
))));
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
))));
break
;
break
;
case
proto
::
AttrType
::
FLOAT64
:
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
double
,
attr_iter
->
second
))));
break
;
case
proto
::
AttrType
::
INT
:
case
proto
::
AttrType
::
INT
:
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
))));
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
))));
...
@@ -2884,6 +2888,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2884,6 +2888,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context
->
EmplaceBackAttr
(
phi_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
));
break
;
break
;
case
phi
::
AttributeType
::
FLOAT64
:
phi_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
double
,
attr_iter
->
second
));
break
;
case
phi
::
AttributeType
::
INT32
:
case
phi
::
AttributeType
::
INT32
:
phi_kernel_context
->
EmplaceBackAttr
(
phi_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
));
...
...
paddle/fluid/framework/type_defs.h
浏览文件 @
efab2eb4
...
@@ -58,7 +58,8 @@ using Attribute = paddle::variant<paddle::blank,
...
@@ -58,7 +58,8 @@ using Attribute = paddle::variant<paddle::blank,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
double
>
,
std
::
vector
<
double
>
,
VarDesc
*
,
VarDesc
*
,
std
::
vector
<
VarDesc
*>>
;
std
::
vector
<
VarDesc
*>
,
double
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
#ifdef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_ASCEND_CL
...
...
paddle/fluid/imperative/prepared_operator.h
浏览文件 @
efab2eb4
...
@@ -412,6 +412,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
...
@@ -412,6 +412,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
kernel_ctx
->
EmplaceBackAttr
(
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr
))));
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr
))));
break
;
break
;
case
framework
::
proto
::
AttrType
::
FLOAT64
:
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
double
,
attr
))));
break
;
case
framework
::
proto
::
AttrType
::
INT
:
case
framework
::
proto
::
AttrType
::
INT
:
kernel_ctx
->
EmplaceBackAttr
(
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr
))));
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr
))));
...
@@ -549,6 +553,9 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
...
@@ -549,6 +553,9 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
case
phi
::
AttributeType
::
FLOAT32
:
case
phi
::
AttributeType
::
FLOAT32
:
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr
));
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr
));
break
;
break
;
case
phi
::
AttributeType
::
FLOAT64
:
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
double
,
attr
));
break
;
case
phi
::
AttributeType
::
INT32
:
case
phi
::
AttributeType
::
INT32
:
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr
));
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr
));
break
;
break
;
...
...
paddle/fluid/pybind/op_function.h
浏览文件 @
efab2eb4
...
@@ -126,45 +126,6 @@ CastPyHandleToVarBaseList(const std::string& op_type,
...
@@ -126,45 +126,6 @@ CastPyHandleToVarBaseList(const std::string& op_type,
return
result
;
return
result
;
}
// namespace pybind
}
// namespace pybind
static
inline
void
ConstructAttrMapFromPyArgs
(
const
std
::
string
&
op_type
,
int
start_idx
,
framework
::
AttributeMap
*
attrs
,
const
py
::
args
&
args
)
{
PADDLE_ENFORCE_EQ
(
args
.
size
()
%
2
,
0
,
platform
::
errors
::
InvalidArgument
(
"The number of arguments for arributes should be even."
));
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
+=
2
)
{
std
::
string
name
;
framework
::
Attribute
value
;
try
{
name
=
args
[
i
].
cast
<
std
::
string
>
();
}
catch
(
std
::
exception
&
e
)
{
PyObject
*
py_obj
=
args
[
i
].
ptr
();
// get underlying PyObject
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be str, but got "
"%s"
,
op_type
,
start_idx
+
i
,
Py_TYPE
(
py_obj
)
->
tp_name
));
}
try
{
value
=
args
[
i
+
1
].
cast
<
framework
::
Attribute
>
();
}
catch
(
std
::
exception
&
e
)
{
PyObject
*
py_obj
=
args
[
i
+
1
].
ptr
();
// get underlying PyObject
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"Attribute type (one of str, bool, int, int64, float, or list of "
"them), but got %s"
,
op_type
,
start_idx
+
i
+
1
,
Py_TYPE
(
py_obj
)
->
tp_name
));
}
(
*
attrs
)[
name
]
=
value
;
}
}
static
inline
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
static
inline
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
ConstructDuplicableOutput
(
const
size_t
num
)
{
ConstructDuplicableOutput
(
const
size_t
num
)
{
auto
tracer
=
imperative
::
GetCurrentTracer
();
auto
tracer
=
imperative
::
GetCurrentTracer
();
...
...
paddle/fluid/pybind/op_function_common.cc
浏览文件 @
efab2eb4
...
@@ -188,6 +188,14 @@ float CastPyArg2Float(PyObject* obj,
...
@@ -188,6 +188,14 @@ float CastPyArg2Float(PyObject* obj,
return
static_cast
<
float
>
(
CastPyArg2Double
(
obj
,
op_type
,
arg_pos
));
return
static_cast
<
float
>
(
CastPyArg2Double
(
obj
,
op_type
,
arg_pos
));
}
}
void
CastPyArg2AttrFloat
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Float
(
obj
,
op_type
,
arg_pos
);
}
double
CastPyArg2Double
(
PyObject
*
obj
,
double
CastPyArg2Double
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
...
@@ -196,7 +204,7 @@ double CastPyArg2Double(PyObject* obj,
...
@@ -196,7 +204,7 @@ double CastPyArg2Double(PyObject* obj,
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
"
float
, but got %s"
,
"
double
, but got %s"
,
op_type
,
op_type
,
arg_pos
+
1
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
...
@@ -205,12 +213,12 @@ double CastPyArg2Double(PyObject* obj,
...
@@ -205,12 +213,12 @@ double CastPyArg2Double(PyObject* obj,
return
0.0
;
return
0.0
;
}
}
void
CastPyArg2Attr
Float
(
PyObject
*
obj
,
void
CastPyArg2Attr
Double
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2
Float
(
obj
,
op_type
,
arg_pos
);
attrs
[
key
]
=
CastPyArg2
Double
(
obj
,
op_type
,
arg_pos
);
}
}
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
...
@@ -735,6 +743,9 @@ void ConstructAttrMapFromPyArgs(
...
@@ -735,6 +743,9 @@ void ConstructAttrMapFromPyArgs(
case
paddle
::
framework
::
proto
::
AttrType
::
FLOAT
:
case
paddle
::
framework
::
proto
::
AttrType
::
FLOAT
:
CastPyArg2AttrFloat
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
CastPyArg2AttrFloat
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
FLOAT64
:
CastPyArg2AttrDouble
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
STRING
:
case
paddle
::
framework
::
proto
::
AttrType
::
STRING
:
CastPyArg2AttrString
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
CastPyArg2AttrString
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
break
;
...
...
paddle/fluid/pybind/op_function_common.h
浏览文件 @
efab2eb4
...
@@ -107,6 +107,12 @@ void CastPyArg2AttrFloat(PyObject* obj,
...
@@ -107,6 +107,12 @@ void CastPyArg2AttrFloat(PyObject* obj,
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
ssize_t
arg_pos
);
void
CastPyArg2AttrDouble
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
void
CastPyArg2AttrString
(
PyObject
*
obj
,
void
CastPyArg2AttrString
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
key
,
...
...
paddle/phi/core/enforce.cc
浏览文件 @
efab2eb4
...
@@ -44,7 +44,8 @@ using Attribute = paddle::variant<paddle::blank,
...
@@ -44,7 +44,8 @@ using Attribute = paddle::variant<paddle::blank,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
double
>
,
std
::
vector
<
double
>
,
VarDesc
*
,
VarDesc
*
,
std
::
vector
<
VarDesc
*>>
;
std
::
vector
<
VarDesc
*>
,
double
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
}
// namespace framework
}
// namespace framework
namespace
imperative
{
namespace
imperative
{
...
...
paddle/phi/core/infermeta_utils.h
浏览文件 @
efab2eb4
...
@@ -192,6 +192,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
...
@@ -192,6 +192,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int64_t
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int64_t
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
float
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
float
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
double
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataType
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataType
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
Backend
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
Backend
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataLayout
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataLayout
);
...
...
python/paddle/fluid/op.py
浏览文件 @
efab2eb4
...
@@ -124,6 +124,8 @@ class OpDescCreationMethod(object):
...
@@ -124,6 +124,8 @@ class OpDescCreationMethod(object):
new_attr
.
bools
.
extend
(
user_defined_attr
)
new_attr
.
bools
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
framework_pb2
.
LONGS
:
elif
attr
.
type
==
framework_pb2
.
LONGS
:
new_attr
.
longs
.
extend
(
user_defined_attr
)
new_attr
.
longs
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
framework_pb2
.
FLOAT64
:
new_attr
.
float64
=
user_defined_attr
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"A not supported attribute type: %s."
%
"A not supported attribute type: %s."
%
...
...
python/paddle/fluid/tests/unittests/test_operator.py
浏览文件 @
efab2eb4
...
@@ -16,6 +16,8 @@ from __future__ import print_function
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
paddle.fluid.op
as
op
import
paddle.fluid.op
as
op
import
paddle.fluid.proto.framework_pb2
as
framework_pb2
import
paddle.fluid.proto.framework_pb2
as
framework_pb2
...
@@ -152,6 +154,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -152,6 +154,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
__add_attr__
(
"int_attr"
,
framework_pb2
.
INT
)
__add_attr__
(
"int_attr"
,
framework_pb2
.
INT
)
__add_attr__
(
"float_attr"
,
framework_pb2
.
FLOAT
)
__add_attr__
(
"float_attr"
,
framework_pb2
.
FLOAT
)
__add_attr__
(
"float64_attr"
,
framework_pb2
.
FLOAT64
)
__add_attr__
(
"string_attr"
,
framework_pb2
.
STRING
)
__add_attr__
(
"string_attr"
,
framework_pb2
.
STRING
)
__add_attr__
(
"ints_attr"
,
framework_pb2
.
INTS
)
__add_attr__
(
"ints_attr"
,
framework_pb2
.
INTS
)
__add_attr__
(
"floats_attr"
,
framework_pb2
.
FLOATS
)
__add_attr__
(
"floats_attr"
,
framework_pb2
.
FLOATS
)
...
@@ -165,6 +168,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -165,6 +168,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
generated
=
method
(
X
=
"a"
,
generated
=
method
(
X
=
"a"
,
int_attr
=
10
,
int_attr
=
10
,
float_attr
=
3.2
,
float_attr
=
3.2
,
float64_attr
=
np
.
finfo
(
"float64"
).
max
,
string_attr
=
"test_str"
,
string_attr
=
"test_str"
,
ints_attr
=
[
0
,
1
,
2
,
3
,
4
],
ints_attr
=
[
0
,
1
,
2
,
3
,
4
],
floats_attr
=
[
0.2
,
3.2
,
4.5
],
floats_attr
=
[
0.2
,
3.2
,
4.5
],
...
@@ -187,6 +191,11 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -187,6 +191,11 @@ class TestOpDescCreationMethod(unittest.TestCase):
attr
.
type
=
framework_pb2
.
FLOAT
attr
.
type
=
framework_pb2
.
FLOAT
attr
.
f
=
3.2
attr
.
f
=
3.2
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"float64_attr"
attr
.
type
=
framework_pb2
.
FLOAT64
attr
.
float64
=
np
.
finfo
(
"float64"
).
max
attr
=
expected
.
attrs
.
add
()
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"string_attr"
attr
.
name
=
"string_attr"
attr
.
type
=
framework_pb2
.
STRING
attr
.
type
=
framework_pb2
.
STRING
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录