未验证 提交 d9e63a81 编写于 作者: C chentianyu03 提交者: GitHub

Add gradient with optimizer API (#34395)

* add gradients_with_optimizer api

* modify gradients_with_optimizer

* add gradients_with_optimizer api into paddle.auto.backward_mode

* add gradients_with_optimizer test case

* add doc for gradients_with_optimizer

* add doc for gradients_with_optimizer
上级 1f76a2f7
......@@ -14,6 +14,7 @@
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.backward import gradients_with_optimizer
import paddle
__all__ = []
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
from .proto import framework_pb2
from paddle.fluid import framework as framework
from paddle.fluid import program_guard
from . import core
import collections
import copy
......@@ -2015,3 +2016,72 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
outs = calc_gradient(targets, inputs, target_gradients, no_grad_set)
return _as_list(outs)
@framework.static_only
def gradients_with_optimizer(program, optimizer, inputs=None, outputs=None):
"""
:api_attr: Static Graph
Backpropagate the gradients of the program and apply the gradients with the given optimizer.
Args:
program (Program): The input program.
optimizer (Optimizer): The optimizer to apply the gradients.
inputs (Tensor|list[Tensor]|tuple[Tensor], optional): The input Tensors.
If None, the inputs will be created from the input variables in the given program. Default:None.
outputs (Tensor|list[Tensor]|tuple[Tensor], optional): The output Tensors.
If None, the outputs will be created from the output variables in the given program. Default: None.
Return:
tuple: tuple (optimize_ops, params_grads), A list of operators appended
by gradients_with_optimizer and a list of (param, grad) variable pairs, param is
``Parameter``, grad is the gradient value corresponding to the parameter.
The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
``fetch_list`` before run, see details in ``Executor``.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
img = static.data(name='image', shape=[None, 784])
pred = static.nn.fc(x=img, size=10, activation='relu')
loss = paddle.mean(pred)
opt_ops, pram_grads = paddle.fluid.backward.gradients_with_optimizer(static.default_main_program(), opt)
print(opt_ops)
"""
check_type(program, 'program', paddle.fluid.Program,
'paddle.static.gradients_with_optimizer')
check_type(optimizer, 'optimizer', paddle.optimizer.Optimizer,
'paddle.static.gradients_with_optimizer')
if inputs is None or outputs is None:
in_set = set()
out_set = set()
for block in program.blocks:
for op in block.ops:
for name in op.input_arg_names:
in_set.add(block.vars[name])
for name in op.output_arg_names:
out_set.add(block.vars[name])
if inputs is None:
inputs = list(in_set.difference(out_set))
if outputs is None:
outputs = list(out_set.difference(in_set))
grads = gradients(outputs, inputs)
with program_guard(program, None):
pram_grads = [(pram, grad) for pram, grad in zip(inputs, grads)
if isinstance(pram, paddle.fluid.framework.Parameter) and
grad is not None]
optimize_ops = optimizer.apply_gradients(pram_grads)
return optimize_ops, pram_grads
......@@ -16,6 +16,9 @@ from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.static as static
import paddle
import numpy as np
......@@ -327,6 +330,35 @@ class TestAppendBackwardWithError(unittest.TestCase):
loss=self.avg_loss, callbacks=callback)
class TestGradientsWithOptimizer(unittest.TestCase):
def _check_grad_op_name(self, forward_list, optimiezed_list):
backward_list = [op + "_grad" for op in reversed(forward_list)]
idx = optimiezed_list.index(backward_list[0], len(backward_list))
self.assertListEqual(backward_list,
optimiezed_list[idx:idx + len(backward_list)])
def test_gradient_with_optimizer(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
img = static.data(name='image', shape=[None, 784])
pred = static.nn.fc(x=img, size=10, activation='relu')
loss = paddle.mean(pred)
opt = paddle.optimizer.Momentum(learning_rate=0.01, momentum=0.9)
forward_list = [o.type for o in main.current_block().ops]
optimize_ops, pram_grads = paddle.autograd.backward_mode.gradients_with_optimizer(
main, opt)
optimized_list = [o.type for o in main.current_block().ops]
self.assertGreater(len(optimized_list), len(forward_list))
self.assertIn(opt.type, optimized_list)
self._check_grad_op_name(forward_list, optimized_list)
# TODO(Aurelius84): add conditional network test
class ConditionalNet(BackwardNet):
def __init__(self):
......@@ -334,4 +366,5 @@ class ConditionalNet(BackwardNet):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册