未验证 提交 ea6a251c 编写于 作者: Z zhongpu 提交者: GitHub

fix sample code, test=develop (#23448)

上级 25628587
...@@ -122,6 +122,7 @@ class Layer(core.Layer): ...@@ -122,6 +122,7 @@ class Layer(core.Layer):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np
# the forward_post_hook change the output of the layer: output = output * 2 # the forward_post_hook change the output of the layer: output = output * 2
def forward_post_hook(layer, input, output): def forward_post_hook(layer, input, output):
...@@ -136,15 +137,15 @@ class Layer(core.Layer): ...@@ -136,15 +137,15 @@ class Layer(core.Layer):
# register the hook # register the hook
forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook) forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)
value = np.arange(26).reshape(2, 13).astype("float32") value1 = np.arange(26).reshape(2, 13).astype("float32")
in = fluid.dygraph.to_variable(value0) in1 = fluid.dygraph.to_variable(value1)
out0 = linear(in) out0 = linear(in1)
# remove the hook # remove the hook
forward_post_hook_handle.remove() forward_post_hook_handle.remove()
out1 = linear(in) out1 = linear(in1)
# hook change the linear's output to output * 2, so out0 is equal to out1 * 2. # hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
assert (out0.numpy() == (out1.numpy()) * 2).any() assert (out0.numpy() == (out1.numpy()) * 2).any()
...@@ -173,6 +174,7 @@ class Layer(core.Layer): ...@@ -173,6 +174,7 @@ class Layer(core.Layer):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np
# the forward_post_hook change the input of the layer: input = input * 2 # the forward_post_hook change the input of the layer: input = input * 2
def forward_pre_hook(layer, input): def forward_pre_hook(layer, input):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册