提交 f5894d22 编写于 作者: F fengjiayi

Fix a backward bug

上级 436bb450
...@@ -136,29 +136,32 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -136,29 +136,32 @@ def _addup_repetitive_outputs_(op_descs):
"sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]}, "sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]},
{"use_mkldnn": False}), idx)) {"use_mkldnn": False}), idx))
renamed_vars[var_name] = [var_name] renamed_vars[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for param_name in op_desc.output_names():
if var_name == core.empty_var_name( arg_names = op_desc.output(param_name)
) or var_name in op_desc.input_arg_names(): for arg_idx, var_name in enumerate(arg_names):
# empty variable or inplace op if var_name == core.empty_var_name(
continue ) or var_name in op_desc.input_arg_names():
if len(renamed_vars[var_name]) == 0: # empty variable or inplace op
# it's the first time we get the variable continue
renamed_vars[var_name] = [var_name] if len(renamed_vars[var_name]) == 0:
else: # it's the first time we get the variable
if len(renamed_vars[var_name]) == 1: renamed_vars[var_name] = [var_name]
else:
if len(renamed_vars[var_name]) == 1:
new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] += 1
# rename original var_name
renamed_vars[var_name][0] = new_name
_rename_arg_(op_descs, var_name, new_name, 0, idx)
_rename_arg_(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] += 1 var_rename_count[var_name] += 1
# rename original var_name arg_names[arg_idx] = new_name
renamed_vars[var_name][0] = new_name op_desc.set_output(param_name, arg_names)
_rename_arg_(op_descs, var_name, new_name, 0, idx) renamed_vars[var_name].append(new_name)
_rename_arg_(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] += 1
op_desc.rename_output(var_name, new_name)
renamed_vars[var_name].append(new_name)
for var_name, inputs in renamed_vars.iteritems(): for var_name, inputs in renamed_vars.iteritems():
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append( pending_sum_ops.append(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册