diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 7b4eb27fc82baed02c35373991eba103b9382dab..3305660c1aa65a64a2c284872c065396aa0c8374 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -16,7 +16,7 @@ import abc import paddle from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..dist_attribute import OperatorDistributedAttribute -from ..utils import _get_comm_group, _get_corresponding_rank +from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op from ..process_group import new_process_group _g_distributed_operator_impl_containers = {} @@ -426,7 +426,8 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names, rank (int): global ranks index for current process. """ - if len(act_grad_names) == 0 or len(out_grad_names) == 0: + if is_optimize_op(op) or len(act_grad_names) == 0 or len( + out_grad_names) == 0: return dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank) diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 3eb8437db6b1f1b3157eb512fb55bb748bc670fb..853df2eea7f5d0a7e96f84c5e0b83a685b3606d1 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -279,7 +279,7 @@ class Partitioner(object): dist_op_opt_impl = _get_dist_op_backward_implement( op, self._dist_context, forward_op_id2forward_op) dist_op_opt_impl.backward(self._dist_context, **kinputs, - **koutputs) + **koutputs, **{"grad_var_to_var": {}}) else: raise NotImplementedError( "partitioner only support forward and backward, optimize ops, but got {}" diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 8bfde1cba1cabc053c9930c0beeff4700b5983fd..128cffcc6e9270d165cc6f4e8c4d1417e027dca9 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -491,6 +491,50 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context): dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) +def _get_memcopy_idx(block, found_inf_var): + # use reduce_any op for check_nan_inf as the anchor for now + for idx, op in enumerate(block.ops): + if op.type == 'reduce_any' and op.output_arg_names[ + 0] == found_inf_var.name: + return idx + 1 + + raise RuntimeError( + "not found the correct location for memcopy for found_inf_var.") + + +def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): + src_name = src_var.name + output_var = block.create_var(name=unique_name.generate_with_ignorable_key( + src_name.join(['memcopy_'])), + dtype=src_var.dtype, + shape=src_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=src_var.stop_gradient) + + set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks) + + # TODO to support CUDAPinned/NPU/XPU Places + if direction == "D2H": + dst_place_type = 0 + elif direction == "D2H": + dst_place_type = 1 + else: + raise NotImplementedError( + "direction [{}] is not supported yet.".format(direction)) + + attrs = {'dst_place_type': dst_place_type} + new_op = block._insert_op_without_sync(index=idx, + type='memcpy', + inputs={'X': [src_var]}, + outputs={'Out': [output_var]}, + attrs=attrs) + _set_op_dist_attr_with_ranks(new_op, world_process_group.ranks, block, + dist_context) + block._sync_with_cpp() + return output_var + + @register_pass("auto_parallel_fp16") class FP16Pass(AMPPass): @@ -577,9 +621,12 @@ class FP16Pass(AMPPass): if isinstance( base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)): - # with main_program._optimized_guard([]): - # found_inf = paddle.tensor.creation._memcpy( - # found_inf, paddle.CPUPlace()) + with main_program._optimized_guard([]): + # found_inf = paddle.tensor.creation._memcpy( + # found_inf, paddle.CPUPlace()) + insert_idx = _get_memcopy_idx(block, found_inf) + found_inf = _insert_memcopy(block, insert_idx, found_inf, + self.dist_context) base_opt._set_auxiliary_var('found_inf', found_inf.name) elif hasattr(base_opt, "_set_auxiliary_var"): base_opt._set_auxiliary_var('found_inf', found_inf.name)