提交 7be57de9 编写于 作者: F fengjiayi

enhance no_grad_var handling

上级 6a5cf28a
...@@ -57,6 +57,8 @@ def _all_in_set_(cands, s): ...@@ -57,6 +57,8 @@ def _all_in_set_(cands, s):
""" """
Test if all elements of 'cands' are in set 's' Test if all elements of 'cands' are in set 's'
""" """
if len(cands) == 0:
return False
for c in cands: for c in cands:
if not c in s: if not c in s:
return False return False
...@@ -138,10 +140,20 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): ...@@ -138,10 +140,20 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
1. all outputs of the grad op are in 'no_grad_set' 1. all outputs of the grad op are in 'no_grad_set'
2. (TODO) all grad inputs of the grad op are in 'no_grad_set' 2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
""" """
def _op_can_be_removed_(op_desc, no_grad_set):
if _all_in_set_(op_desc.output_arg_names(), no_grad_set):
return True
if _all_in_set_(
filter(lambda name: name.find(core.grad_var_suffix()) != -1,
op_desc.input_arg_names()), no_grad_set):
no_grad_set.union(op_desc.output_arg_names())
return True
return False
# Remove ops whose outputs are all in no_grad_dict # Remove ops whose outputs are all in no_grad_dict
op_descs = filter( op_descs = filter(
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set), lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs)
op_descs)
# Insert fill_zeros_like_op # Insert fill_zeros_like_op
to_insert = [] to_insert = []
for idx, op_desc in enumerate(op_descs): for idx, op_desc in enumerate(op_descs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册