From 201480d5295be18d312b6d543602b979659db801 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 4 Sep 2023 16:47:16 +0800 Subject: [PATCH] fix bug in c_split (#56917) --- .../distributed/fleet/layers/mpu/mp_ops.py | 46 +++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index 82a072fe056..ccfba7acaff 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -48,6 +48,38 @@ class c_identity_eager(PyLayer): return dy +class c_split_eager(PyLayer): + @staticmethod + def forward(ctx, tensor, group, rank, nranks): + ctx.group = group + ctx.nranks = nranks + return _legacy_C_ops.c_split( + tensor, + 'use_calc_stream', + True, + 'ring_id', + group.id, + 'rank', + rank, + 'nranks', + nranks, + 'use_model_parallel', + True, + ) + + @staticmethod + def backward(ctx, dy): + group = ctx.group + out_shape = dy.shape + out_shape[0] = out_shape[0] * ctx.nranks + out = paddle.empty(out_shape, dtype=dy.dtype) + group.process_group.all_gather_into_tensor_on_calc_stream( + out, + dy, + ) + return out + + def _c_identity(tensor, group=None, skip_c_identity_dynamic=False): """ Return a copy of the tensor, mainly used with model parallel. @@ -179,19 +211,7 @@ def _c_split(tensor, group=None): ) if in_dynamic_mode(): - return _legacy_C_ops.c_split( - tensor, - 'use_calc_stream', - True, - 'ring_id', - ring_id, - 'rank', - rank, - 'nranks', - nranks, - 'use_model_parallel', - True, - ) + return c_split_eager.apply(tensor, group, rank, nranks) else: op_type = 'c_split' helper = LayerHelper(op_type, **locals()) -- GitLab