diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 235745824489d78794d6f273c9a1484a60da1c4f..71a1653cf4d39cf71167116260d541fcb065e425 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -136,7 +136,7 @@ def _addup_repetitive_outputs_(op_descs): "sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]}, {"use_mkldnn": False}), idx)) renamed_vars[var_name] = [var_name] - for param_name in op_desc.output_names(): + for param_idx, param_name in enumerate(op_desc.output_names()): arg_names = op_desc.output(param_name) for arg_idx, var_name in enumerate(arg_names): if var_name == core.empty_var_name( @@ -156,12 +156,26 @@ def _addup_repetitive_outputs_(op_descs): _rename_arg_(op_descs, var_name, new_name, 0, idx) _rename_arg_(pending_sum_ops, var_name, new_name) + for p in op_desc.output_names()[:param_idx]: + p_arg_names = op_desc.output(p) + if var_name in p_arg_names: + op_desc.set_output(p, [ + new_name if x == var_name else x + for x in p_arg_names + ]) + + arg_names = [ + new_name if x == var_name else x + for x in arg_names[:arg_idx] + ] + arg_names[arg_idx:] + new_name = var_name + "@RENAME@" + \ str(var_rename_count[var_name]) var_rename_count[var_name] += 1 arg_names[arg_idx] = new_name op_desc.set_output(param_name, arg_names) renamed_vars[var_name].append(new_name) + for var_name, inputs in renamed_vars.iteritems(): if len(inputs) > 1: pending_sum_ops.append(