提交 33e75201 编写于 作者: F fengjiayi

fix bugs

上级 8d4a607f
......@@ -142,12 +142,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
"""
def _op_can_be_removed_(op_desc, no_grad_set):
if _all_in_set_(op_desc.output_arg_names(), no_grad_set):
out_arg_names = op_desc.output_arg_names()
if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set):
return True
if _all_in_set_(
filter(lambda name: name.find(core.grad_var_suffix()) != -1,
op_desc.input_arg_names()), no_grad_set):
no_grad_set.union(op_desc.output_arg_names())
no_grad_set.union(out_arg_names)
return True
return False
......@@ -296,7 +297,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_dict[block.idx] = block_no_grad_set
elif isinstance(no_grad_set, set):
no_grad_dict = {0: no_grad_set}
no_grad_dict = {
0: set([_append_grad_suffix_(name) for name in no_grad_set])
}
else:
raise ValueError("'no_grad_set' should be a set or None.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册