From e2b924bfd2c4e9dc36dfbf0a03af6e39ac69c6d3 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 16 Aug 2022 11:35:56 +0800 Subject: [PATCH] [AutoParallel] Prune D2H memcpy for fp16 pass (#45159) * prune d2h memcpy for fp16 pass --- .../auto_parallel/operators/common.py | 5 +- .../distributed/auto_parallel/partitioner.py | 2 +- .../distributed/passes/auto_parallel_fp16.py | 53 +++++++++++++++++-- 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 7b4eb27fc82..3305660c1aa 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -16,7 +16,7 @@ import abc import paddle from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..dist_attribute import OperatorDistributedAttribute -from ..utils import _get_comm_group, _get_corresponding_rank +from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op from ..process_group import new_process_group _g_distributed_operator_impl_containers = {} @@ -426,7 +426,8 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names, rank (int): global ranks index for current process. """ - if len(act_grad_names) == 0 or len(out_grad_names) == 0: + if is_optimize_op(op) or len(act_grad_names) == 0 or len( + out_grad_names) == 0: return dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank) diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 3eb8437db6b..853df2eea7f 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -279,7 +279,7 @@ class Partitioner(object): dist_op_opt_impl = _get_dist_op_backward_implement( op, self._dist_context, forward_op_id2forward_op) dist_op_opt_impl.backward(self._dist_context, **kinputs, - **koutputs) + **koutputs, **{"grad_var_to_var": {}}) else: raise NotImplementedError( "partitioner only support forward and backward, optimize ops, but got {}" diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 8bfde1cba1c..128cffcc6e9 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -491,6 +491,50 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context): dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) +def _get_memcopy_idx(block, found_inf_var): + # use reduce_any op for check_nan_inf as the anchor for now + for idx, op in enumerate(block.ops): + if op.type == 'reduce_any' and op.output_arg_names[ + 0] == found_inf_var.name: + return idx + 1 + + raise RuntimeError( + "not found the correct location for memcopy for found_inf_var.") + + +def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): + src_name = src_var.name + output_var = block.create_var(name=unique_name.generate_with_ignorable_key( + src_name.join(['memcopy_'])), + dtype=src_var.dtype, + shape=src_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=src_var.stop_gradient) + + set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks) + + # TODO to support CUDAPinned/NPU/XPU Places + if direction == "D2H": + dst_place_type = 0 + elif direction == "D2H": + dst_place_type = 1 + else: + raise NotImplementedError( + "direction [{}] is not supported yet.".format(direction)) + + attrs = {'dst_place_type': dst_place_type} + new_op = block._insert_op_without_sync(index=idx, + type='memcpy', + inputs={'X': [src_var]}, + outputs={'Out': [output_var]}, + attrs=attrs) + _set_op_dist_attr_with_ranks(new_op, world_process_group.ranks, block, + dist_context) + block._sync_with_cpp() + return output_var + + @register_pass("auto_parallel_fp16") class FP16Pass(AMPPass): @@ -577,9 +621,12 @@ class FP16Pass(AMPPass): if isinstance( base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)): - # with main_program._optimized_guard([]): - # found_inf = paddle.tensor.creation._memcpy( - # found_inf, paddle.CPUPlace()) + with main_program._optimized_guard([]): + # found_inf = paddle.tensor.creation._memcpy( + # found_inf, paddle.CPUPlace()) + insert_idx = _get_memcopy_idx(block, found_inf) + found_inf = _insert_memcopy(block, insert_idx, found_inf, + self.dist_context) base_opt._set_auxiliary_var('found_inf', found_inf.name) elif hasattr(base_opt, "_set_auxiliary_var"): base_opt._set_auxiliary_var('found_inf', found_inf.name) -- GitLab