未验证 提交 cbf26bb1 编写于 作者: H Haohongxiang 提交者: GitHub

fix perf bugs of mp in eager dygraph (#45646)

上级 c7c4cec7
......@@ -1402,9 +1402,9 @@ def _mp_allreduce(tensor,
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
if in_dygraph_mode():
group = _get_default_group() if group is None else group
assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)
from paddle.autograd import PyLayer
......@@ -1412,14 +1412,19 @@ def _mp_allreduce(tensor,
class mp_allreduce_eager(PyLayer):
@staticmethod
def forward(ctx, tensor, use_calc_stream, ring_id,
def forward(ctx, tensor, group, use_calc_stream,
use_model_parallel):
ctx.ring_id = ring_id
return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream,
'ring_id', ring_id,
"use_model_parallel",
use_model_parallel)
ctx.ring_id = group.id
if use_calc_stream:
op_type = _get_reduce_op(op, "_mp_allreduce")
group.process_group.allreduce_on_calc_stream(
tensor, op_type)
return tensor
else:
return _legacy_C_ops.c_allreduce_sum_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id',
ring_id, "use_model_parallel", use_model_parallel)
@staticmethod
def backward(ctx, dy):
......@@ -1427,10 +1432,11 @@ def _mp_allreduce(tensor,
'ring_id', ctx.ring_id,
'use_model_parallel', True)
return mp_allreduce_eager.apply(tensor, use_calc_stream, ring_id,
return mp_allreduce_eager.apply(tensor, group, use_calc_stream,
use_model_parallel)
elif _in_legacy_dygraph():
ring_id = 0 if group is None else group.id
if _in_legacy_dygraph():
if op == ReduceOp.SUM:
return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册