From 33e75201e9d3c14945bbe556267b8bae069de327 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 2 Jan 2018 20:00:00 +0800 Subject: [PATCH] fix bugs --- python/paddle/v2/fluid/backward.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index a1be768daa2..ac60bf54360 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -142,12 +142,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): """ def _op_can_be_removed_(op_desc, no_grad_set): - if _all_in_set_(op_desc.output_arg_names(), no_grad_set): + out_arg_names = op_desc.output_arg_names() + if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set): return True if _all_in_set_( filter(lambda name: name.find(core.grad_var_suffix()) != -1, op_desc.input_arg_names()), no_grad_set): - no_grad_set.union(op_desc.output_arg_names()) + no_grad_set.union(out_arg_names) return True return False @@ -296,7 +297,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): block_no_grad_set.add(_append_grad_suffix_(var.name)) no_grad_dict[block.idx] = block_no_grad_set elif isinstance(no_grad_set, set): - no_grad_dict = {0: no_grad_set} + no_grad_dict = { + 0: set([_append_grad_suffix_(name) for name in no_grad_set]) + } else: raise ValueError("'no_grad_set' should be a set or None.") -- GitLab