From db937b5a54e2f8210fd68939b7e5539cf6bcc427 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 23 Aug 2022 13:19:23 +0800 Subject: [PATCH] fix multi-targets bugs which this is common case in dy2static (#45277) --- python/paddle/fluid/backward.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 5ed01a01144..db7c03bb255 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -642,12 +642,16 @@ def _addup_repetitive_outputs_(op_descs, return op_descs -def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None): +def _remove_no_grad_branch_(op_descs, + no_grad_set, + grad_op_id_to_fwd_op=None, + target_vars=[]): """ Remove unnecessary grad ops A grad op can be removed in two cases: 1. all outputs of the grad op are in 'no_grad_set' 2. all grad inputs of the grad op are in 'no_grad_set' + NOTE: we will skip target_vars's grad name. """ def _op_can_be_removed_(op_desc, no_grad_set): @@ -658,11 +662,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None): name for name in op_desc.input_arg_names() if name.find(core.grad_var_suffix()) != -1 ], no_grad_set): - no_grad_set.update(out_arg_names) + no_grad_set.update(set(out_arg_names) - target_grad_var_names) return True return False # Remove ops whose outputs are all in no_grad_dict + target_grad_var_names = set( + [var.name + core.grad_var_suffix() for var in target_vars]) op_descs = [ op_desc for op_desc in op_descs if not _op_can_be_removed_(op_desc, no_grad_set) @@ -824,6 +830,7 @@ def serialize_op_decs(op_desc): def _append_backward_ops_with_checkpoints_(block, ops, + target_vars, target_block, no_grad_dict, grad_to_var, @@ -835,6 +842,7 @@ def _append_backward_ops_with_checkpoints_(block, Args: block(Block): the block where forward ops are ops(Op): the forward operators whose forward recomputation backward ops need to be added + target_vars(list[Tensor]): the loss vars we want to calculate gradient. target_block(Block): the block which is going to hold new generated grad ops no_grad_dict(dict): key(int) block index @@ -1070,7 +1078,7 @@ def _append_backward_ops_with_checkpoints_(block, # 4) remove no grad branch as it is in _remove_no_grad_branch_ grad_op_descs = _remove_no_grad_branch_(grad_op_descs, no_grad_dict[block.idx], - grad_op_id_to_fwd_op) + grad_op_id_to_fwd_op, target_vars) added_descs = _add_descs_to_block(grad_op_descs, target_block, grad_op_id_to_fwd_op) return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments @@ -1140,6 +1148,7 @@ def _rename_grad_name_(name, grad_order): def _append_backward_ops_(block, ops, + target_vars, target_block, no_grad_dict, grad_to_var, @@ -1155,6 +1164,7 @@ def _append_backward_ops_(block, Args: block(Block): the block where forward ops are ops(Op): the forward operators whose backward ops need to be added + target_vars(list[Tensor]): the loss vars we want to calculate gradient. target_block(Block): the block which is going to hold new generated grad ops no_grad_dict(dict): key(int) block index @@ -1212,6 +1222,7 @@ def _append_backward_ops_(block, sub_block_path = op_path_dict[op._block_attr_id("sub_block")] _append_backward_ops_(sub_block, sub_block_path, + target_vars, grad_sub_block, no_grad_dict, grad_to_var, @@ -1330,7 +1341,7 @@ def _append_backward_ops_(block, # if all inputs of the grad op are in no_grad_set, just remove this op grad_op_descs = _remove_no_grad_branch_(grad_op_descs, no_grad_dict[block.idx], - grad_op_id_to_fwd_op) + grad_op_id_to_fwd_op, target_vars) # remove some backward ops not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set) @@ -1765,6 +1776,7 @@ def append_backward(loss, _append_backward_ops_with_checkpoints_( root_block, op_path, + [loss], root_block, no_grad_dict, grad_to_var, @@ -1774,6 +1786,7 @@ def append_backward(loss, _append_backward_ops_( block, # the block where forward ops are in op_path, + [loss], target_grad_block, no_grad_dict, grad_to_var, @@ -2135,6 +2148,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): grad_info_map = dict() _append_backward_ops_(block, op_path, + targets, block, no_grad_dict, grad_to_var, -- GitLab