diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index 08093710b3b89196f3ed758a757e2b560d7ca191..884af3a441431861bc9757ea8d91c5c7997f880a 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -46,7 +46,15 @@ def _c_identity(tensor, group=None): class c_identity_eager(PyLayer): @staticmethod def forward(ctx, tensor): - return tensor + return _legacy_C_ops.c_identity( + tensor, + 'use_calc_stream', + True, + 'ring_id', + group.id, + 'use_model_parallel', + True, + ) @staticmethod def backward(ctx, dy): @@ -249,7 +257,15 @@ def _mp_allreduce( @staticmethod def backward(ctx, dy): - return dy + return _legacy_C_ops.c_identity( + dy, + 'use_calc_stream', + True, + 'ring_id', + ctx.ring_id, + 'use_model_parallel', + True, + ) return mp_allreduce_eager.apply( tensor, group, use_calc_stream, use_model_parallel