diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index a1be768daa2dfa20f0c3b71f429581e9ca2932b5..ac60bf543600008fd5339c1a378951374afc4ad6 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -142,12 +142,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): """ def _op_can_be_removed_(op_desc, no_grad_set): - if _all_in_set_(op_desc.output_arg_names(), no_grad_set): + out_arg_names = op_desc.output_arg_names() + if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set): return True if _all_in_set_( filter(lambda name: name.find(core.grad_var_suffix()) != -1, op_desc.input_arg_names()), no_grad_set): - no_grad_set.union(op_desc.output_arg_names()) + no_grad_set.union(out_arg_names) return True return False @@ -296,7 +297,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): block_no_grad_set.add(_append_grad_suffix_(var.name)) no_grad_dict[block.idx] = block_no_grad_set elif isinstance(no_grad_set, set): - no_grad_dict = {0: no_grad_set} + no_grad_dict = { + 0: set([_append_grad_suffix_(name) for name in no_grad_set]) + } else: raise ValueError("'no_grad_set' should be a set or None.")