未验证 提交 50af6b5d 编写于 作者: A Aurelius84 提交者: GitHub

polish no_grad_set of gradient and append_backward (#22440)

* polish backward api doc test=develop, test=document_preview,
       test=document_fix

* polish backward api doc test=develop, test=document_preview, test=document_fix

* no_grad supports set of Variable test=develop, test=document_preview

* polish sample code of append_backward test=develop, test=document_preview

* modify assert into Raise TypeError test=develop,test=document_preview

* fix unittest failed test=develop

* rm useless file test=develop

* polish en doc test=develop

* polish code of no_grad_set test=develop

* polish code of no_grad_set test=develop
上级 7c9ce097
...@@ -1110,6 +1110,26 @@ def _get_son_parent_block_idx_dict(program, current_block_idx): ...@@ -1110,6 +1110,26 @@ def _get_son_parent_block_idx_dict(program, current_block_idx):
return son_parent_block_idx_dict return son_parent_block_idx_dict
def _get_no_grad_set_name(no_grad_set):
no_grad_set_name = set()
if no_grad_set is not None:
if isinstance(no_grad_set, (set, list, tuple)):
for i, no_grad_var in enumerate(no_grad_set):
if isinstance(no_grad_var, framework.Variable):
no_grad_set_name.add(no_grad_var.name)
elif isinstance(no_grad_var, six.string_types):
no_grad_set_name.add(no_grad_var)
else:
raise TypeError(
"The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s."
% (type(no_grad_var)))
else:
raise TypeError(
"The type of no_grad_set should be set or list or tuple, but received {}".
format(type(no_grad_set)))
return no_grad_set_name
def append_backward(loss, def append_backward(loss,
parameter_list=None, parameter_list=None,
no_grad_set=None, no_grad_set=None,
...@@ -1133,11 +1153,11 @@ def append_backward(loss, ...@@ -1133,11 +1153,11 @@ def append_backward(loss,
If it is None, all parameters If it is None, all parameters
will be updated. will be updated.
Default: None. Default: None.
no_grad_set(set[str], optional): Variable names in the :ref:`api_guide_Block_en` 0 whose gradients no_grad_set(set[Variable|str], optional): Set of Variables or Variable.names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All variables with should be ignored. All variables with
`stop_gradient=True` from all blocks will `stop_gradient=True` from all blocks will
be automatically added into this set. be automatically added into this set.
If this parameter is not None, the names in this set will be added to the default set. If this parameter is not None, the Variables or Variable.names in this set will be added to the default set.
Default: None. Default: None.
callbacks(list[callable object], optional): List of callback functions. callbacks(list[callable object], optional): List of callback functions.
The callbacks are used for The callbacks are used for
...@@ -1174,18 +1194,40 @@ def append_backward(loss, ...@@ -1174,18 +1194,40 @@ def append_backward(loss,
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
x = fluid.data(name='x', shape=[None, 13], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None) x = fluid.data(name='x', shape=[None, 13], dtype='int64')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
x_emb = fluid.embedding(x, size=[100, 256])
y_predict = fluid.layers.fc(input=x_emb, size=1, act=None, name='my_fc')
loss = fluid.layers.square_error_cost(input=y_predict, label=y) loss = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(loss) avg_loss = fluid.layers.mean(loss)
param_grad_list = fluid.backward.append_backward(loss=avg_loss)
p_g_list1 = fluid.backward.append_backward(loss=avg_loss) # len(p_g_list1) == 2 # Get all weights in main_program, not include bias.
p_g_list2 = fluid.backward.append_backward(loss=avg_loss, parameter_list=[p_g_list1[0][0].name]) # len(p_g_list1) == 1 all_weights = [param for param in fluid.default_main_program().block(0).all_parameters() if 'w_' in param.name]
p_g_list3 = fluid.backward.append_backward(loss=avg_loss, no_grad_set=set([p_g_list1[0][0].name])) # len(p_g_list1) == 1 all_weights_name = [w.name for w in all_weights]
p_g_list4 = fluid.backward.append_backward(loss=avg_loss, parameter_list=[p_g_list1[0][0].name], no_grad_set=set([p_g_list1[0][0].name])) # len(p_g_list1) == 0
# return all param_grads needed to be updated if parameter_list set default None.
p_g_list1 = fluid.backward.append_backward(loss=avg_loss)
# output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)]
# return the param_grads corresponding to parameter_list that can be list of param (Variable).
p_g_list2 = fluid.backward.append_backward(loss=avg_loss, parameter_list=all_weights)
# output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)]
# parameter_list can be list of param.name (str).
p_g_list3 = fluid.backward.append_backward(loss=avg_loss, parameter_list=all_weights_name)
# output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)]
# no_grad_set can be set of Variables that means grad will be cut off from these Variables.
p_g_list4 = fluid.backward.append_backward(loss=avg_loss, no_grad_set=set([x_emb]))
# output: [(my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)]
# no_grad_set can be set of Variable.name when the Variable is created inside layers and can't be specified explicitly.
p_g_list5 = fluid.backward.append_backward(loss=avg_loss, no_grad_set=set(['my_fc.b_0']))
# output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)]
# return [] because all param_grads are filtered by no_grad_set.
p_g_list6 = fluid.backward.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights))
""" """
assert isinstance(loss, framework.Variable) assert isinstance(loss, framework.Variable)
...@@ -1215,7 +1257,8 @@ def append_backward(loss, ...@@ -1215,7 +1257,8 @@ def append_backward(loss,
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
no_grad_set = copy.copy(no_grad_set) else:
no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set))
no_grad_dict = _get_stop_gradients_(program) no_grad_dict = _get_stop_gradients_(program)
# no_grad_set only contains vars in block 0 # no_grad_set only contains vars in block 0
# Todo(liym27): support vars in sub block # Todo(liym27): support vars in sub block
...@@ -1501,12 +1544,15 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1501,12 +1544,15 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
Args: Args:
targets(Variable|list[Variable]): The target variables targets(Variable|list[Variable]): The target variables
inputs(Variable|list[Variable]): The input variables inputs(Variable|list[Variable]): The input variables
target_gradients (Variable|list[Variable]|None): The gradient variables target_gradients (Variable|list[Variable], optional): The gradient variables
of targets which has the same shape with targets, If None, ones will of targets which has the same shape with targets, If None, ones will
be created for them. be created for them.
no_grad_set(set[string]): The names of variables that have no gradients no_grad_set(set[Variable|str], optional): Set of Variables or Variable.names in the :ref:`api_guide_Block_en` 0 whose gradients
in Block 0. All variables with `stop_gradient=True` from all blocks should be ignored. All variables with
will be automatically added. `stop_gradient=True` from all blocks will
be automatically added into this set.
If this parameter is not None, the Variables or Variable.names in this set will be added to the default set.
Default: None.
Return: Return:
(list[Variable]): A list of gradients for inputs (list[Variable]): A list of gradients for inputs
...@@ -1532,7 +1578,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1532,7 +1578,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
no_grad_set = copy.copy(no_grad_set) else:
no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set))
no_grad_dict = _get_stop_gradients_(prog) no_grad_dict = _get_stop_gradients_(prog)
no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set))) no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set)))
...@@ -1623,12 +1670,13 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1623,12 +1670,13 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
Args: Args:
targets (Variable|list[Variable]): The target variables. targets (Variable|list[Variable]): The target variables.
inputs (Variable|list[Variable]): The input variables. inputs (Variable|list[Variable]): The input variables.
target_gradients (Variable|list[Variable]|None): The gradient variables target_gradients (Variable|list[Variable], optional): The gradient variables
of targets which has the same shape with targets, If None, ones will of targets which has the same shape with targets, If None, ones will
be created for them. be created for them.
no_grad_set (set[string]): The names of variables that have no gradients no_grad_set (set[Variable|str], optional): Set of Variables or Variable.names in the :ref:`api_guide_Block_en` 0 whose gradients
in Block 0. All variables with `stop_gradient=True` from all blocks should be ignored. All variables with `stop_gradient=True` from all blocks will
will be automatically added. be automatically added into this set. If this parameter is not None, the Variables or Variable.names
in this set will be added to the default set. Default: None.
Return: Return:
(list[Variable]): A list of gradients for inputs (list[Variable]): A list of gradients for inputs
...@@ -1640,7 +1688,7 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1640,7 +1688,7 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
import paddle.fluid as fluid import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[2,8,8], dtype='float32') x = fluid.data(name='x', shape=[None,2,8,8], dtype='float32')
x.stop_gradient=False x.stop_gradient=False
y = fluid.layers.conv2d(x, 4, 1, bias_attr=False) y = fluid.layers.conv2d(x, 4, 1, bias_attr=False)
y = fluid.layers.relu(y) y = fluid.layers.relu(y)
......
...@@ -23,7 +23,7 @@ from paddle.fluid.framework import Program, Variable, name_scope, default_main_p ...@@ -23,7 +23,7 @@ from paddle.fluid.framework import Program, Variable, name_scope, default_main_p
from . import framework from . import framework
from . import layers from . import layers
from . import unique_name from . import unique_name
from .backward import append_backward, _some_in_set_, _append_grad_suffix_ from .backward import append_backward, _some_in_set_, _append_grad_suffix_, _get_no_grad_set_name
from .clip import append_gradient_clip_ops, error_clip_callback from .clip import append_gradient_clip_ops, error_clip_callback
from .framework import program_guard from .framework import program_guard
from .initializer import Constant from .initializer import Constant
...@@ -592,7 +592,7 @@ class Optimizer(object): ...@@ -592,7 +592,7 @@ class Optimizer(object):
parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters to minimize ``loss``. The default value is None, at this time all parameters
will be updated. will be updated.
no_grad_set (set, optional): Set of ``Variable`` objects that don't need no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need
to be updated. The default value is None. to be updated. The default value is None.
callbacks (list, optional): list of callable objects to run when appending backward callbacks (list, optional): list of callable objects to run when appending backward
operator for one parameter. The default value is None. operator for one parameter. The default value is None.
...@@ -705,14 +705,7 @@ class Optimizer(object): ...@@ -705,14 +705,7 @@ class Optimizer(object):
return optimize_ops return optimize_ops
def _get_no_grad_set(self, loss, no_grad_set=None): def _get_no_grad_set(self, loss, no_grad_set=None):
if no_grad_set is None: no_grad_set = _get_no_grad_set_name(no_grad_set)
no_grad_set = set()
elif isinstance(no_grad_set, set) or isinstance(
no_grad_set, list) or isinstance(no_grad_set, tuple):
no_grad_set = set(no_grad_set)
else:
assert "no_grad_set should be a set, but the passed type is {}".format(
type(no_grad_set))
parameters = loss.block.program.global_block().all_parameters() parameters = loss.block.program.global_block().all_parameters()
param_no_trainable = set( param_no_trainable = set(
[param.name for param in parameters if param.trainable is False]) [param.name for param in parameters if param.trainable is False])
...@@ -770,7 +763,7 @@ class Optimizer(object): ...@@ -770,7 +763,7 @@ class Optimizer(object):
parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters to minimize ``loss``. The default value is None, at this time all parameters
will be updated. will be updated.
no_grad_set (set, optional): Set of ``Variable`` objects that don't need no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need
to be updated. The default value is None. to be updated. The default value is None.
grad_clip (GradClipBase, optional) : Gradient clipping strategy, static grad_clip (GradClipBase, optional) : Gradient clipping strategy, static
graph mode does not need to use this argument. Currently, this argument graph mode does not need to use this argument. Currently, this argument
...@@ -3843,8 +3836,8 @@ class RecomputeOptimizer(Optimizer): ...@@ -3843,8 +3836,8 @@ class RecomputeOptimizer(Optimizer):
loss (Variable): loss variable to run optimizations. loss (Variable): loss variable to run optimizations.
startup_program (Program): startup_program for initializing parameters startup_program (Program): startup_program for initializing parameters
in `parameter_list`. in `parameter_list`.
parameter_list (list): list of Variables to update. parameter_list (list): list of Variables or Variable.names to update.
no_grad_set (set|None): set of Variables should be ignored. no_grad_set (set|None): set of Variables or Variables.names should be ignored.
callbacks (list|None): list of callables to run when appending backward callbacks (list|None): list of callables to run when appending backward
operator for one parameter. operator for one parameter.
checkpoints (list): list of Variables as checkpoints checkpoints (list): list of Variables as checkpoints
......
...@@ -142,6 +142,21 @@ class TestBackward(unittest.TestCase): ...@@ -142,6 +142,21 @@ class TestBackward(unittest.TestCase):
exe.run(startup) exe.run(startup)
exe.run(feed=net.init_data()) exe.run(feed=net.init_data())
def _check_error_no_grad_set(self, net, no_grad_set):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = net.build_model()
optimizer = fluid.optimizer.SGD(learning_rate=0.1)
optimizer.minimize(loss, no_grad_set=no_grad_set)
exe.run(startup)
exe.run(feed=net.init_data())
class SimpleNet(BackwardNet): class SimpleNet(BackwardNet):
def __init__(self): def __init__(self):
...@@ -233,12 +248,25 @@ class TestSimpleNetWithErrorParamList(TestBackward): ...@@ -233,12 +248,25 @@ class TestSimpleNetWithErrorParamList(TestBackward):
# The type of parameter_list argument must be list or tuple # The type of parameter_list argument must be list or tuple
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self._check_error_param_list(self.net, "test") self._check_error_param_list(self.net, "test")
# The type of parameter_list's member must be varable or str # The type of parameter_list's member must be Variable or str
test = fluid.data(name='test', shape=[None, 90], dtype='float32') test = fluid.data(name='test', shape=[None, 90], dtype='float32')
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self._check_error_param_list(self.net, [test, "test", 3]) self._check_error_param_list(self.net, [test, "test", 3])
class TestSimpleNetWithErrorNoGradSet(TestBackward):
def test_no_grad_set_type_error(self):
self.global_block_idx = 0
self.net = SimpleNet()
# The type of no_grad_set argument must be set or list or tuple
with self.assertRaises(TypeError):
self._check_error_no_grad_set(self.net, "test")
# The type of no_grad_set's member must be Variable or str
test = fluid.data(name='test', shape=[None, 90], dtype='float32')
with self.assertRaises(TypeError):
self._check_error_no_grad_set(self.net, [test, "test", 3])
# TODO(Aurelius84): add conditional network test # TODO(Aurelius84): add conditional network test
class ConditionalNet(BackwardNet): class ConditionalNet(BackwardNet):
def __init__(self): def __init__(self):
......
...@@ -55,7 +55,7 @@ class TestFusedEmbeddingSeqPoolOp(OpTest): ...@@ -55,7 +55,7 @@ class TestFusedEmbeddingSeqPoolOp(OpTest):
if ver.mkl() == "ON" and 'Linux' in platform.platform(): if ver.mkl() == "ON" and 'Linux' in platform.platform():
self.attrs = {'is_sparse': False} self.attrs = {'is_sparse': False}
self.check_grad( self.check_grad(
['W'], 'Out', no_grad_set=('Ids'), check_dygraph=False) ['W'], 'Out', no_grad_set=['Ids'], check_dygraph=False)
class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp): class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp):
...@@ -89,7 +89,7 @@ class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp): ...@@ -89,7 +89,7 @@ class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp):
self.attrs = {'padding_idx': int(padding_idx), 'is_sparse': False} self.attrs = {'padding_idx': int(padding_idx), 'is_sparse': False}
# TODO(wangzhongpu): support lod in dygraph mode # TODO(wangzhongpu): support lod in dygraph mode
self.check_grad( self.check_grad(
['W'], 'Out', no_grad_set=('Ids'), check_dygraph=False) ['W'], 'Out', no_grad_set=['Ids'], check_dygraph=False)
class TestFusedEmbeddingSeqPoolApi(unittest.TestCase): class TestFusedEmbeddingSeqPoolApi(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册