未验证 提交 9a2204ee 编写于 作者: C Chen Weihang 提交者: GitHub

Uniform append_backward & gradients parameter_list type to Variable (#21938)



* update doc, test=develop

* fix related unittests, test=develop

* fix str incompatible error, test=develop
上级 6b4c33ee
......@@ -1027,18 +1027,18 @@ def append_backward(loss,
Parameters:
loss( :ref:`api_guide_Variable_en` ): The loss variable of the network.
parameter_list(list of str, optional): Names of parameters that need
to be updated by optimizers.
parameter_list(list[Variable|str], optional): List of Parameters or Parameter.names
that need to be updated by optimizers.
If it is None, all parameters
will be updated.
Default: None.
no_grad_set(set of str, optional): Variable names in the :ref:`api_guide_Block_en` 0 whose gradients
no_grad_set(set[str], optional): Variable names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All variables with
`stop_gradient=True` from all blocks will
be automatically added into this set.
If this parameter is not None, the names in this set will be added to the default set.
Default: None.
callbacks(list of callable object, optional): List of callback functions.
callbacks(list[callable object], optional): List of callback functions.
The callbacks are used for
doing some custom jobs during
backward part building. All
......@@ -1167,7 +1167,20 @@ def append_backward(loss,
program._sync_with_cpp()
if parameter_list is not None:
parameters = parameter_list
if not isinstance(parameter_list, (list, tuple, set)):
raise TypeError(
"The type of parameter_list argument must be list or tuple or set, but received %s."
% (type(parameter_list)))
parameters = []
for i, param in enumerate(parameter_list):
if isinstance(param, framework.Variable):
parameters.append(param.name)
elif isinstance(param, six.string_types):
parameters.append(param)
else:
raise TypeError(
"The type of parameter_list's member must be paddle.fluid.Variable or str, but received %s."
% (type(param)))
else:
params = program.global_block().all_parameters()
parameters = [param.name for param in params if param.trainable]
......
......@@ -519,7 +519,7 @@ class Optimizer(object):
startup_program (Program, optional): :ref:`api_fluid_Program` for
initializing parameters in ``parameter_list``. The default value
is None, at this time :ref:`api_fluid_default_startup_program` will be used.
parameter_list (list, optional): List of ``Variable`` names to update
parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters
will be updated.
no_grad_set (set, optional): Set of ``Variable`` objects that don't need
......@@ -666,7 +666,7 @@ class Optimizer(object):
startup_program (Program, optional): :ref:`api_fluid_Program` for
initializing parameters in ``parameter_list``. The default value
is None, at this time :ref:`api_fluid_default_startup_program` will be used.
parameter_list (list, optional): List of ``Variable`` names to update
parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters
will be updated.
no_grad_set (set, optional): Set of ``Variable`` objects that don't need
......
......@@ -127,6 +127,21 @@ class TestBackward(unittest.TestCase):
return no_grad_vars
def _check_error_param_list(self, net, parameter_list):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = net.build_model()
optimizer = fluid.optimizer.SGD(learning_rate=0.1)
optimizer.minimize(loss, parameter_list=parameter_list)
exe.run(startup)
exe.run(feed=net.init_data())
class SimpleNet(BackwardNet):
def __init__(self):
......@@ -211,6 +226,19 @@ class TestSimpleNet(TestBackward):
self._check_all(self.net)
class TestSimpleNetWithErrorParamList(TestBackward):
def test_parameter_list_type_error(self):
self.global_block_idx = 0
self.net = SimpleNet()
# The type of parameter_list argument must be list or tuple
with self.assertRaises(TypeError):
self._check_error_param_list(self.net, "test")
# The type of parameter_list's member must be varable or str
test = fluid.data(name='test', shape=[None, 90], dtype='float32')
with self.assertRaises(TypeError):
self._check_error_param_list(self.net, [test, "test", 3])
# TODO(Aurelius84): add conditional network test
class ConditionalNet(BackwardNet):
def __init__(self):
......
......@@ -54,7 +54,7 @@ class TestUnStackOpBase(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad('X', self.get_y_names())
self.check_grad(['X'], self.get_y_names())
class TestStackOp3(TestUnStackOpBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册