未验证 提交 cbf26bb1 编写于 作者: H Haohongxiang 提交者: GitHub

fix perf bugs of mp in eager dygraph (#45646)

上级 c7c4cec7
...@@ -1402,9 +1402,9 @@ def _mp_allreduce(tensor, ...@@ -1402,9 +1402,9 @@ def _mp_allreduce(tensor,
""" """
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id
if in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group
assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op) assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
...@@ -1412,14 +1412,19 @@ def _mp_allreduce(tensor, ...@@ -1412,14 +1412,19 @@ def _mp_allreduce(tensor,
class mp_allreduce_eager(PyLayer): class mp_allreduce_eager(PyLayer):
@staticmethod @staticmethod
def forward(ctx, tensor, use_calc_stream, ring_id, def forward(ctx, tensor, group, use_calc_stream,
use_model_parallel): use_model_parallel):
ctx.ring_id = ring_id ctx.ring_id = group.id
return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream, if use_calc_stream:
'ring_id', ring_id, op_type = _get_reduce_op(op, "_mp_allreduce")
"use_model_parallel", group.process_group.allreduce_on_calc_stream(
use_model_parallel) tensor, op_type)
return tensor
else:
return _legacy_C_ops.c_allreduce_sum_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id',
ring_id, "use_model_parallel", use_model_parallel)
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
...@@ -1427,10 +1432,11 @@ def _mp_allreduce(tensor, ...@@ -1427,10 +1432,11 @@ def _mp_allreduce(tensor,
'ring_id', ctx.ring_id, 'ring_id', ctx.ring_id,
'use_model_parallel', True) 'use_model_parallel', True)
return mp_allreduce_eager.apply(tensor, use_calc_stream, ring_id, return mp_allreduce_eager.apply(tensor, group, use_calc_stream,
use_model_parallel) use_model_parallel)
elif _in_legacy_dygraph(): ring_id = 0 if group is None else group.id
if _in_legacy_dygraph():
if op == ReduceOp.SUM: if op == ReduceOp.SUM:
return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream', return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id', use_calc_stream, 'ring_id',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册