Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f0df62f1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f0df62f1
编写于
12月 13, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add more unittest case
test=develop
上级
f6741df4
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
60 addition
and
26 deletion
+60
-26
paddle/fluid/operators/py_func_op.cc
paddle/fluid/operators/py_func_op.cc
+20
-13
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+23
-11
python/paddle/fluid/tests/unittests/test_py_func_op.py
python/paddle/fluid/tests/unittests/test_py_func_op.py
+16
-1
未找到文件。
paddle/fluid/operators/py_func_op.cc
浏览文件 @
f0df62f1
...
...
@@ -35,6 +35,9 @@ size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
return
g_py_callables
.
size
()
-
1
;
}
// Return py::object* instead of py::object
// Returning py::object would cause reference count increasing
// but without GIL, reference count in Python may not be safe
static
py
::
object
*
GetPythonCallableObject
(
size_t
i
)
{
PADDLE_ENFORCE_LT
(
i
,
g_py_callables
.
size
(),
"Invalid python callable id"
);
return
&
g_py_callables
[
i
];
...
...
@@ -47,7 +50,7 @@ static std::string PythonObjectToString(const py::object &py_callable) {
static
void
CallPythonFunc
(
py
::
object
*
callable
,
const
std
::
vector
<
framework
::
LoDTensor
>
&
ins
,
std
::
vector
<
framework
::
LoDTensor
*>
*
out
)
{
std
::
vector
<
framework
::
LoDTensor
*>
*
out
s
)
{
py
::
gil_scoped_acquire
guard
;
py
::
tuple
in_args
(
ins
.
size
());
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
...
...
@@ -57,8 +60,8 @@ static void CallPythonFunc(py::object *callable,
auto
ret
=
(
*
callable
)(
*
in_args
);
auto
ret_tuple
=
py
::
cast
<
py
::
tuple
>
(
ret
);
size_t
ret_num
=
py
::
len
(
ret_tuple
);
size_t
out_num
=
out
->
size
();
if
(
ret_num
!=
out_num
)
{
size_t
out_num
=
out
s
->
size
();
if
(
UNLIKELY
(
ret_num
!=
out_num
)
)
{
// Python function has no return values or returns None
// In this case, ret_num = 1 && ret[0] == None && out_num should be 0
// Otherwise, ret_num must be equal to out_num
...
...
@@ -69,17 +72,18 @@ static void CallPythonFunc(py::object *callable,
}
for
(
size_t
i
=
0
;
i
<
out_num
;
++
i
)
{
if
((
*
out
)[
i
]
==
nullptr
)
{
auto
*
out
=
(
*
outs
)[
i
];
if
(
out
==
nullptr
)
{
continue
;
}
try
{
auto
*
out_tensor
=
py
::
cast
<
framework
::
LoDTensor
*>
(
ret_tuple
[
i
]);
PADDLE_ENFORCE_NOT_NULL
(
out_tensor
,
auto
*
py_
out_tensor
=
py
::
cast
<
framework
::
LoDTensor
*>
(
ret_tuple
[
i
]);
PADDLE_ENFORCE_NOT_NULL
(
py_
out_tensor
,
"Output tensor %d should not be nullptr"
,
i
);
(
*
out
)[
i
]
->
set_lod
(
out_tensor
->
lod
());
(
*
out
)[
i
]
->
ShareDataWith
(
*
out_tensor
);
out
->
set_lod
(
py_
out_tensor
->
lod
());
out
->
ShareDataWith
(
*
py_
out_tensor
);
}
catch
(
py
::
cast_error
&
)
{
PADDLE_THROW
(
"
Output %d is not
LoDTensor"
,
i
);
PADDLE_THROW
(
"
The %d-th output must be
LoDTensor"
,
i
);
}
}
}
...
...
@@ -94,6 +98,10 @@ class PyFuncOpShapeInference : public framework::InferShapeBase {
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
int
>
(
kForwardPythonCallableId
),
0
,
"Function id cannot be less than 0"
);
// Transverse all outputs
// If name of any output ends with @GRAD,
// set its shape, dtype, lod_level, type to be the same as
// the correponding forward variable
auto
*
op
=
boost
::
get
<
const
framework
::
OpDesc
*>
(
ctx
->
GetOp
());
auto
*
block
=
op
->
Block
();
const
std
::
string
kGradVarSuffix
=
framework
::
kGradVarSuffix
;
...
...
@@ -115,7 +123,7 @@ class PyFuncOpShapeInference : public framework::InferShapeBase {
auto
*
in_var_desc
=
block
->
FindVarRecursive
(
fwd_var_name
);
PADDLE_ENFORCE_NOT_NULL
(
in_var_desc
,
"Forward variable %s not found"
,
fwd_var_name
);
VLOG
(
10
)
<<
"Infer shape of Out("
<<
out_name
<<
") as Input("
VLOG
(
10
)
<<
"Infer shape of Out
put
("
<<
out_name
<<
") as Input("
<<
in_var_desc
->
Name
()
<<
")"
;
out_var_desc
->
SetShape
(
in_var_desc
->
GetShape
());
out_var_desc
->
SetDataType
(
in_var_desc
->
GetDataType
());
...
...
@@ -135,7 +143,7 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
"Index of registered forward Python function."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
kBackwardPythonCallableId
,
"Index of registered backward Python function"
)
"Index of registered backward Python function
.
"
)
.
SetDefault
(
-
1
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
kPyFuncBackwardSkipVars
,
"Unused forward in/out in backward op"
)
...
...
@@ -170,8 +178,7 @@ class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase {
auto
fwd_outs
=
Output
(
"Out"
);
// For memory reused, some inputs/output in forward part may be not needed
// in backward part
// Just skip these vars
// in backward part. Skipping these vars helps to save memory
auto
&
backward_skip_var_list
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
fwd_attrs
.
at
(
kPyFuncBackwardSkipVars
));
std
::
unordered_set
<
std
::
string
>
backward_skip_var_set
(
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
f0df62f1
...
...
@@ -104,7 +104,7 @@ PYBIND11_MODULE(core, m) {
BindException
(
&
m
);
m
.
def
(
"append_python_callable_object_and_return_id"
,
"
_
append_python_callable_object_and_return_id"
,
[](
py
::
object
py_obj
)
->
size_t
{
return
paddle
::
operators
::
AppendPythonCallableObjectAndReturnId
(
py_obj
);
});
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
f0df62f1
...
...
@@ -9137,8 +9137,13 @@ class PyFuncRegistry(object):
self
.
_func
=
func
# find named args using reflection
self
.
_named_args
=
inspect
.
getargspec
(
self
.
_func
)[
0
]
self
.
_id
=
core
.
append_python_callable_object_and_return_id
(
self
)
args
=
inspect
.
getargspec
(
self
.
_func
)
if
len
(
args
[
0
])
==
0
and
args
[
1
]
is
None
and
args
[
2
]
is
None
:
# Function with no inputs
self
.
_named_args
=
None
else
:
self
.
_named_args
=
args
[
0
]
self
.
_id
=
core
.
_append_python_callable_object_and_return_id
(
self
)
'''
Why record self here?
...
...
@@ -9168,13 +9173,16 @@ class PyFuncRegistry(object):
return
self
.
_id
def
__call__
(
self
,
*
args
):
if
self
.
_named_args
is
None
:
func_ret
=
self
.
_func
()
else
:
kwargs
=
dict
()
idx
=
0
for
arg
in
self
.
_named_args
:
kwargs
[
arg
]
=
args
[
idx
]
idx
+=
1
func_ret
=
self
.
_func
(
*
args
[
idx
:],
**
kwargs
)
if
not
isinstance
(
func_ret
,
(
list
,
tuple
)):
func_ret
=
(
func_ret
,
)
...
...
@@ -9207,14 +9215,18 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
User should set the right data type and shape of :code:`out` before
calling this function. However, data types and shapes of gradients of
:code:`out` and :code:`x` would be infered automatically.
:code:`out` and :code:`x` would be infer
r
ed automatically.
The orders of inputs of :code:`backward_func` would be: forward input
:code:`x`, forward output
:code:`out` and backward input gradient
of
Input orders of :code:`backward_func` would be: forward inputs
:code:`x`, forward output
s :code:`out` and backward input gradients
of
:code:`out`. If some variables of :code:`out` have no gradient, the input
tensor would be None in Python side. If some variables of :code:`in` have
no gradient, users should return None.
This function can also be used to debug the running network. User can
add a :code:`py_func` operator without output, and print input
:code:`x` inside :code:`func`.
Args:
func (callable): forward Python function.
x (Variable|list(Variable)|tuple(Variable)): inputs of :code:`func`.
...
...
python/paddle/fluid/tests/unittests/test_py_func_op.py
浏览文件 @
f0df62f1
...
...
@@ -25,6 +25,14 @@ if fluid.core.is_compiled_with_cuda():
os
.
environ
[
'CPU_NUM'
]
=
str
(
dev_cnt
)
def
dummy_func_with_no_input
():
return
float
(
1.0
)
def
dummy_func_with_no_output
(
x
):
pass
def
tanh
(
x
):
return
np
.
tanh
(
x
)
...
...
@@ -86,13 +94,20 @@ def simple_fc_net(img, label, use_py_func_op):
else
:
loss
=
fluid
.
default_main_program
().
current_block
().
create_var
(
name
=
'loss'
,
dtype
=
'float32'
,
shape
=
[
-
1
,
1
])
fluid
.
layers
.
py_func
(
loss
=
fluid
.
layers
.
py_func
(
func
=
cross_entropy
,
x
=
[
prediction
,
label
],
out
=
loss
,
backward_func
=
cross_entropy_grad
,
skip_vars_in_backward_input
=
loss
)
dummy_var
=
fluid
.
default_main_program
().
current_block
().
create_var
(
name
=
'test_tmp_var'
,
dtype
=
'float32'
,
shape
=
[
1
])
fluid
.
layers
.
py_func
(
func
=
dummy_func_with_no_input
,
x
=
None
,
out
=
dummy_var
)
fluid
.
layers
.
py_func
(
func
=
dummy_func_with_no_output
,
x
=
loss
,
out
=
None
)
loss
=
fluid
.
layers
.
mean
(
loss
)
return
loss
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录