diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index 82a072fe056404c227f374c08a79f5592369e49a..ccfba7acafffd7f154ec92f44d971b6a66e2edb2 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -48,6 +48,38 @@ class c_identity_eager(PyLayer): return dy +class c_split_eager(PyLayer): + @staticmethod + def forward(ctx, tensor, group, rank, nranks): + ctx.group = group + ctx.nranks = nranks + return _legacy_C_ops.c_split( + tensor, + 'use_calc_stream', + True, + 'ring_id', + group.id, + 'rank', + rank, + 'nranks', + nranks, + 'use_model_parallel', + True, + ) + + @staticmethod + def backward(ctx, dy): + group = ctx.group + out_shape = dy.shape + out_shape[0] = out_shape[0] * ctx.nranks + out = paddle.empty(out_shape, dtype=dy.dtype) + group.process_group.all_gather_into_tensor_on_calc_stream( + out, + dy, + ) + return out + + def _c_identity(tensor, group=None, skip_c_identity_dynamic=False): """ Return a copy of the tensor, mainly used with model parallel. @@ -179,19 +211,7 @@ def _c_split(tensor, group=None): ) if in_dynamic_mode(): - return _legacy_C_ops.c_split( - tensor, - 'use_calc_stream', - True, - 'ring_id', - ring_id, - 'rank', - rank, - 'nranks', - nranks, - 'use_model_parallel', - True, - ) + return c_split_eager.apply(tensor, group, rank, nranks) else: op_type = 'c_split' helper = LayerHelper(op_type, **locals())