未验证 提交 74fa603c 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #12043 from JiayiFeng/fix_backward_bug

Fix a backward bug
...@@ -123,7 +123,8 @@ def _append_grad_suffix_(name): ...@@ -123,7 +123,8 @@ def _append_grad_suffix_(name):
def _addup_repetitive_outputs_(op_descs): def _addup_repetitive_outputs_(op_descs):
""" """
In backward part, an variable may be the output of more than one ops. In backward part, an variable may be the output of more than one ops.
In this case, the variable should be the accumulation of all the outputs. And one op may yield its multiple outputs to the same variable.
In these cases, the variable should be the accumulation of all the outputs.
`sum_op`s are added to implement the accumulate. `sum_op`s are added to implement the accumulate.
""" """
pending_sum_ops = [] pending_sum_ops = []
...@@ -136,7 +137,9 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -136,7 +137,9 @@ def _addup_repetitive_outputs_(op_descs):
"sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]}, "sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]},
{"use_mkldnn": False}), idx)) {"use_mkldnn": False}), idx))
renamed_vars[var_name] = [var_name] renamed_vars[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for param_idx, param_name in enumerate(op_desc.output_names()):
arg_names = op_desc.output(param_name)
for arg_idx, var_name in enumerate(arg_names):
if var_name == core.empty_var_name( if var_name == core.empty_var_name(
) or var_name in op_desc.input_arg_names(): ) or var_name in op_desc.input_arg_names():
# empty variable or inplace op # empty variable or inplace op
...@@ -154,11 +157,26 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -154,11 +157,26 @@ def _addup_repetitive_outputs_(op_descs):
_rename_arg_(op_descs, var_name, new_name, 0, idx) _rename_arg_(op_descs, var_name, new_name, 0, idx)
_rename_arg_(pending_sum_ops, var_name, new_name) _rename_arg_(pending_sum_ops, var_name, new_name)
for p in op_desc.output_names()[:param_idx]:
p_arg_names = op_desc.output(p)
if var_name in p_arg_names:
op_desc.set_output(p, [
new_name if x == var_name else x
for x in p_arg_names
])
arg_names = [
new_name if x == var_name else x
for x in arg_names[:arg_idx]
] + arg_names[arg_idx:]
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] += 1 var_rename_count[var_name] += 1
op_desc.rename_output(var_name, new_name) arg_names[arg_idx] = new_name
op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
for var_name, inputs in renamed_vars.iteritems(): for var_name, inputs in renamed_vars.iteritems():
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append( pending_sum_ops.append(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册