提交 7be57de9 编写于 作者: F fengjiayi

enhance no_grad_var handling

上级 6a5cf28a
......@@ -57,6 +57,8 @@ def _all_in_set_(cands, s):
"""
Test if all elements of 'cands' are in set 's'
"""
if len(cands) == 0:
return False
for c in cands:
if not c in s:
return False
......@@ -138,10 +140,20 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
1. all outputs of the grad op are in 'no_grad_set'
2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
"""
def _op_can_be_removed_(op_desc, no_grad_set):
if _all_in_set_(op_desc.output_arg_names(), no_grad_set):
return True
if _all_in_set_(
filter(lambda name: name.find(core.grad_var_suffix()) != -1,
op_desc.input_arg_names()), no_grad_set):
no_grad_set.union(op_desc.output_arg_names())
return True
return False
# Remove ops whose outputs are all in no_grad_dict
op_descs = filter(
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
op_descs)
lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs)
# Insert fill_zeros_like_op
to_insert = []
for idx, op_desc in enumerate(op_descs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册