From 3ece0ece6428f54e3e2060299e0a43dc005eb24f Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 12 Apr 2023 06:40:25 -0500 Subject: [PATCH] fix bug of mp (#52789) --- .../distributed/fleet/layers/mpu/mp_ops.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index 08093710b3b..884af3a4414 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -46,7 +46,15 @@ def _c_identity(tensor, group=None): class c_identity_eager(PyLayer): @staticmethod def forward(ctx, tensor): - return tensor + return _legacy_C_ops.c_identity( + tensor, + 'use_calc_stream', + True, + 'ring_id', + group.id, + 'use_model_parallel', + True, + ) @staticmethod def backward(ctx, dy): @@ -249,7 +257,15 @@ def _mp_allreduce( @staticmethod def backward(ctx, dy): - return dy + return _legacy_C_ops.c_identity( + dy, + 'use_calc_stream', + True, + 'ring_id', + ctx.ring_id, + 'use_model_parallel', + True, + ) return mp_allreduce_eager.apply( tensor, group, use_calc_stream, use_model_parallel -- GitLab