提交 f5894d22 编写于 作者: F fengjiayi

Fix a backward bug

上级 436bb450
...@@ -136,7 +136,9 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -136,7 +136,9 @@ def _addup_repetitive_outputs_(op_descs):
"sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]}, "sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]},
{"use_mkldnn": False}), idx)) {"use_mkldnn": False}), idx))
renamed_vars[var_name] = [var_name] renamed_vars[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for param_name in op_desc.output_names():
arg_names = op_desc.output(param_name)
for arg_idx, var_name in enumerate(arg_names):
if var_name == core.empty_var_name( if var_name == core.empty_var_name(
) or var_name in op_desc.input_arg_names(): ) or var_name in op_desc.input_arg_names():
# empty variable or inplace op # empty variable or inplace op
...@@ -157,7 +159,8 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -157,7 +159,8 @@ def _addup_repetitive_outputs_(op_descs):
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] += 1 var_rename_count[var_name] += 1
op_desc.rename_output(var_name, new_name) arg_names[arg_idx] = new_name
op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
for var_name, inputs in renamed_vars.iteritems(): for var_name, inputs in renamed_vars.iteritems():
if len(inputs) > 1: if len(inputs) > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册