未验证 提交 9a2204ee 编写于 作者: C Chen Weihang 提交者: GitHub

Uniform append_backward & gradients parameter_list type to Variable (#21938)



* update doc, test=develop

* fix related unittests, test=develop

* fix str incompatible error, test=develop
上级 6b4c33ee
...@@ -1027,18 +1027,18 @@ def append_backward(loss, ...@@ -1027,18 +1027,18 @@ def append_backward(loss,
Parameters: Parameters:
loss( :ref:`api_guide_Variable_en` ): The loss variable of the network. loss( :ref:`api_guide_Variable_en` ): The loss variable of the network.
parameter_list(list of str, optional): Names of parameters that need parameter_list(list[Variable|str], optional): List of Parameters or Parameter.names
to be updated by optimizers. that need to be updated by optimizers.
If it is None, all parameters If it is None, all parameters
will be updated. will be updated.
Default: None. Default: None.
no_grad_set(set of str, optional): Variable names in the :ref:`api_guide_Block_en` 0 whose gradients no_grad_set(set[str], optional): Variable names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All variables with should be ignored. All variables with
`stop_gradient=True` from all blocks will `stop_gradient=True` from all blocks will
be automatically added into this set. be automatically added into this set.
If this parameter is not None, the names in this set will be added to the default set. If this parameter is not None, the names in this set will be added to the default set.
Default: None. Default: None.
callbacks(list of callable object, optional): List of callback functions. callbacks(list[callable object], optional): List of callback functions.
The callbacks are used for The callbacks are used for
doing some custom jobs during doing some custom jobs during
backward part building. All backward part building. All
...@@ -1167,7 +1167,20 @@ def append_backward(loss, ...@@ -1167,7 +1167,20 @@ def append_backward(loss,
program._sync_with_cpp() program._sync_with_cpp()
if parameter_list is not None: if parameter_list is not None:
parameters = parameter_list if not isinstance(parameter_list, (list, tuple, set)):
raise TypeError(
"The type of parameter_list argument must be list or tuple or set, but received %s."
% (type(parameter_list)))
parameters = []
for i, param in enumerate(parameter_list):
if isinstance(param, framework.Variable):
parameters.append(param.name)
elif isinstance(param, six.string_types):
parameters.append(param)
else:
raise TypeError(
"The type of parameter_list's member must be paddle.fluid.Variable or str, but received %s."
% (type(param)))
else: else:
params = program.global_block().all_parameters() params = program.global_block().all_parameters()
parameters = [param.name for param in params if param.trainable] parameters = [param.name for param in params if param.trainable]
......
...@@ -519,7 +519,7 @@ class Optimizer(object): ...@@ -519,7 +519,7 @@ class Optimizer(object):
startup_program (Program, optional): :ref:`api_fluid_Program` for startup_program (Program, optional): :ref:`api_fluid_Program` for
initializing parameters in ``parameter_list``. The default value initializing parameters in ``parameter_list``. The default value
is None, at this time :ref:`api_fluid_default_startup_program` will be used. is None, at this time :ref:`api_fluid_default_startup_program` will be used.
parameter_list (list, optional): List of ``Variable`` names to update parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters to minimize ``loss``. The default value is None, at this time all parameters
will be updated. will be updated.
no_grad_set (set, optional): Set of ``Variable`` objects that don't need no_grad_set (set, optional): Set of ``Variable`` objects that don't need
...@@ -666,7 +666,7 @@ class Optimizer(object): ...@@ -666,7 +666,7 @@ class Optimizer(object):
startup_program (Program, optional): :ref:`api_fluid_Program` for startup_program (Program, optional): :ref:`api_fluid_Program` for
initializing parameters in ``parameter_list``. The default value initializing parameters in ``parameter_list``. The default value
is None, at this time :ref:`api_fluid_default_startup_program` will be used. is None, at this time :ref:`api_fluid_default_startup_program` will be used.
parameter_list (list, optional): List of ``Variable`` names to update parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters to minimize ``loss``. The default value is None, at this time all parameters
will be updated. will be updated.
no_grad_set (set, optional): Set of ``Variable`` objects that don't need no_grad_set (set, optional): Set of ``Variable`` objects that don't need
......
...@@ -127,6 +127,21 @@ class TestBackward(unittest.TestCase): ...@@ -127,6 +127,21 @@ class TestBackward(unittest.TestCase):
return no_grad_vars return no_grad_vars
def _check_error_param_list(self, net, parameter_list):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = net.build_model()
optimizer = fluid.optimizer.SGD(learning_rate=0.1)
optimizer.minimize(loss, parameter_list=parameter_list)
exe.run(startup)
exe.run(feed=net.init_data())
class SimpleNet(BackwardNet): class SimpleNet(BackwardNet):
def __init__(self): def __init__(self):
...@@ -211,6 +226,19 @@ class TestSimpleNet(TestBackward): ...@@ -211,6 +226,19 @@ class TestSimpleNet(TestBackward):
self._check_all(self.net) self._check_all(self.net)
class TestSimpleNetWithErrorParamList(TestBackward):
def test_parameter_list_type_error(self):
self.global_block_idx = 0
self.net = SimpleNet()
# The type of parameter_list argument must be list or tuple
with self.assertRaises(TypeError):
self._check_error_param_list(self.net, "test")
# The type of parameter_list's member must be varable or str
test = fluid.data(name='test', shape=[None, 90], dtype='float32')
with self.assertRaises(TypeError):
self._check_error_param_list(self.net, [test, "test", 3])
# TODO(Aurelius84): add conditional network test # TODO(Aurelius84): add conditional network test
class ConditionalNet(BackwardNet): class ConditionalNet(BackwardNet):
def __init__(self): def __init__(self):
......
...@@ -54,7 +54,7 @@ class TestUnStackOpBase(OpTest): ...@@ -54,7 +54,7 @@ class TestUnStackOpBase(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad('X', self.get_y_names()) self.check_grad(['X'], self.get_y_names())
class TestStackOp3(TestUnStackOpBase): class TestStackOp3(TestUnStackOpBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册