未验证 提交 fe2cf39f 编写于 作者: W Wilber 提交者: GitHub

[2.0] Update py_func English doc. (#28646)

上级 16a80814
......@@ -13496,16 +13496,16 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
principe of py_func is that Tensor and numpy array can be converted to each
other easily. So you can use Python and numpy API to register a python OP.
The forward function of the registered OP is ``func`` and the backward function
of that is ``backward_func``. Paddle will call ``func`` at forward runtime and
The forward function of the registered OP is ``func`` and the backward function
of that is ``backward_func``. Paddle will call ``func`` at forward runtime and
call ``backward_func`` at backward runtime(if ``backward_func`` is not None).
``x`` is the input of ``func``, whose type must be Tensor; ``out`` is
the output of ``func``, whose type can be either Tensor or numpy array.
The input of the backward function ``backward_func`` is ``x``, ``out`` and
the gradient of ``out``. If some variables of ``out`` have no gradient, the
relevant input variable of ``backward_func`` is None. If some variables of
``x`` do not have a gradient, the user should return None in ``backward_func``.
the gradient of ``out``. If ``out`` have no gradient, the relevant input of
``backward_func`` is None. If ``x`` do not have a gradient, the user should
return None in ``backward_func``.
The data type and shape of ``out`` should also be set correctly before this
API is called, and the data type and shape of the gradient of ``out`` and
......@@ -13520,27 +13520,26 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
function and the forward input ``x``. In ``func`` , it's suggested that we
actively convert Tensor into a numpy array, so that we can use Python and
numpy API arbitrarily. If not, some operations of numpy may not be compatible.
x (Variable|tuple(Variale)|list[Variale]): The input of the forward function ``func``.
It can be Variable|tuple(Variale)|list[Variale], where Variable is Tensor or
Tenosor. In addition, Multiple Variable should be passed in the form of tuple(Variale)
or list[Variale].
out (Variable|tuple(Variale)|list[Variale]): The output of the forward function ``func``,
it can be Variable|tuple(Variale)|list[Variale], where Variable can be either Tensor
or numpy array. Since Paddle cannot automatically infer the shape and type of ``out``,
you must create ``out`` in advance.
x (Tensor|tuple(Tensor)|list[Tensor]): The input of the forward function ``func``.
It can be Tensor|tuple(Tensor)|list[Tensor]. In addition, Multiple Tensor
should be passed in the form of tuple(Tensor) or list[Tensor].
out (T|tuple(T)|list[T]): The output of the forward function ``func``, it can be
T|tuple(T)|list[T], where T can be either Tensor or numpy array. Since Paddle
cannot automatically infer the shape and type of ``out``, you must create
``out`` in advance.
backward_func (callable, optional): The backward function of the registered OP.
Its default value is None, which means there is no reverse calculation. If
it is not None, ``backward_func`` is called to calculate the gradient of
``x`` when the network is at backward runtime.
skip_vars_in_backward_input (Variable, optional): It's used to limit the input
variable list of ``backward_func``, and it can be Variable|tuple(Variale)|list[Variale].
skip_vars_in_backward_input (Tensor, optional): It's used to limit the input
list of ``backward_func``, and it can be Tensor|tuple(Tensor)|list[Tensor].
It must belong to either ``x`` or ``out``. The default value is None, which means
that no variables need to be removed from ``x`` and ``out``. If it is not None,
these variables will not be the input of ``backward_func``. This parameter is only
that no tensors need to be removed from ``x`` and ``out``. If it is not None,
these tensors will not be the input of ``backward_func``. This parameter is only
useful when ``backward_func`` is not None.
Returns:
Variable|tuple(Variale)|list[Variale]: The output ``out`` of the forward function ``func``.
Tensor|tuple(Tensor)|list[Tensor]: The output ``out`` of the forward function ``func``.
Examples:
.. code-block:: python
......@@ -13548,6 +13547,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
# example 1:
import paddle
import six
import numpy as np
paddle.enable_static()
......@@ -13578,16 +13578,31 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
dtype=hidden.dtype, shape=hidden.shape)
# User-defined forward and backward
hidden = paddle.static.nn.py_func(func=tanh, x=hidden,
hidden = paddle.static.py_func(func=tanh, x=hidden,
out=new_hidden, backward_func=tanh_grad,
skip_vars_in_backward_input=hidden)
# User-defined debug functions that print out the input Tensor
paddle.static.nn.py_func(func=debug_func, x=hidden, out=None)
paddle.static.py_func(func=debug_func, x=hidden, out=None)
prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax')
loss = paddle.static.nn.cross_entropy(input=prediction, label=label)
return paddle.mean(loss)
ce_loss = paddle.nn.loss.CrossEntropyLoss()
return ce_loss(prediction, label)
x = paddle.static.data(name='x', shape=[1,4], dtype='float32')
y = paddle.static.data(name='y', shape=[1,10], dtype='int64')
res = simple_net(x, y)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
input1 = np.random.random(size=[1,4]).astype('float32')
input2 = np.random.randint(1, 10, size=[1,10], dtype='int64')
out = exe.run(paddle.static.default_main_program(),
feed={'x':input1, 'y':input2},
fetch_list=[res.name])
print(out)
.. code-block:: python
# example 2:
# This example shows how to turn Tensor into numpy array and
......@@ -13629,7 +13644,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
output = create_tmp_var('output','int32', [3,1])
# Multiple Variable should be passed in the form of tuple(Variale) or list[Variale]
paddle.static.nn.py_func(func=element_wise_add, x=[x,y], out=output)
paddle.static.py_func(func=element_wise_add, x=[x,y], out=output)
exe=paddle.static.Executor(paddle.CPUPlace())
exe.run(start_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册