未验证 提交 372505be 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Fix bugs of mp in eager mode (#46303) (#46396)

* fix bugs of mp

* fix bugs of mp

* update

* update

* fix bug
上级 083853cd
...@@ -265,7 +265,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> { ...@@ -265,7 +265,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
auto map = distributed::ProcessGroupMapFromGid::getInstance(); auto map = distributed::ProcessGroupMapFromGid::getInstance();
distributed::ProcessGroup* pg = map->get(rid); distributed::ProcessGroup* pg = map->get(rid);
distributed::AllreduceOptions opts; distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM; opts.reduce_op = distributed::ReduceOp::MAX;
// allocate memory on device. // allocate memory on device.
softmax->mutable_data<T>(place); softmax->mutable_data<T>(place);
...@@ -348,6 +348,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> { ...@@ -348,6 +348,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
in_out.clear(); in_out.clear();
in_out.push_back(predicted_logits); in_out.push_back(predicted_logits);
opts.reduce_op = distributed::ReduceOp::SUM;
pg->AllReduce(in_out, in_out, opts)->Synchronize(); pg->AllReduce(in_out, in_out, opts)->Synchronize();
// step 4, obtain exp(logit) // step 4, obtain exp(logit)
...@@ -364,6 +365,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> { ...@@ -364,6 +365,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
in_out.clear(); in_out.clear();
in_out.push_back(sum_exp_logits); in_out.push_back(sum_exp_logits);
opts.reduce_op = distributed::ReduceOp::SUM;
pg->AllReduce(in_out, in_out, opts)->Synchronize(); pg->AllReduce(in_out, in_out, opts)->Synchronize();
auto eigen_loss = math::EigenMatrix<T>::From(loss_2d); auto eigen_loss = math::EigenMatrix<T>::From(loss_2d);
......
...@@ -107,13 +107,26 @@ def _broadcast_data_help(data, shape, dtype, hcg): ...@@ -107,13 +107,26 @@ def _broadcast_data_help(data, shape, dtype, hcg):
group=model_parallel_group, group=model_parallel_group,
use_calc_stream=True) use_calc_stream=True)
if mp_rank != 0:
if in_dygraph_mode():
data._clear_data()
input_data._share_buffer_to(data)
else:
data.value().get_tensor()._clear()
data.value().get_tensor()._share_data_with(
input_data.value().get_tensor())
def broadcast_input_data(hcg, *inputs, **kwargs): def broadcast_input_data(hcg, *inputs, **kwargs):
cur_device = paddle.get_device() cur_device = paddle.get_device()
for v in inputs: for v in inputs:
if isinstance(v, (core.VarBase, core.eager.Tensor)): if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad(): with framework.no_grad():
v = v.cuda() if "gpu" in cur_device else v if "gpu" in cur_device and in_dygraph_mode() \
and not v.place.is_gpu_place():
v_gpu = v.cuda(int(cur_device.split(":")[1]))
v._clear_data()
v_gpu._share_buffer_to(v)
_broadcast_data_help(v, v.shape, v.dtype, hcg) _broadcast_data_help(v, v.shape, v.dtype, hcg)
else: else:
logger.error("it doesn't support data type {}".format(type(v))) logger.error("it doesn't support data type {}".format(type(v)))
...@@ -121,7 +134,11 @@ def broadcast_input_data(hcg, *inputs, **kwargs): ...@@ -121,7 +134,11 @@ def broadcast_input_data(hcg, *inputs, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, (core.VarBase, core.eager.Tensor)): if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad(): with framework.no_grad():
v = v.cuda() if "gpu" in cur_device else v if "gpu" in cur_device and in_dygraph_mode() \
and not v.place.is_gpu_place():
v_gpu = v.cuda(int(cur_device.split(":")[1]))
v._clear_data()
v_gpu._share_buffer_to(v)
_broadcast_data_help(v, v.shape, v.dtype, hcg) _broadcast_data_help(v, v.shape, v.dtype, hcg)
kwargs[k] = v kwargs[k] = v
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册