未验证 提交 078e8c78 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Fix Perf of FusedFeedForward and FusedAttention with AllReduce (#46780)

上级 2e217dbb
......@@ -30,7 +30,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -50,13 +50,15 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true);
task->Wait();
} else {
auto dtype = platform::ToNCCLDataType(
......
......@@ -23,7 +23,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -43,13 +43,15 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup* pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL*>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true);
task->Wait();
} else {
auto dtype = platform::ToNCCLDataType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册