From f5894d22c505ab688c7970fb72681aeb5816fc63 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 9 Jul 2018 16:11:20 +0800 Subject: [PATCH] Fix a backward bug --- python/paddle/fluid/backward.py | 43 ++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 4faa0630317..23574582448 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -136,29 +136,32 @@ def _addup_repetitive_outputs_(op_descs): "sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]}, {"use_mkldnn": False}), idx)) renamed_vars[var_name] = [var_name] - for var_name in op_desc.output_arg_names(): - if var_name == core.empty_var_name( - ) or var_name in op_desc.input_arg_names(): - # empty variable or inplace op - continue - if len(renamed_vars[var_name]) == 0: - # it's the first time we get the variable - renamed_vars[var_name] = [var_name] - else: - if len(renamed_vars[var_name]) == 1: + for param_name in op_desc.output_names(): + arg_names = op_desc.output(param_name) + for arg_idx, var_name in enumerate(arg_names): + if var_name == core.empty_var_name( + ) or var_name in op_desc.input_arg_names(): + # empty variable or inplace op + continue + if len(renamed_vars[var_name]) == 0: + # it's the first time we get the variable + renamed_vars[var_name] = [var_name] + else: + if len(renamed_vars[var_name]) == 1: + new_name = var_name + "@RENAME@" + \ + str(var_rename_count[var_name]) + var_rename_count[var_name] += 1 + # rename original var_name + renamed_vars[var_name][0] = new_name + _rename_arg_(op_descs, var_name, new_name, 0, idx) + _rename_arg_(pending_sum_ops, var_name, new_name) + new_name = var_name + "@RENAME@" + \ str(var_rename_count[var_name]) var_rename_count[var_name] += 1 - # rename original var_name - renamed_vars[var_name][0] = new_name - _rename_arg_(op_descs, var_name, new_name, 0, idx) - _rename_arg_(pending_sum_ops, var_name, new_name) - - new_name = var_name + "@RENAME@" + \ - str(var_rename_count[var_name]) - var_rename_count[var_name] += 1 - op_desc.rename_output(var_name, new_name) - renamed_vars[var_name].append(new_name) + arg_names[arg_idx] = new_name + op_desc.set_output(param_name, arg_names) + renamed_vars[var_name].append(new_name) for var_name, inputs in renamed_vars.iteritems(): if len(inputs) > 1: pending_sum_ops.append( -- GitLab