Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
286c2e0e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
286c2e0e
编写于
4月 10, 2020
作者:
W
Wilber
提交者:
GitHub
4月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
error message enhancement for py_func op. (#23565)
error message enhancement for py_func op.
上级
94a3789f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
51 addition
and
18 deletion
+51
-18
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
+1
-1
paddle/fluid/operators/py_func_op.cc
paddle/fluid/operators/py_func_op.cc
+48
-16
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+2
-1
未找到文件。
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
浏览文件 @
286c2e0e
...
...
@@ -66,7 +66,7 @@ void FusionRepeatedFCReluOp::InferShape(
for
(
size_t
i
=
1
;
i
<
sz
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
w_dims
[
i
].
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Every weight shape size should be 2
.
, but received "
"Every weight shape size should be 2, but received "
"w_dims[%d].size() = %d."
,
i
,
w_dims
[
i
].
size
()));
PADDLE_ENFORCE_EQ
(
...
...
paddle/fluid/operators/py_func_op.cc
浏览文件 @
286c2e0e
...
...
@@ -42,7 +42,11 @@ size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
// Returning py::object would cause reference count increasing
// but without GIL, reference count in Python may not be safe
static
py
::
object
*
GetPythonCallableObject
(
size_t
i
)
{
PADDLE_ENFORCE_LT
(
i
,
g_py_callables
.
size
(),
"Invalid python callable id"
);
PADDLE_ENFORCE_LT
(
i
,
g_py_callables
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Invalid python callable id %d, which should be less than %d."
,
i
,
g_py_callables
.
size
()));
return
&
g_py_callables
[
i
];
}
...
...
@@ -71,10 +75,27 @@ static void CallPythonFunc(py::object *callable,
// Python function has no return values or returns None
// In this case, ret_num = 1 && ret[0] == None && out_num should be 0
// Otherwise, ret_num must be equal to out_num
PADDLE_ENFORCE
(
ret_num
==
1
&&
out_num
==
0
&&
py
::
cast
<
framework
::
LoDTensor
*>
(
ret_tuple
[
0
])
==
nullptr
,
"Output number not match. Expected %d, actual %d"
,
out_num
,
ret_num
);
PADDLE_ENFORCE_EQ
(
ret_num
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"Python function has no return values or returns "
"None. In this case, ret_num = 1 && ret[0] == None "
"&& out_num should be 0. But ret_num is %d"
,
ret_num
));
PADDLE_ENFORCE_EQ
(
out_num
==
0
,
true
,
platform
::
errors
::
InvalidArgument
(
"Python function has no return values or returns None. In "
"this case, ret_num = 1 && ret[0] == None && out_num should "
"be 0. But out_num is %d"
,
out_num
));
PADDLE_ENFORCE_EQ
(
py
::
cast
<
framework
::
LoDTensor
*>
(
ret_tuple
[
0
])
==
nullptr
,
true
,
platform
::
errors
::
InvalidArgument
(
"Python function has no return values or returns None. In "
"this case, ret_num = 1 && ret[0] == None && out_num should "
"be 0. But ret[0] is not None"
));
}
for
(
size_t
i
=
0
;
i
<
out_num
;
++
i
)
{
...
...
@@ -85,7 +106,8 @@ static void CallPythonFunc(py::object *callable,
try
{
auto
*
py_out_tensor
=
py
::
cast
<
framework
::
LoDTensor
*>
(
ret_tuple
[
i
]);
PADDLE_ENFORCE_NOT_NULL
(
py_out_tensor
,
"Output tensor %d should not be nullptr"
,
i
);
platform
::
errors
::
InvalidArgument
(
"Output tensor %d should not be nullptr"
,
i
));
out
->
set_lod
(
py_out_tensor
->
lod
());
out
->
ShareDataWith
(
*
py_out_tensor
);
}
catch
(
py
::
cast_error
&
)
{
...
...
@@ -105,10 +127,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
* X or Out can be empty, so that py_func can be more flexible
* to support Python functions with no input or no output
*/
PADDLE_ENFORCE
(
has_in
||
has_out
,
"Input(X) or Output(Out) must exist"
);
PADDLE_ENFORCE_GE
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
kForwardPythonCallableId
)),
0
,
"Function id cannot be less than 0"
);
PADDLE_ENFORCE_EQ
(
has_in
||
has_out
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(X) or Output(Out) must exist, "
"but has_in is %d, has_out is %d."
,
has_in
,
has_out
));
PADDLE_ENFORCE_GE
(
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
kForwardPythonCallableId
)),
0
,
platform
::
errors
::
InvalidArgument
(
"Function id cannot be less than 0, but received value is %d."
,
boost
::
get
<
int
>
(
ctx
->
GetAttr
(
kForwardPythonCallableId
))));
if
(
!
has_out
)
return
;
...
...
@@ -128,10 +157,12 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
size_t
len
=
out_var_name
.
size
()
-
kGradVarSuffix
.
size
();
if
(
out_var_name
.
substr
(
len
)
==
kGradVarSuffix
)
{
auto
fwd_var_name
=
out_var_name
.
substr
(
0
,
len
);
PADDLE_ENFORCE
(
ctx
->
HasVar
(
out_var_name
),
"Backward variable %s not found"
,
out_var_name
);
PADDLE_ENFORCE
(
ctx
->
HasVar
(
fwd_var_name
),
"Backward variable %s not found"
,
fwd_var_name
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasVar
(
out_var_name
),
true
,
platform
::
errors
::
InvalidArgument
(
"Backward variable %s not found"
,
out_var_name
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasVar
(
fwd_var_name
),
true
,
platform
::
errors
::
InvalidArgument
(
"Backward variable %s not found"
,
fwd_var_name
));
VLOG
(
10
)
<<
"Infer var_desc of Output("
<<
out_var_name
<<
") as Input("
<<
fwd_var_name
<<
")"
;
...
...
@@ -147,8 +178,9 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
class
PyFuncOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
IsRuntime
(),
"Infer shape cannot be called in runtime."
);
PADDLE_ENFORCE_EQ
(
!
ctx
->
IsRuntime
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Infer shape cannot be called in runtime."
));
}
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
286c2e0e
...
...
@@ -12820,6 +12820,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
# [7, 5, 2]] [1, 3, 3]] [8, 8, 5]], dtype=int32)]
"""
helper = LayerHelper('py_func', **locals())
check_type(x, 'X', (list, tuple, Variable, type(None)), 'py_func')
if x is None:
x = []
elif isinstance(x, Variable):
...
...
@@ -12828,7 +12829,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
x = list(x)
elif not isinstance(x, (list, tuple, Variable)):
raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)')
check_type(out, 'Out', (list, tuple, Variable, type(None)), 'py_func')
if out is None:
out_list = []
elif isinstance(out, Variable):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录