提交 33e75201 编写于 作者: F fengjiayi

fix bugs

上级 8d4a607f
...@@ -142,12 +142,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): ...@@ -142,12 +142,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
""" """
def _op_can_be_removed_(op_desc, no_grad_set): def _op_can_be_removed_(op_desc, no_grad_set):
if _all_in_set_(op_desc.output_arg_names(), no_grad_set): out_arg_names = op_desc.output_arg_names()
if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set):
return True return True
if _all_in_set_( if _all_in_set_(
filter(lambda name: name.find(core.grad_var_suffix()) != -1, filter(lambda name: name.find(core.grad_var_suffix()) != -1,
op_desc.input_arg_names()), no_grad_set): op_desc.input_arg_names()), no_grad_set):
no_grad_set.union(op_desc.output_arg_names()) no_grad_set.union(out_arg_names)
return True return True
return False return False
...@@ -296,7 +297,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): ...@@ -296,7 +297,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
block_no_grad_set.add(_append_grad_suffix_(var.name)) block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_dict[block.idx] = block_no_grad_set no_grad_dict[block.idx] = block_no_grad_set
elif isinstance(no_grad_set, set): elif isinstance(no_grad_set, set):
no_grad_dict = {0: no_grad_set} no_grad_dict = {
0: set([_append_grad_suffix_(name) for name in no_grad_set])
}
else: else:
raise ValueError("'no_grad_set' should be a set or None.") raise ValueError("'no_grad_set' should be a set or None.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册