Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
307801d5
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
307801d5
编写于
8月 16, 2022
作者:
F
Feiyu Chan
提交者:
GitHub
8月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add strongly typed functions to set attributes to avoid unexpected type conversions. (#45107)
上级
642f6df9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
74 addition
and
3 deletion
+74
-3
paddle/fluid/framework/op_desc.h
paddle/fluid/framework/op_desc.h
+9
-0
paddle/fluid/pybind/protobuf.cc
paddle/fluid/pybind/protobuf.cc
+17
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+46
-1
python/paddle/fluid/tests/unittests/test_fold_op.py
python/paddle/fluid/tests/unittests/test_fold_op.py
+1
-1
python/paddle/fluid/tests/unittests/test_histogram_op.py
python/paddle/fluid/tests/unittests/test_histogram_op.py
+1
-1
未找到文件。
paddle/fluid/framework/op_desc.h
浏览文件 @
307801d5
...
@@ -96,6 +96,15 @@ class OpDesc {
...
@@ -96,6 +96,15 @@ class OpDesc {
void
SetAttr
(
const
std
::
string
&
name
,
const
Attribute
&
v
);
void
SetAttr
(
const
std
::
string
&
name
,
const
Attribute
&
v
);
void
RemoveAttr
(
const
std
::
string
&
name
);
void
RemoveAttr
(
const
std
::
string
&
name
);
// NOTE(chenfeiyu): this template is added to avoid using a variant(Attribute)
// as a parameter of a function which is bound to python, which causes
// unexpected type conversion due to the overload resolution mechanism
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
template
<
typename
T
>
void
SetPlainAttr
(
const
std
::
string
&
name
,
const
T
&
value
)
{
SetAttr
(
name
,
value
);
}
void
SetVarAttr
(
const
std
::
string
&
name
,
VarDesc
*
var
);
void
SetVarAttr
(
const
std
::
string
&
name
,
VarDesc
*
var
);
void
SetVarsAttr
(
const
std
::
string
&
name
,
std
::
vector
<
VarDesc
*>
vars
);
void
SetVarsAttr
(
const
std
::
string
&
name
,
std
::
vector
<
VarDesc
*>
vars
);
...
...
paddle/fluid/pybind/protobuf.cc
浏览文件 @
307801d5
...
@@ -286,6 +286,8 @@ void BindOpDesc(pybind11::module *m) {
...
@@ -286,6 +286,8 @@ void BindOpDesc(pybind11::module *m) {
.
value
(
"LONGS"
,
pd
::
proto
::
AttrType
::
LONGS
)
.
value
(
"LONGS"
,
pd
::
proto
::
AttrType
::
LONGS
)
.
value
(
"FLOAT"
,
pd
::
proto
::
AttrType
::
FLOAT
)
.
value
(
"FLOAT"
,
pd
::
proto
::
AttrType
::
FLOAT
)
.
value
(
"FLOATS"
,
pd
::
proto
::
AttrType
::
FLOATS
)
.
value
(
"FLOATS"
,
pd
::
proto
::
AttrType
::
FLOATS
)
// .value("FLOAT64", pd::proto::AttrType::FLOAT64)
.
value
(
"FLOAT64S"
,
pd
::
proto
::
AttrType
::
FLOAT64S
)
.
value
(
"STRING"
,
pd
::
proto
::
AttrType
::
STRING
)
.
value
(
"STRING"
,
pd
::
proto
::
AttrType
::
STRING
)
.
value
(
"STRINGS"
,
pd
::
proto
::
AttrType
::
STRINGS
)
.
value
(
"STRINGS"
,
pd
::
proto
::
AttrType
::
STRINGS
)
.
value
(
"BOOL"
,
pd
::
proto
::
AttrType
::
BOOLEAN
)
.
value
(
"BOOL"
,
pd
::
proto
::
AttrType
::
BOOLEAN
)
...
@@ -361,6 +363,21 @@ void BindOpDesc(pybind11::module *m) {
...
@@ -361,6 +363,21 @@ void BindOpDesc(pybind11::module *m) {
py
::
arg
(
"with_attr_var"
)
=
false
)
py
::
arg
(
"with_attr_var"
)
=
false
)
.
def
(
"_set_attr"
,
&
pd
::
OpDesc
::
SetAttr
)
.
def
(
"_set_attr"
,
&
pd
::
OpDesc
::
SetAttr
)
.
def
(
"remove_attr"
,
&
pd
::
OpDesc
::
RemoveAttr
)
.
def
(
"remove_attr"
,
&
pd
::
OpDesc
::
RemoveAttr
)
.
def
(
"_set_bool_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
bool
>
)
.
def
(
"_set_int32_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
int
>
)
.
def
(
"_set_int64_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
int64_t
>
)
.
def
(
"_set_float32_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
float
>
)
// .def("_set_float64_attr", &pd::OpDesc::SetPlainAttr<double>)
.
def
(
"_set_str_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
std
::
string
>
)
.
def
(
"_set_bools_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
std
::
vector
<
bool
>>
)
.
def
(
"_set_int32s_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
std
::
vector
<
int
>>
)
.
def
(
"_set_int64s_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
std
::
vector
<
int64_t
>>
)
.
def
(
"_set_float32s_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
std
::
vector
<
float
>>
)
.
def
(
"_set_float64s_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
std
::
vector
<
double
>>
)
.
def
(
"_set_strs_attr"
,
&
pd
::
OpDesc
::
SetPlainAttr
<
std
::
vector
<
std
::
string
>>
)
.
def
(
.
def
(
"attr"
,
"attr"
,
[](
pd
::
OpDesc
&
self
,
const
std
::
string
&
name
,
bool
with_attr_var
)
{
[](
pd
::
OpDesc
&
self
,
const
std
::
string
&
name
,
bool
with_attr_var
)
{
...
...
python/paddle/fluid/framework.py
浏览文件 @
307801d5
...
@@ -2675,6 +2675,16 @@ class Operator(object):
...
@@ -2675,6 +2675,16 @@ class Operator(object):
inputs
=
None
,
inputs
=
None
,
outputs
=
None
,
outputs
=
None
,
attrs
=
None
):
attrs
=
None
):
# read attr type index from op proto to avoid unexpected type
# conversions, e.g. narrowing conversion like double to float
try
:
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
type
)
self
.
_attr_types
=
{}
for
attr
in
proto
.
attrs
:
self
.
_attr_types
[
attr
.
name
]
=
attr
.
type
except
ValueError
:
pass
if
_non_static_mode
():
if
_non_static_mode
():
if
type
is
None
:
if
type
is
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -3159,7 +3169,42 @@ class Operator(object):
...
@@ -3159,7 +3169,42 @@ class Operator(object):
isinstance
(
val
,
core
.
ProgramDesc
):
isinstance
(
val
,
core
.
ProgramDesc
):
self
.
desc
.
set_serialized_attr
(
name
,
val
.
serialize_to_string
())
self
.
desc
.
set_serialized_attr
(
name
,
val
.
serialize_to_string
())
else
:
else
:
self
.
desc
.
_set_attr
(
name
,
val
)
self
.
_update_desc_plain_attr
(
name
,
val
)
def
_update_desc_plain_attr
(
self
,
name
,
val
):
desc
=
self
.
desc
if
not
hasattr
(
self
,
"_attr_types"
)
or
(
name
not
in
self
.
_attr_types
):
desc
.
_set_attr
(
name
,
val
)
return
type_index
=
self
.
_attr_types
[
name
]
if
type_index
==
core
.
AttrType
.
BOOL
:
desc
.
_set_bool_attr
(
name
,
val
)
elif
type_index
==
core
.
AttrType
.
INT
:
desc
.
_set_int32_attr
(
name
,
val
)
elif
type_index
==
core
.
AttrType
.
LONG
:
desc
.
_set_int64_attr
(
name
,
val
)
elif
type_index
==
core
.
AttrType
.
FLOAT
:
desc
.
_set_float32_attr
(
name
,
val
)
# elif type_index == core.AttrType.FLOAT64:
# desc._set_float64_attr(name, val)
elif
type_index
==
core
.
AttrType
.
STRING
:
desc
.
_set_str_attr
(
name
,
val
)
elif
type_index
==
core
.
AttrType
.
BOOLS
:
desc
.
_set_bools_attr
(
name
,
val
)
elif
type_index
==
core
.
AttrType
.
INTS
:
desc
.
_set_int32s_attr
(
name
,
val
)
elif
type_index
==
core
.
AttrType
.
LONGS
:
desc
.
_set_int64s_attr
(
name
,
val
)
elif
type_index
==
core
.
AttrType
.
FLOATS
:
desc
.
_set_float32s_attr
(
name
,
val
)
elif
type_index
==
core
.
AttrType
.
FLOAT64S
:
desc
.
_set_float64s_attr
(
name
,
val
)
elif
type_index
==
core
.
AttrType
.
STRINGS
:
desc
.
_set_strs_attr
(
name
,
val
)
else
:
# defaults to old methods
desc
.
_set_attr
(
name
,
val
)
@
property
@
property
def
attr_names
(
self
):
def
attr_names
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_fold_op.py
浏览文件 @
307801d5
...
@@ -206,7 +206,7 @@ class TestFoldOpError(unittest.TestCase):
...
@@ -206,7 +206,7 @@ class TestFoldOpError(unittest.TestCase):
self
.
assertRaises
(
AssertionError
,
test_dilations_shape
)
self
.
assertRaises
(
AssertionError
,
test_dilations_shape
)
self
.
assertRaises
(
AssertionError
,
test_strides_shape
)
self
.
assertRaises
(
AssertionError
,
test_strides_shape
)
self
.
assertRaises
(
ValueError
,
test_output_size
)
self
.
assertRaises
(
ValueError
,
test_output_size
)
self
.
assertRaises
(
Valu
eError
,
test_output_size_2
)
self
.
assertRaises
(
Typ
eError
,
test_output_size_2
)
self
.
assertRaises
(
ValueError
,
test_block_h_w
)
self
.
assertRaises
(
ValueError
,
test_block_h_w
)
self
.
assertRaises
(
ValueError
,
test_GT_0
)
self
.
assertRaises
(
ValueError
,
test_GT_0
)
...
...
python/paddle/fluid/tests/unittests/test_histogram_op.py
浏览文件 @
307801d5
...
@@ -111,7 +111,7 @@ class TestHistogramOpError(unittest.TestCase):
...
@@ -111,7 +111,7 @@ class TestHistogramOpError(unittest.TestCase):
value
=
3.0
)
value
=
3.0
)
paddle
.
histogram
(
input
=
input_value
,
bins
=
1
,
min
=-
np
.
inf
,
max
=
5
)
paddle
.
histogram
(
input
=
input_value
,
bins
=
1
,
min
=-
np
.
inf
,
max
=
5
)
with
self
.
assertRaises
(
Valu
eError
):
with
self
.
assertRaises
(
Typ
eError
):
self
.
run_network
(
net_func
)
self
.
run_network
(
net_func
)
def
test_type_errors
(
self
):
def
test_type_errors
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录