未验证 提交 db937b5a 编写于 作者: X xiongkun 提交者: GitHub

fix multi-targets bugs which this is common case in dy2static (#45277)

上级 229befc8
...@@ -642,12 +642,16 @@ def _addup_repetitive_outputs_(op_descs, ...@@ -642,12 +642,16 @@ def _addup_repetitive_outputs_(op_descs,
return op_descs return op_descs
def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None): def _remove_no_grad_branch_(op_descs,
no_grad_set,
grad_op_id_to_fwd_op=None,
target_vars=[]):
""" """
Remove unnecessary grad ops Remove unnecessary grad ops
A grad op can be removed in two cases: A grad op can be removed in two cases:
1. all outputs of the grad op are in 'no_grad_set' 1. all outputs of the grad op are in 'no_grad_set'
2. all grad inputs of the grad op are in 'no_grad_set' 2. all grad inputs of the grad op are in 'no_grad_set'
NOTE: we will skip target_vars's grad name.
""" """
def _op_can_be_removed_(op_desc, no_grad_set): def _op_can_be_removed_(op_desc, no_grad_set):
...@@ -658,11 +662,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None): ...@@ -658,11 +662,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None):
name for name in op_desc.input_arg_names() name for name in op_desc.input_arg_names()
if name.find(core.grad_var_suffix()) != -1 if name.find(core.grad_var_suffix()) != -1
], no_grad_set): ], no_grad_set):
no_grad_set.update(out_arg_names) no_grad_set.update(set(out_arg_names) - target_grad_var_names)
return True return True
return False return False
# Remove ops whose outputs are all in no_grad_dict # Remove ops whose outputs are all in no_grad_dict
target_grad_var_names = set(
[var.name + core.grad_var_suffix() for var in target_vars])
op_descs = [ op_descs = [
op_desc for op_desc in op_descs op_desc for op_desc in op_descs
if not _op_can_be_removed_(op_desc, no_grad_set) if not _op_can_be_removed_(op_desc, no_grad_set)
...@@ -824,6 +830,7 @@ def serialize_op_decs(op_desc): ...@@ -824,6 +830,7 @@ def serialize_op_decs(op_desc):
def _append_backward_ops_with_checkpoints_(block, def _append_backward_ops_with_checkpoints_(block,
ops, ops,
target_vars,
target_block, target_block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
...@@ -835,6 +842,7 @@ def _append_backward_ops_with_checkpoints_(block, ...@@ -835,6 +842,7 @@ def _append_backward_ops_with_checkpoints_(block,
Args: Args:
block(Block): the block where forward ops are block(Block): the block where forward ops are
ops(Op): the forward operators whose forward recomputation backward ops need to be added ops(Op): the forward operators whose forward recomputation backward ops need to be added
target_vars(list[Tensor]): the loss vars we want to calculate gradient.
target_block(Block): the block which is going to hold new generated grad ops target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict): no_grad_dict(dict):
key(int) block index key(int) block index
...@@ -1070,7 +1078,7 @@ def _append_backward_ops_with_checkpoints_(block, ...@@ -1070,7 +1078,7 @@ def _append_backward_ops_with_checkpoints_(block,
# 4) remove no grad branch as it is in _remove_no_grad_branch_ # 4) remove no grad branch as it is in _remove_no_grad_branch_
grad_op_descs = _remove_no_grad_branch_(grad_op_descs, grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx], no_grad_dict[block.idx],
grad_op_id_to_fwd_op) grad_op_id_to_fwd_op, target_vars)
added_descs = _add_descs_to_block(grad_op_descs, target_block, added_descs = _add_descs_to_block(grad_op_descs, target_block,
grad_op_id_to_fwd_op) grad_op_id_to_fwd_op)
return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments
...@@ -1140,6 +1148,7 @@ def _rename_grad_name_(name, grad_order): ...@@ -1140,6 +1148,7 @@ def _rename_grad_name_(name, grad_order):
def _append_backward_ops_(block, def _append_backward_ops_(block,
ops, ops,
target_vars,
target_block, target_block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
...@@ -1155,6 +1164,7 @@ def _append_backward_ops_(block, ...@@ -1155,6 +1164,7 @@ def _append_backward_ops_(block,
Args: Args:
block(Block): the block where forward ops are block(Block): the block where forward ops are
ops(Op): the forward operators whose backward ops need to be added ops(Op): the forward operators whose backward ops need to be added
target_vars(list[Tensor]): the loss vars we want to calculate gradient.
target_block(Block): the block which is going to hold new generated grad ops target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict): no_grad_dict(dict):
key(int) block index key(int) block index
...@@ -1212,6 +1222,7 @@ def _append_backward_ops_(block, ...@@ -1212,6 +1222,7 @@ def _append_backward_ops_(block,
sub_block_path = op_path_dict[op._block_attr_id("sub_block")] sub_block_path = op_path_dict[op._block_attr_id("sub_block")]
_append_backward_ops_(sub_block, _append_backward_ops_(sub_block,
sub_block_path, sub_block_path,
target_vars,
grad_sub_block, grad_sub_block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
...@@ -1330,7 +1341,7 @@ def _append_backward_ops_(block, ...@@ -1330,7 +1341,7 @@ def _append_backward_ops_(block,
# if all inputs of the grad op are in no_grad_set, just remove this op # if all inputs of the grad op are in no_grad_set, just remove this op
grad_op_descs = _remove_no_grad_branch_(grad_op_descs, grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx], no_grad_dict[block.idx],
grad_op_id_to_fwd_op) grad_op_id_to_fwd_op, target_vars)
# remove some backward ops # remove some backward ops
not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set) not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set)
...@@ -1765,6 +1776,7 @@ def append_backward(loss, ...@@ -1765,6 +1776,7 @@ def append_backward(loss,
_append_backward_ops_with_checkpoints_( _append_backward_ops_with_checkpoints_(
root_block, root_block,
op_path, op_path,
[loss],
root_block, root_block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
...@@ -1774,6 +1786,7 @@ def append_backward(loss, ...@@ -1774,6 +1786,7 @@ def append_backward(loss,
_append_backward_ops_( _append_backward_ops_(
block, # the block where forward ops are in block, # the block where forward ops are in
op_path, op_path,
[loss],
target_grad_block, target_grad_block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
...@@ -2135,6 +2148,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -2135,6 +2148,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
grad_info_map = dict() grad_info_map = dict()
_append_backward_ops_(block, _append_backward_ops_(block,
op_path, op_path,
targets,
block, block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册