提交 7d10edc5 编写于 作者: Z zhongpu 提交者: hong

add clear_gradients for Optimizer and add clear_gradients api description (#21948)

* add clear_gradients for Optimizer, add api description, test=develop

* fix optest for optimizer's clear_gradient interface, test=develop

* add sample code, test=develop

* polish sample code, test=develop
上级 51cb918a
...@@ -173,6 +173,30 @@ class Layer(core.Layer): ...@@ -173,6 +173,30 @@ class Layer(core.Layer):
return ret return ret
def clear_gradients(self): def clear_gradients(self):
"""
Clear the gradients of all parameters for this layer.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
with fluid.dygraph.guard():
value = np.arange(26).reshape(2, 13).astype("float32")
a = fluid.dygraph.to_variable(value)
linear = fluid.Linear(13, 5, dtype="float32")
adam = fluid.optimizer.Adam(learning_rate=0.01,
parameter_list=linear.parameters())
out = linear(a)
out.backward()
adam.minimize(out)
linear.clear_gradients()
"""
for p in self.parameters(): for p in self.parameters():
if p.trainable: if p.trainable:
p.clear_gradient() p.clear_gradient()
......
...@@ -662,6 +662,37 @@ class Optimizer(object): ...@@ -662,6 +662,37 @@ class Optimizer(object):
return no_grad_set return no_grad_set
@framework.dygraph_only
def clear_gradients(self):
"""
Clear the gradients of all optimized parameters for model.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
with fluid.dygraph.guard():
value = np.arange(26).reshape(2, 13).astype("float32")
a = fluid.dygraph.to_variable(value)
linear = fluid.Linear(13, 5, dtype="float32")
# This can be any optimizer supported by dygraph.
adam = fluid.optimizer.Adam(learning_rate = 0.01,
parameter_list = linear.parameters())
out = linear(a)
out.backward()
adam.minimize(out)
adam.clear_gradients()
"""
for p in self._parameter_list:
if p.trainable:
p.clear_gradient()
@imperative_base.no_grad @imperative_base.no_grad
def minimize(self, def minimize(self,
loss, loss,
......
...@@ -131,7 +131,7 @@ class TestDygraphSimpleNet(unittest.TestCase): ...@@ -131,7 +131,7 @@ class TestDygraphSimpleNet(unittest.TestCase):
dy_param_init[param.name] = param.numpy() dy_param_init[param.name] = param.numpy()
dy_loss.backward(backward_strategy) dy_loss.backward(backward_strategy)
sgd.minimize(dy_loss) sgd.minimize(dy_loss)
simple_net.clear_gradients() sgd.clear_gradients()
if i == batch_num - 1: if i == batch_num - 1:
for param in simple_net.parameters(): for param in simple_net.parameters():
dy_param_updated[param.name] = param.numpy() dy_param_updated[param.name] = param.numpy()
......
...@@ -137,7 +137,7 @@ class TestDygraphSimpleNet(unittest.TestCase): ...@@ -137,7 +137,7 @@ class TestDygraphSimpleNet(unittest.TestCase):
dy_param_init[param.name] = param.numpy() dy_param_init[param.name] = param.numpy()
dy_loss.backward(backward_strategy) dy_loss.backward(backward_strategy)
sgd.minimize(dy_loss) sgd.minimize(dy_loss)
simple_net.clear_gradients() sgd.clear_gradients()
if i == batch_num - 1: if i == batch_num - 1:
for param in simple_net.parameters(): for param in simple_net.parameters():
dy_param_updated[param.name] = param.numpy() dy_param_updated[param.name] = param.numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册