From 3c9b59b8e8d787eb4d2c33e468f48048fdec2959 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 9 Jul 2018 19:00:53 +0800 Subject: [PATCH] update --- python/paddle/fluid/backward.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 23574582448..71a1653cf4d 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -136,7 +136,7 @@ def _addup_repetitive_outputs_(op_descs): "sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]}, {"use_mkldnn": False}), idx)) renamed_vars[var_name] = [var_name] - for param_name in op_desc.output_names(): + for param_idx, param_name in enumerate(op_desc.output_names()): arg_names = op_desc.output(param_name) for arg_idx, var_name in enumerate(arg_names): if var_name == core.empty_var_name( @@ -156,12 +156,26 @@ def _addup_repetitive_outputs_(op_descs): _rename_arg_(op_descs, var_name, new_name, 0, idx) _rename_arg_(pending_sum_ops, var_name, new_name) + for p in op_desc.output_names()[:param_idx]: + p_arg_names = op_desc.output(p) + if var_name in p_arg_names: + op_desc.set_output(p, [ + new_name if x == var_name else x + for x in p_arg_names + ]) + + arg_names = [ + new_name if x == var_name else x + for x in arg_names[:arg_idx] + ] + arg_names[arg_idx:] + new_name = var_name + "@RENAME@" + \ str(var_rename_count[var_name]) var_rename_count[var_name] += 1 arg_names[arg_idx] = new_name op_desc.set_output(param_name, arg_names) renamed_vars[var_name].append(new_name) + for var_name, inputs in renamed_vars.iteritems(): if len(inputs) > 1: pending_sum_ops.append( -- GitLab