未验证 提交 3ece0ece 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of mp (#52789)

上级 cbdba509
......@@ -46,7 +46,15 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
return tensor
return _legacy_C_ops.c_identity(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)
@staticmethod
def backward(ctx, dy):
......@@ -249,7 +257,15 @@ def _mp_allreduce(
@staticmethod
def backward(ctx, dy):
return dy
return _legacy_C_ops.c_identity(
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)
return mp_allreduce_eager.apply(
tensor, group, use_calc_stream, use_model_parallel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册