未验证 提交 e2b924bf 编写于 作者: J JZ-LIANG 提交者: GitHub

[AutoParallel] Prune D2H memcpy for fp16 pass (#45159)

* prune d2h memcpy for fp16 pass
上级 fa890092
...@@ -16,7 +16,7 @@ import abc ...@@ -16,7 +16,7 @@ import abc
import paddle import paddle
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op
from ..process_group import new_process_group from ..process_group import new_process_group
_g_distributed_operator_impl_containers = {} _g_distributed_operator_impl_containers = {}
...@@ -426,7 +426,8 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names, ...@@ -426,7 +426,8 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names,
rank (int): global ranks index for current process. rank (int): global ranks index for current process.
""" """
if len(act_grad_names) == 0 or len(out_grad_names) == 0: if is_optimize_op(op) or len(act_grad_names) == 0 or len(
out_grad_names) == 0:
return return
dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank) dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank)
......
...@@ -279,7 +279,7 @@ class Partitioner(object): ...@@ -279,7 +279,7 @@ class Partitioner(object):
dist_op_opt_impl = _get_dist_op_backward_implement( dist_op_opt_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op) op, self._dist_context, forward_op_id2forward_op)
dist_op_opt_impl.backward(self._dist_context, **kinputs, dist_op_opt_impl.backward(self._dist_context, **kinputs,
**koutputs) **koutputs, **{"grad_var_to_var": {}})
else: else:
raise NotImplementedError( raise NotImplementedError(
"partitioner only support forward and backward, optimize ops, but got {}" "partitioner only support forward and backward, optimize ops, but got {}"
......
...@@ -491,6 +491,50 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context): ...@@ -491,6 +491,50 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def _get_memcopy_idx(block, found_inf_var):
# use reduce_any op for check_nan_inf as the anchor for now
for idx, op in enumerate(block.ops):
if op.type == 'reduce_any' and op.output_arg_names[
0] == found_inf_var.name:
return idx + 1
raise RuntimeError(
"not found the correct location for memcopy for found_inf_var.")
def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
src_name = src_var.name
output_var = block.create_var(name=unique_name.generate_with_ignorable_key(
src_name.join(['memcopy_'])),
dtype=src_var.dtype,
shape=src_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=src_var.stop_gradient)
set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks)
# TODO to support CUDAPinned/NPU/XPU Places
if direction == "D2H":
dst_place_type = 0
elif direction == "D2H":
dst_place_type = 1
else:
raise NotImplementedError(
"direction [{}] is not supported yet.".format(direction))
attrs = {'dst_place_type': dst_place_type}
new_op = block._insert_op_without_sync(index=idx,
type='memcpy',
inputs={'X': [src_var]},
outputs={'Out': [output_var]},
attrs=attrs)
_set_op_dist_attr_with_ranks(new_op, world_process_group.ranks, block,
dist_context)
block._sync_with_cpp()
return output_var
@register_pass("auto_parallel_fp16") @register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass): class FP16Pass(AMPPass):
...@@ -577,9 +621,12 @@ class FP16Pass(AMPPass): ...@@ -577,9 +621,12 @@ class FP16Pass(AMPPass):
if isinstance( if isinstance(
base_opt, base_opt,
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)): (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)):
# with main_program._optimized_guard([]): with main_program._optimized_guard([]):
# found_inf = paddle.tensor.creation._memcpy( # found_inf = paddle.tensor.creation._memcpy(
# found_inf, paddle.CPUPlace()) # found_inf, paddle.CPUPlace())
insert_idx = _get_memcopy_idx(block, found_inf)
found_inf = _insert_memcopy(block, insert_idx, found_inf,
self.dist_context)
base_opt._set_auxiliary_var('found_inf', found_inf.name) base_opt._set_auxiliary_var('found_inf', found_inf.name)
elif hasattr(base_opt, "_set_auxiliary_var"): elif hasattr(base_opt, "_set_auxiliary_var"):
base_opt._set_auxiliary_var('found_inf', found_inf.name) base_opt._set_auxiliary_var('found_inf', found_inf.name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册