From c532fdab29f17d3aa7edc7902d9a5a94346660b4 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 26 Dec 2017 23:44:02 +0800 Subject: [PATCH] fix errors --- python/paddle/v2/fluid/backward.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index b90949838ea..6966cc75804 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -195,7 +195,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): _infer_var_data_type_(arg, block) -def append_backward(loss, parameter_list=None, no_grad_dict=None): +def append_backward(loss, parameter_list=None, no_grad_set=None): """ Create and add gradient Operators in BlockDesc to compute gradients of `loss` for parameters in parameter_list @@ -213,8 +213,8 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None): assert isinstance(loss, framework.Variable) program = loss.block.program - if no_grad_dict is None: - no_grad_dict = dict() + no_grad_dict = dict() + if no_grad_set is None: assert isinstance(program, framework.Program) for block in program.blocks: assert isinstance(block, framework.Block) @@ -224,8 +224,10 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None): if var.stop_gradient: block_no_grad_set.add(_append_grad_suffix_(var.name)) no_grad_dict[block.idx] = block_no_grad_set - elif isinstance(no_grad_dict, set): - no_grad_dict = {0: no_grad_dict} + elif isinstance(no_grad_set, set): + no_grad_dict = {0: no_grad_set} + else: + raise ValueError("'no_grad_set' should be a set or None.") grad_info_map = dict() root_block = program.block(0) -- GitLab