Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
efab2eb4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
efab2eb4
编写于
8月 25, 2022
作者:
F
Feiyu Chan
提交者:
GitHub
8月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add support for double attributes (#45390)
上级
0c363de8
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
104 addition
and
48 deletion
+104
-48
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
...auto_code_generator/final_state_generator/python_c_gen.py
+1
-0
paddle/fluid/framework/attribute.h
paddle/fluid/framework/attribute.h
+31
-0
paddle/fluid/framework/framework.proto
paddle/fluid/framework/framework.proto
+2
-0
paddle/fluid/framework/infershape_utils.cc
paddle/fluid/framework/infershape_utils.cc
+8
-0
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+7
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+8
-0
paddle/fluid/framework/type_defs.h
paddle/fluid/framework/type_defs.h
+2
-1
paddle/fluid/imperative/prepared_operator.h
paddle/fluid/imperative/prepared_operator.h
+7
-0
paddle/fluid/pybind/op_function.h
paddle/fluid/pybind/op_function.h
+0
-39
paddle/fluid/pybind/op_function_common.cc
paddle/fluid/pybind/op_function_common.cc
+18
-7
paddle/fluid/pybind/op_function_common.h
paddle/fluid/pybind/op_function_common.h
+6
-0
paddle/phi/core/enforce.cc
paddle/phi/core/enforce.cc
+2
-1
paddle/phi/core/infermeta_utils.h
paddle/phi/core/infermeta_utils.h
+1
-0
python/paddle/fluid/op.py
python/paddle/fluid/op.py
+2
-0
python/paddle/fluid/tests/unittests/test_operator.py
python/paddle/fluid/tests/unittests/test_operator.py
+9
-0
未找到文件。
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
浏览文件 @
efab2eb4
...
@@ -36,6 +36,7 @@ atype_to_parsing_function = {
...
@@ -36,6 +36,7 @@ atype_to_parsing_function = {
"long"
:
"CastPyArg2Long"
,
"long"
:
"CastPyArg2Long"
,
"int64_t"
:
"CastPyArg2Long"
,
"int64_t"
:
"CastPyArg2Long"
,
"float"
:
"CastPyArg2Float"
,
"float"
:
"CastPyArg2Float"
,
"double"
:
"CastPyArg2Double"
,
"std::string"
:
"CastPyArg2String"
,
"std::string"
:
"CastPyArg2String"
,
"std::vector<bool>"
:
"CastPyArg2Booleans"
,
"std::vector<bool>"
:
"CastPyArg2Booleans"
,
"std::vector<int>"
:
"CastPyArg2Ints"
,
"std::vector<int>"
:
"CastPyArg2Ints"
,
...
...
paddle/fluid/framework/attribute.h
浏览文件 @
efab2eb4
...
@@ -180,6 +180,37 @@ struct ExtractAttribute<float> {
...
@@ -180,6 +180,37 @@ struct ExtractAttribute<float> {
const
std
::
string
&
attr_name_
;
const
std
::
string
&
attr_name_
;
};
};
template
<
>
struct
ExtractAttribute
<
double
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
double
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
PADDLE_GET_CONST
(
int
,
attr
);
attr
=
static_cast
<
double
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
int64_t
))
{
// NOLINT
int64_t
val
=
PADDLE_GET_CONST
(
int64_t
,
attr
);
attr
=
static_cast
<
double
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
int64_t
val
=
PADDLE_GET_CONST
(
float
,
attr
);
attr
=
static_cast
<
double
>
(
val
);
}
double
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
paddle
::
get
<
double
>
(
attr
);
}
catch
(
paddle
::
bad_variant_access
const
&
bad_get
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Cannot get attribute (%s) by type double, its type is %s."
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
())));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
>
template
<
>
struct
ExtractAttribute
<
std
::
vector
<
double
>>
{
struct
ExtractAttribute
<
std
::
vector
<
double
>>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
...
...
paddle/fluid/framework/framework.proto
浏览文件 @
efab2eb4
...
@@ -38,6 +38,7 @@ enum AttrType {
...
@@ -38,6 +38,7 @@ enum AttrType {
FLOAT64S
=
12
;
FLOAT64S
=
12
;
VAR
=
13
;
VAR
=
13
;
VARS
=
14
;
VARS
=
14
;
FLOAT64
=
15
;
}
}
// OpDesc describes an instance of a C++ framework::OperatorBase
// OpDesc describes an instance of a C++ framework::OperatorBase
...
@@ -62,6 +63,7 @@ message OpDesc {
...
@@ -62,6 +63,7 @@ message OpDesc {
repeated
double
float64s
=
16
;
repeated
double
float64s
=
16
;
optional
string
var_name
=
17
;
optional
string
var_name
=
17
;
repeated
string
vars_name
=
18
;
repeated
string
vars_name
=
18
;
optional
double
float64
=
19
;
};
};
message
Var
{
message
Var
{
...
...
paddle/fluid/framework/infershape_utils.cc
浏览文件 @
efab2eb4
...
@@ -482,6 +482,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
...
@@ -482,6 +482,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context
.
EmplaceBackAttr
(
infer_meta_context
.
EmplaceBackAttr
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr
)));
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr
)));
break
;
break
;
case
framework
::
proto
::
AttrType
::
FLOAT64
:
infer_meta_context
.
EmplaceBackAttr
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
double
,
attr
)));
break
;
case
framework
::
proto
::
AttrType
::
INT
:
case
framework
::
proto
::
AttrType
::
INT
:
infer_meta_context
.
EmplaceBackAttr
(
infer_meta_context
.
EmplaceBackAttr
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr
)));
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr
)));
...
@@ -651,6 +655,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
...
@@ -651,6 +655,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
case
phi
::
AttributeType
::
FLOAT32
:
case
phi
::
AttributeType
::
FLOAT32
:
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr
));
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr
));
break
;
break
;
case
phi
::
AttributeType
::
FLOAT64
:
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
double
,
attr
));
break
;
case
phi
::
AttributeType
::
INT32
:
case
phi
::
AttributeType
::
INT32
:
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr
));
infer_meta_context
.
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr
));
break
;
break
;
...
...
paddle/fluid/framework/op_desc.cc
浏览文件 @
efab2eb4
...
@@ -668,6 +668,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
...
@@ -668,6 +668,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
this
->
attrs_
[
name
]
=
std
::
vector
<
float
>
();
this
->
attrs_
[
name
]
=
std
::
vector
<
float
>
();
break
;
break
;
}
}
case
proto
::
AttrType
::
FLOAT64S
:
{
VLOG
(
11
)
<<
"SetAttr: "
<<
Type
()
<<
", "
<<
name
<<
" from INTS to FLOAT64S"
;
this
->
attrs_
[
name
]
=
std
::
vector
<
double
>
();
break
;
}
case
proto
::
AttrType
::
STRINGS
:
{
case
proto
::
AttrType
::
STRINGS
:
{
VLOG
(
11
)
<<
"SetAttr: "
<<
Type
()
<<
", "
<<
name
VLOG
(
11
)
<<
"SetAttr: "
<<
Type
()
<<
", "
<<
name
<<
" from INTS to STRINGS"
;
<<
" from INTS to STRINGS"
;
...
@@ -838,6 +844,7 @@ struct SetAttrDescVisitor {
...
@@ -838,6 +844,7 @@ struct SetAttrDescVisitor {
mutable
proto
::
OpDesc
::
Attr
*
attr_
;
mutable
proto
::
OpDesc
::
Attr
*
attr_
;
void
operator
()(
int
v
)
const
{
attr_
->
set_i
(
v
);
}
void
operator
()(
int
v
)
const
{
attr_
->
set_i
(
v
);
}
void
operator
()(
float
v
)
const
{
attr_
->
set_f
(
v
);
}
void
operator
()(
float
v
)
const
{
attr_
->
set_f
(
v
);
}
void
operator
()(
double
v
)
const
{
attr_
->
set_float64
(
v
);
}
void
operator
()(
const
std
::
string
&
v
)
const
{
attr_
->
set_s
(
v
);
}
void
operator
()(
const
std
::
string
&
v
)
const
{
attr_
->
set_s
(
v
);
}
// Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
// Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
efab2eb4
...
@@ -2745,6 +2745,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2745,6 +2745,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
))));
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
))));
break
;
break
;
case
proto
::
AttrType
::
FLOAT64
:
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
double
,
attr_iter
->
second
))));
break
;
case
proto
::
AttrType
::
INT
:
case
proto
::
AttrType
::
INT
:
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
))));
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
))));
...
@@ -2884,6 +2888,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2884,6 +2888,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context
->
EmplaceBackAttr
(
phi_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
));
break
;
break
;
case
phi
::
AttributeType
::
FLOAT64
:
phi_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
double
,
attr_iter
->
second
));
break
;
case
phi
::
AttributeType
::
INT32
:
case
phi
::
AttributeType
::
INT32
:
phi_kernel_context
->
EmplaceBackAttr
(
phi_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
));
...
...
paddle/fluid/framework/type_defs.h
浏览文件 @
efab2eb4
...
@@ -58,7 +58,8 @@ using Attribute = paddle::variant<paddle::blank,
...
@@ -58,7 +58,8 @@ using Attribute = paddle::variant<paddle::blank,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
double
>
,
std
::
vector
<
double
>
,
VarDesc
*
,
VarDesc
*
,
std
::
vector
<
VarDesc
*>>
;
std
::
vector
<
VarDesc
*>
,
double
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
#ifdef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_ASCEND_CL
...
...
paddle/fluid/imperative/prepared_operator.h
浏览文件 @
efab2eb4
...
@@ -412,6 +412,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
...
@@ -412,6 +412,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
kernel_ctx
->
EmplaceBackAttr
(
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr
))));
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr
))));
break
;
break
;
case
framework
::
proto
::
AttrType
::
FLOAT64
:
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
double
,
attr
))));
break
;
case
framework
::
proto
::
AttrType
::
INT
:
case
framework
::
proto
::
AttrType
::
INT
:
kernel_ctx
->
EmplaceBackAttr
(
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr
))));
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr
))));
...
@@ -549,6 +553,9 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
...
@@ -549,6 +553,9 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
case
phi
::
AttributeType
::
FLOAT32
:
case
phi
::
AttributeType
::
FLOAT32
:
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr
));
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr
));
break
;
break
;
case
phi
::
AttributeType
::
FLOAT64
:
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
double
,
attr
));
break
;
case
phi
::
AttributeType
::
INT32
:
case
phi
::
AttributeType
::
INT32
:
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr
));
kernel_ctx
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr
));
break
;
break
;
...
...
paddle/fluid/pybind/op_function.h
浏览文件 @
efab2eb4
...
@@ -126,45 +126,6 @@ CastPyHandleToVarBaseList(const std::string& op_type,
...
@@ -126,45 +126,6 @@ CastPyHandleToVarBaseList(const std::string& op_type,
return
result
;
return
result
;
}
// namespace pybind
}
// namespace pybind
static
inline
void
ConstructAttrMapFromPyArgs
(
const
std
::
string
&
op_type
,
int
start_idx
,
framework
::
AttributeMap
*
attrs
,
const
py
::
args
&
args
)
{
PADDLE_ENFORCE_EQ
(
args
.
size
()
%
2
,
0
,
platform
::
errors
::
InvalidArgument
(
"The number of arguments for arributes should be even."
));
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
+=
2
)
{
std
::
string
name
;
framework
::
Attribute
value
;
try
{
name
=
args
[
i
].
cast
<
std
::
string
>
();
}
catch
(
std
::
exception
&
e
)
{
PyObject
*
py_obj
=
args
[
i
].
ptr
();
// get underlying PyObject
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be str, but got "
"%s"
,
op_type
,
start_idx
+
i
,
Py_TYPE
(
py_obj
)
->
tp_name
));
}
try
{
value
=
args
[
i
+
1
].
cast
<
framework
::
Attribute
>
();
}
catch
(
std
::
exception
&
e
)
{
PyObject
*
py_obj
=
args
[
i
+
1
].
ptr
();
// get underlying PyObject
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"Attribute type (one of str, bool, int, int64, float, or list of "
"them), but got %s"
,
op_type
,
start_idx
+
i
+
1
,
Py_TYPE
(
py_obj
)
->
tp_name
));
}
(
*
attrs
)[
name
]
=
value
;
}
}
static
inline
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
static
inline
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
ConstructDuplicableOutput
(
const
size_t
num
)
{
ConstructDuplicableOutput
(
const
size_t
num
)
{
auto
tracer
=
imperative
::
GetCurrentTracer
();
auto
tracer
=
imperative
::
GetCurrentTracer
();
...
...
paddle/fluid/pybind/op_function_common.cc
浏览文件 @
efab2eb4
...
@@ -188,6 +188,14 @@ float CastPyArg2Float(PyObject* obj,
...
@@ -188,6 +188,14 @@ float CastPyArg2Float(PyObject* obj,
return
static_cast
<
float
>
(
CastPyArg2Double
(
obj
,
op_type
,
arg_pos
));
return
static_cast
<
float
>
(
CastPyArg2Double
(
obj
,
op_type
,
arg_pos
));
}
}
void
CastPyArg2AttrFloat
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Float
(
obj
,
op_type
,
arg_pos
);
}
double
CastPyArg2Double
(
PyObject
*
obj
,
double
CastPyArg2Double
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
...
@@ -196,7 +204,7 @@ double CastPyArg2Double(PyObject* obj,
...
@@ -196,7 +204,7 @@ double CastPyArg2Double(PyObject* obj,
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
"
float
, but got %s"
,
"
double
, but got %s"
,
op_type
,
op_type
,
arg_pos
+
1
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
...
@@ -205,12 +213,12 @@ double CastPyArg2Double(PyObject* obj,
...
@@ -205,12 +213,12 @@ double CastPyArg2Double(PyObject* obj,
return
0.0
;
return
0.0
;
}
}
void
CastPyArg2Attr
Float
(
PyObject
*
obj
,
void
CastPyArg2Attr
Double
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2
Float
(
obj
,
op_type
,
arg_pos
);
attrs
[
key
]
=
CastPyArg2
Double
(
obj
,
op_type
,
arg_pos
);
}
}
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
...
@@ -735,6 +743,9 @@ void ConstructAttrMapFromPyArgs(
...
@@ -735,6 +743,9 @@ void ConstructAttrMapFromPyArgs(
case
paddle
::
framework
::
proto
::
AttrType
::
FLOAT
:
case
paddle
::
framework
::
proto
::
AttrType
::
FLOAT
:
CastPyArg2AttrFloat
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
CastPyArg2AttrFloat
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
FLOAT64
:
CastPyArg2AttrDouble
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
STRING
:
case
paddle
::
framework
::
proto
::
AttrType
::
STRING
:
CastPyArg2AttrString
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
CastPyArg2AttrString
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
break
;
...
...
paddle/fluid/pybind/op_function_common.h
浏览文件 @
efab2eb4
...
@@ -107,6 +107,12 @@ void CastPyArg2AttrFloat(PyObject* obj,
...
@@ -107,6 +107,12 @@ void CastPyArg2AttrFloat(PyObject* obj,
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
ssize_t
arg_pos
);
void
CastPyArg2AttrDouble
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
void
CastPyArg2AttrString
(
PyObject
*
obj
,
void
CastPyArg2AttrString
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
key
,
...
...
paddle/phi/core/enforce.cc
浏览文件 @
efab2eb4
...
@@ -44,7 +44,8 @@ using Attribute = paddle::variant<paddle::blank,
...
@@ -44,7 +44,8 @@ using Attribute = paddle::variant<paddle::blank,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
double
>
,
std
::
vector
<
double
>
,
VarDesc
*
,
VarDesc
*
,
std
::
vector
<
VarDesc
*>>
;
std
::
vector
<
VarDesc
*>
,
double
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
}
// namespace framework
}
// namespace framework
namespace
imperative
{
namespace
imperative
{
...
...
paddle/phi/core/infermeta_utils.h
浏览文件 @
efab2eb4
...
@@ -192,6 +192,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
...
@@ -192,6 +192,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int64_t
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int64_t
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
float
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
float
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
double
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataType
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataType
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
Backend
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
Backend
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataLayout
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataLayout
);
...
...
python/paddle/fluid/op.py
浏览文件 @
efab2eb4
...
@@ -124,6 +124,8 @@ class OpDescCreationMethod(object):
...
@@ -124,6 +124,8 @@ class OpDescCreationMethod(object):
new_attr
.
bools
.
extend
(
user_defined_attr
)
new_attr
.
bools
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
framework_pb2
.
LONGS
:
elif
attr
.
type
==
framework_pb2
.
LONGS
:
new_attr
.
longs
.
extend
(
user_defined_attr
)
new_attr
.
longs
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
framework_pb2
.
FLOAT64
:
new_attr
.
float64
=
user_defined_attr
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"A not supported attribute type: %s."
%
"A not supported attribute type: %s."
%
...
...
python/paddle/fluid/tests/unittests/test_operator.py
浏览文件 @
efab2eb4
...
@@ -16,6 +16,8 @@ from __future__ import print_function
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
paddle.fluid.op
as
op
import
paddle.fluid.op
as
op
import
paddle.fluid.proto.framework_pb2
as
framework_pb2
import
paddle.fluid.proto.framework_pb2
as
framework_pb2
...
@@ -152,6 +154,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -152,6 +154,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
__add_attr__
(
"int_attr"
,
framework_pb2
.
INT
)
__add_attr__
(
"int_attr"
,
framework_pb2
.
INT
)
__add_attr__
(
"float_attr"
,
framework_pb2
.
FLOAT
)
__add_attr__
(
"float_attr"
,
framework_pb2
.
FLOAT
)
__add_attr__
(
"float64_attr"
,
framework_pb2
.
FLOAT64
)
__add_attr__
(
"string_attr"
,
framework_pb2
.
STRING
)
__add_attr__
(
"string_attr"
,
framework_pb2
.
STRING
)
__add_attr__
(
"ints_attr"
,
framework_pb2
.
INTS
)
__add_attr__
(
"ints_attr"
,
framework_pb2
.
INTS
)
__add_attr__
(
"floats_attr"
,
framework_pb2
.
FLOATS
)
__add_attr__
(
"floats_attr"
,
framework_pb2
.
FLOATS
)
...
@@ -165,6 +168,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -165,6 +168,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
generated
=
method
(
X
=
"a"
,
generated
=
method
(
X
=
"a"
,
int_attr
=
10
,
int_attr
=
10
,
float_attr
=
3.2
,
float_attr
=
3.2
,
float64_attr
=
np
.
finfo
(
"float64"
).
max
,
string_attr
=
"test_str"
,
string_attr
=
"test_str"
,
ints_attr
=
[
0
,
1
,
2
,
3
,
4
],
ints_attr
=
[
0
,
1
,
2
,
3
,
4
],
floats_attr
=
[
0.2
,
3.2
,
4.5
],
floats_attr
=
[
0.2
,
3.2
,
4.5
],
...
@@ -187,6 +191,11 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -187,6 +191,11 @@ class TestOpDescCreationMethod(unittest.TestCase):
attr
.
type
=
framework_pb2
.
FLOAT
attr
.
type
=
framework_pb2
.
FLOAT
attr
.
f
=
3.2
attr
.
f
=
3.2
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"float64_attr"
attr
.
type
=
framework_pb2
.
FLOAT64
attr
.
float64
=
np
.
finfo
(
"float64"
).
max
attr
=
expected
.
attrs
.
add
()
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"string_attr"
attr
.
name
=
"string_attr"
attr
.
type
=
framework_pb2
.
STRING
attr
.
type
=
framework_pb2
.
STRING
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录