未验证 提交 052e7b07 编写于 作者: 张春乔 提交者: GitHub

[xdoctest] reformat example code with google style in No 101 (#55968)

* weight_norm_hook

* Update weight_norm_hook.py

* Update weight_norm_hook.py

* Update python/paddle/nn/utils/weight_norm_hook.py

* Update python/paddle/nn/utils/weight_norm_hook.py

* Update python/paddle/nn/utils/weight_norm_hook.py
Co-authored-by: NNyakku Shigure <sigure.qaq@gmail.com>

---------
Co-authored-by: NNyakku Shigure <sigure.qaq@gmail.com>
上级 2ac6a7e4
...@@ -191,15 +191,15 @@ def weight_norm(layer, name='weight', dim=0): ...@@ -191,15 +191,15 @@ def weight_norm(layer, name='weight', dim=0):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.nn import Conv2D >>> from paddle.nn import Conv2D
from paddle.nn.utils import weight_norm >>> from paddle.nn.utils import weight_norm
conv = Conv2D(3, 5, 3) >>> conv = Conv2D(3, 5, 3)
wn = weight_norm(conv) >>> wn = weight_norm(conv)
print(conv.weight_g.shape) >>> print(conv.weight_g.shape)
# [5] [5]
print(conv.weight_v.shape) >>> print(conv.weight_v.shape)
# [5, 3, 3, 3] [5, 3, 3, 3]
""" """
WeightNorm.apply(layer, name, dim) WeightNorm.apply(layer, name, dim)
return layer return layer
...@@ -219,22 +219,21 @@ def remove_weight_norm(layer, name='weight'): ...@@ -219,22 +219,21 @@ def remove_weight_norm(layer, name='weight'):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle >>> import paddle
from paddle.nn import Conv2D >>> from paddle.nn import Conv2D
from paddle.nn.utils import weight_norm, remove_weight_norm >>> from paddle.nn.utils import weight_norm, remove_weight_norm
>>> paddle.seed(2023)
conv = Conv2D(3, 5, 3)
wn = weight_norm(conv) >>> conv = Conv2D(3, 5, 3)
print(conv.weight_g) >>> wn = weight_norm(conv)
# Parameter containing: >>> print(conv.weight_g)
# Tensor(shape=[5], dtype=float32, place=Place(gpu:0), stop_gradient=False, Parameter containing:
# [0., 0., 0., 0., 0.]) Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=False,
# Conv2D(3, 5, kernel_size=[3, 3], data_format=NCHW) [1.35883713, 1.32126212, 1.56303072, 1.20874095, 1.22893476])
>>> remove_weight_norm(conv)
remove_weight_norm(conv) >>> # The following is the effect after removing the weight norm:
# The following is the effect after removing the weight norm: >>> # print(conv.weight_g)
# print(conv.weight_g) >>> # AttributeError: 'Conv2D' object has no attribute 'weight_g'
# AttributeError: 'Conv2D' object has no attribute 'weight_g'
""" """
for k, hook in layer._forward_pre_hooks.items(): for k, hook in layer._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name: if isinstance(hook, WeightNorm) and hook.name == name:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册