未验证 提交 201480d5 编写于 作者: S ShenLiang 提交者: GitHub

fix bug in c_split (#56917)

上级 79bfb184
......@@ -48,6 +48,38 @@ class c_identity_eager(PyLayer):
return dy
class c_split_eager(PyLayer):
@staticmethod
def forward(ctx, tensor, group, rank, nranks):
ctx.group = group
ctx.nranks = nranks
return _legacy_C_ops.c_split(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'rank',
rank,
'nranks',
nranks,
'use_model_parallel',
True,
)
@staticmethod
def backward(ctx, dy):
group = ctx.group
out_shape = dy.shape
out_shape[0] = out_shape[0] * ctx.nranks
out = paddle.empty(out_shape, dtype=dy.dtype)
group.process_group.all_gather_into_tensor_on_calc_stream(
out,
dy,
)
return out
def _c_identity(tensor, group=None, skip_c_identity_dynamic=False):
"""
Return a copy of the tensor, mainly used with model parallel.
......@@ -179,19 +211,7 @@ def _c_split(tensor, group=None):
)
if in_dynamic_mode():
return _legacy_C_ops.c_split(
tensor,
'use_calc_stream',
True,
'ring_id',
ring_id,
'rank',
rank,
'nranks',
nranks,
'use_model_parallel',
True,
)
return c_split_eager.apply(tensor, group, rank, nranks)
else:
op_type = 'c_split'
helper = LayerHelper(op_type, **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册