提交 85b98070 编写于 作者: F fengjiayi

fix a bug of inplace

上级 77fffc60
...@@ -99,6 +99,9 @@ def _append_backward_ops_(target, ...@@ -99,6 +99,9 @@ def _append_backward_ops_(target,
attrs={}), idx)) attrs={}), idx))
var_inputs[var_name] = [var_name] var_inputs[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for var_name in op_desc.output_arg_names():
if var_name in op_desc.input_arg_names():
# in place operator
continue
if var_name == core.empty_var_name() or len(var_inputs[ if var_name == core.empty_var_name() or len(var_inputs[
var_name]) == 0: var_name]) == 0:
# it's the first time we get the variable # it's the first time we get the variable
...@@ -221,6 +224,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): ...@@ -221,6 +224,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
if var.stop_gradient: if var.stop_gradient:
block_no_grad_set.add(_append_grad_suffix_(var.name)) block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_set[block.idx] = block_no_grad_set no_grad_set[block.idx] = block_no_grad_set
else:
# FIX ME
no_grad_set = {0: no_grad_set}
grad_info_map = dict() grad_info_map = dict()
root_block = program.block(0) root_block = program.block(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册