未验证 提交 db937b5a 编写于 作者: X xiongkun 提交者: GitHub

fix multi-targets bugs which this is common case in dy2static (#45277)

上级 229befc8
......@@ -642,12 +642,16 @@ def _addup_repetitive_outputs_(op_descs,
return op_descs
def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None):
def _remove_no_grad_branch_(op_descs,
no_grad_set,
grad_op_id_to_fwd_op=None,
target_vars=[]):
"""
Remove unnecessary grad ops
A grad op can be removed in two cases:
1. all outputs of the grad op are in 'no_grad_set'
2. all grad inputs of the grad op are in 'no_grad_set'
NOTE: we will skip target_vars's grad name.
"""
def _op_can_be_removed_(op_desc, no_grad_set):
......@@ -658,11 +662,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None):
name for name in op_desc.input_arg_names()
if name.find(core.grad_var_suffix()) != -1
], no_grad_set):
no_grad_set.update(out_arg_names)
no_grad_set.update(set(out_arg_names) - target_grad_var_names)
return True
return False
# Remove ops whose outputs are all in no_grad_dict
target_grad_var_names = set(
[var.name + core.grad_var_suffix() for var in target_vars])
op_descs = [
op_desc for op_desc in op_descs
if not _op_can_be_removed_(op_desc, no_grad_set)
......@@ -824,6 +830,7 @@ def serialize_op_decs(op_desc):
def _append_backward_ops_with_checkpoints_(block,
ops,
target_vars,
target_block,
no_grad_dict,
grad_to_var,
......@@ -835,6 +842,7 @@ def _append_backward_ops_with_checkpoints_(block,
Args:
block(Block): the block where forward ops are
ops(Op): the forward operators whose forward recomputation backward ops need to be added
target_vars(list[Tensor]): the loss vars we want to calculate gradient.
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
......@@ -1070,7 +1078,7 @@ def _append_backward_ops_with_checkpoints_(block,
# 4) remove no grad branch as it is in _remove_no_grad_branch_
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx],
grad_op_id_to_fwd_op)
grad_op_id_to_fwd_op, target_vars)
added_descs = _add_descs_to_block(grad_op_descs, target_block,
grad_op_id_to_fwd_op)
return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments
......@@ -1140,6 +1148,7 @@ def _rename_grad_name_(name, grad_order):
def _append_backward_ops_(block,
ops,
target_vars,
target_block,
no_grad_dict,
grad_to_var,
......@@ -1155,6 +1164,7 @@ def _append_backward_ops_(block,
Args:
block(Block): the block where forward ops are
ops(Op): the forward operators whose backward ops need to be added
target_vars(list[Tensor]): the loss vars we want to calculate gradient.
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
......@@ -1212,6 +1222,7 @@ def _append_backward_ops_(block,
sub_block_path = op_path_dict[op._block_attr_id("sub_block")]
_append_backward_ops_(sub_block,
sub_block_path,
target_vars,
grad_sub_block,
no_grad_dict,
grad_to_var,
......@@ -1330,7 +1341,7 @@ def _append_backward_ops_(block,
# if all inputs of the grad op are in no_grad_set, just remove this op
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx],
grad_op_id_to_fwd_op)
grad_op_id_to_fwd_op, target_vars)
# remove some backward ops
not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set)
......@@ -1765,6 +1776,7 @@ def append_backward(loss,
_append_backward_ops_with_checkpoints_(
root_block,
op_path,
[loss],
root_block,
no_grad_dict,
grad_to_var,
......@@ -1774,6 +1786,7 @@ def append_backward(loss,
_append_backward_ops_(
block, # the block where forward ops are in
op_path,
[loss],
target_grad_block,
no_grad_dict,
grad_to_var,
......@@ -2135,6 +2148,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
grad_info_map = dict()
_append_backward_ops_(block,
op_path,
targets,
block,
no_grad_dict,
grad_to_var,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册