From 7be57de9434053e7aa2e7b1d78da62ee1cb41ba7 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 2 Jan 2018 16:55:51 +0800 Subject: [PATCH] enhance no_grad_var handling --- python/paddle/v2/fluid/backward.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index f11c83f59c9..43e9abc354d 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -57,6 +57,8 @@ def _all_in_set_(cands, s): """ Test if all elements of 'cands' are in set 's' """ + if len(cands) == 0: + return False for c in cands: if not c in s: return False @@ -138,10 +140,20 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): 1. all outputs of the grad op are in 'no_grad_set' 2. (TODO) all grad inputs of the grad op are in 'no_grad_set' """ + + def _op_can_be_removed_(op_desc, no_grad_set): + if _all_in_set_(op_desc.output_arg_names(), no_grad_set): + return True + if _all_in_set_( + filter(lambda name: name.find(core.grad_var_suffix()) != -1, + op_desc.input_arg_names()), no_grad_set): + no_grad_set.union(op_desc.output_arg_names()) + return True + return False + # Remove ops whose outputs are all in no_grad_dict op_descs = filter( - lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set), - op_descs) + lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs) # Insert fill_zeros_like_op to_insert = [] for idx, op_desc in enumerate(op_descs): -- GitLab