diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index b90949838ea5ccaae3111164214f15a8b5579e87..6966cc75804b6b5a49ceb45a26994c23d2936bdb 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -195,7 +195,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): _infer_var_data_type_(arg, block) -def append_backward(loss, parameter_list=None, no_grad_dict=None): +def append_backward(loss, parameter_list=None, no_grad_set=None): """ Create and add gradient Operators in BlockDesc to compute gradients of `loss` for parameters in parameter_list @@ -213,8 +213,8 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None): assert isinstance(loss, framework.Variable) program = loss.block.program - if no_grad_dict is None: - no_grad_dict = dict() + no_grad_dict = dict() + if no_grad_set is None: assert isinstance(program, framework.Program) for block in program.blocks: assert isinstance(block, framework.Block) @@ -224,8 +224,10 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None): if var.stop_gradient: block_no_grad_set.add(_append_grad_suffix_(var.name)) no_grad_dict[block.idx] = block_no_grad_set - elif isinstance(no_grad_dict, set): - no_grad_dict = {0: no_grad_dict} + elif isinstance(no_grad_set, set): + no_grad_dict = {0: no_grad_set} + else: + raise ValueError("'no_grad_set' should be a set or None.") grad_info_map = dict() root_block = program.block(0)