未验证 提交 c24e7fe1 编写于 作者: W Wen Sun 提交者: GitHub

Migrate collective communication checks to PHI (#49754)

* refactor: migrate comm checks

* refactor: add check in comm context

* feat: add gloo static check

* refactor: add place param in static check
上级 69d01eb9
......@@ -18,7 +18,7 @@ endif()
if(WITH_NCCL OR WITH_RCCL)
cc_library(
process_group_nccl
SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc
SRCS process_group_nccl.cc nccl_tools.cc common.cc
DEPS process_group
tcp_store
place
......@@ -26,7 +26,9 @@ if(WITH_NCCL OR WITH_RCCL)
collective_helper
device_context
${DEVICE_EVENT_LIBS}
dense_tensor)
dense_tensor
comm_static_check
nccl_dynamic_check)
endif()
if(WITH_XPU_BKCL)
......
......@@ -23,20 +23,11 @@
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/nccl_tools.h"
#endif
#ifdef PADDLE_WITH_RCCL
#include "paddle/phi/backends/dynload/rccl.h"
#else
#include "paddle/phi/backends/dynload/nccl.h"
#endif
namespace paddle {
namespace distributed {
......
......@@ -13,9 +13,11 @@
// limitations under the License.
#pragma once
#include <chrono>
#include <cstdint>
#include <vector>
#include "paddle/phi/common/place.h"
namespace paddle {
......
add_subdirectory(check)
add_subdirectory(store)
set(COMM_CONTEXT_MANAGER_DEPS tcp_store)
......@@ -6,7 +7,8 @@ if(WITH_NCCL OR WITH_RCCL)
cc_library(
nccl_comm_context
SRCS nccl_comm_context.cc
DEPS dense_tensor)
DEPS dense_tensor comm_static_check nccl_dynamic_check)
list(APPEND COMM_CONTEXT_MANAGER_DEPS nccl_comm_context)
endif()
......@@ -19,7 +21,7 @@ if(WITH_GLOO)
cc_library(
gloo_comm_context
SRCS gloo_comm_context.cc
DEPS gloo_utils)
DEPS gloo_utils comm_static_check)
list(APPEND COMM_CONTEXT_MANAGER_DEPS gloo_comm_context gloo_store)
endif()
......
cc_library(
comm_static_check
SRCS static_check.cc
DEPS place dense_tensor enforce)
if(WITH_NCCL OR WITH_RCCL)
cc_library(
nccl_dynamic_check
SRCS nccl_dynamic_check.cc
DEPS dense_tensor)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,21 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/check.h"
#include "paddle/phi/core/distributed/check/nccl_dynamic_check.h"
#include "paddle/fluid/distributed/collective/nccl_tools.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#ifdef PADDLE_WITH_HIP
#if defined(PADDLE_WITH_RCCL)
#include <hip/hip_runtime.h>
#include "paddle/phi/backends/dynload/rccl.h"
#define gpuMalloc hipMalloc
#define gpuMemcpy hipMemcpy
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
#define gpuFree hipFree
#else
#include <cuda_runtime.h>
#include "paddle/phi/backends/dynload/nccl.h"
#define gpuMalloc cudaMalloc
#define gpuMemcpy cudaMemcpy
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
......@@ -34,141 +40,9 @@
#define gpuFree cudaFree
#endif
namespace paddle {
namespace phi {
namespace distributed {
// static checks
void CommStaticCheck::CheckRank(int rank, int world_size) {
PADDLE_ENFORCE_GE(rank,
0,
phi::errors::InvalidArgument(
"Rank should be greater than or equal to 0."));
PADDLE_ENFORCE_LT(
rank,
world_size,
phi::errors::InvalidArgument("Rank is out of the process group."));
}
void CommStaticCheck::CheckPlace(const phi::DenseTensor& tensor) {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(tensor.place()),
true,
platform::errors::InvalidArgument("Tensor should be in GPU place."));
}
void CommStaticCheck::CheckPlace(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor) {
CheckPlace(out_tensor);
CheckPlace(in_tensor);
PADDLE_ENFORCE_EQ(
out_tensor.place(),
in_tensor.place(),
phi::errors::InvalidArgument(
"Input and output tensors should be on the same place."));
}
void CommStaticCheck::CheckDataType(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor) {
PADDLE_ENFORCE_EQ(
out_tensor.dtype(),
in_tensor.dtype(),
phi::errors::InvalidArgument(
"Input and output tensors should have the same data type."));
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor) {
PADDLE_ENFORCE_GT(
tensor.numel(),
0,
phi::errors::InvalidArgument("Size of tensor should be greater than 0."));
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int out_size_factor,
int in_size_factor) {
CheckShape(out_tensor);
CheckShape(in_tensor);
int64_t out_size = out_tensor.numel(), in_size = in_tensor.numel();
PADDLE_ENFORCE_EQ(
out_size * out_size_factor,
in_size * in_size_factor,
phi::errors::InvalidArgument(
"Input and output tensors should have matching sizes."));
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size,
int out_size_factor,
int in_size_factor) {
CheckRank(dst_rank, world_size);
CheckRank(cur_rank, world_size);
CheckPlace(out_tensor, in_tensor);
CheckDataType(out_tensor, in_tensor);
if (dst_rank == cur_rank) {
CheckShape(out_tensor, in_tensor, out_size_factor, in_size_factor);
} else {
CheckShape(out_tensor);
CheckShape(in_tensor);
}
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor,
int rank,
int world_size) {
CheckPlace(tensor);
CheckRank(rank, world_size);
}
void CommStaticCheck::SameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size) {
CheckShape(out_tensor,
in_tensor,
dst_rank,
cur_rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ 1);
}
void CommStaticCheck::ScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size) {
CheckShape(out_tensor,
in_tensor,
dst_rank,
cur_rank,
world_size,
/*out_size_factor*/ world_size,
/*in_size_factor*/ 1);
}
void CommStaticCheck::GatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size) {
CheckShape(out_tensor,
in_tensor,
dst_rank,
cur_rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ world_size);
}
// dynamic checks
void CommDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
void NCCLDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
int64_t dtype) {
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(tensor.dtype()),
......@@ -177,7 +51,7 @@ void CommDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
"Tensors in communication are expected to have the same data type."));
}
void CommDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
void NCCLDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm) {
......@@ -188,7 +62,7 @@ void CommDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
PADDLE_ENFORCE_GPU_SUCCESS(
gpuMemcpy(dtype_device, &dtype_host, kSize, gpuMemcpyHostToDevice));
NCCL_CHECK(phi::dynload::ncclBroadcast(dtype_device,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBroadcast(dtype_device,
dtype_device,
kSize,
ncclInt64,
......@@ -207,7 +81,7 @@ void CommDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(dtype_device));
}
void CommDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
int64_t shape) {
PADDLE_ENFORCE_EQ(
tensor.numel(),
......@@ -216,7 +90,7 @@ void CommDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
"Tensors in communication are expected to have matching sizes."));
}
void CommDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm) {
......@@ -230,7 +104,7 @@ void CommDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
PADDLE_ENFORCE_GPU_SUCCESS(
gpuMemcpy(shape_device, &shape_host, kSize, gpuMemcpyHostToDevice));
NCCL_CHECK(phi::dynload::ncclBroadcast(shape_device,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBroadcast(shape_device,
shape_device,
kSize,
ncclInt64,
......@@ -249,7 +123,7 @@ void CommDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(shape_device));
}
void CommDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor,
void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& in_size_each_rank,
int cur_rank,
......@@ -268,7 +142,7 @@ void CommDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor,
PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy(
in_shape_device, &in_shape_host, kSize, gpuMemcpyHostToDevice));
NCCL_CHECK(phi::dynload::ncclReduce(in_shape_device,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclReduce(in_shape_device,
in_shape_device,
kSize,
ncclInt64,
......@@ -285,6 +159,5 @@ void CommDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor,
PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(in_shape_device));
}
}
} // namespace distributed
} // namespace paddle
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <vector>
#include "paddle/phi/backends/gpu/forwards.h"
#if defined(PADDLE_WITH_RCCL)
using gpuStream_t = hipStream_t;
#else
using gpuStream_t = cudaStream_t;
#endif
namespace phi {
// forward declaration
class DenseTensor;
namespace distributed {
struct NCCLDynamicCheck {
static void CheckDataType(const phi::DenseTensor& tensor, int64_t dtype);
static void CheckDataType(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm);
static void CheckShape(const phi::DenseTensor& tensor, int64_t shape);
static void CheckShape(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm);
static void CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& in_size_each_rank,
int cur_rank,
int world_size,
ncclComm_t comm);
private:
// `0` represents default stream for both cuda & hip
static constexpr gpuStream_t kDefaultStream = 0;
};
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/check/static_check.h"
#include <cstdlib>
#include <cstring>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace phi {
namespace distributed {
void CommStaticCheck::CheckRank(int rank, int world_size) {
PADDLE_ENFORCE_GE(rank,
0,
phi::errors::InvalidArgument(
"Rank should be greater than or equal to 0."));
PADDLE_ENFORCE_LT(
rank,
world_size,
phi::errors::InvalidArgument("Rank is out of the process group."));
}
void CommStaticCheck::CheckPlace(const phi::DenseTensor& tensor,
phi::AllocationType place) {
PADDLE_ENFORCE_EQ(
tensor.place().GetType(),
place,
phi::errors::InvalidArgument("Tensor should be in backend's place."));
}
void CommStaticCheck::CheckPlace(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
phi::AllocationType place) {
CheckPlace(out_tensor, place);
CheckPlace(in_tensor, place);
PADDLE_ENFORCE_EQ(
out_tensor.place(),
in_tensor.place(),
phi::errors::InvalidArgument(
"Input and output tensors should be on the same place."));
}
void CommStaticCheck::CheckDataType(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor) {
PADDLE_ENFORCE_EQ(
out_tensor.dtype(),
in_tensor.dtype(),
phi::errors::InvalidArgument(
"Input and output tensors should have the same data type."));
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor) {
PADDLE_ENFORCE_GT(
tensor.numel(),
0,
phi::errors::InvalidArgument("Size of tensor should be greater than 0."));
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int out_size_factor,
int in_size_factor) {
CheckShape(out_tensor);
CheckShape(in_tensor);
int64_t out_size = out_tensor.numel(), in_size = in_tensor.numel();
PADDLE_ENFORCE_EQ(
out_size * out_size_factor,
in_size * in_size_factor,
phi::errors::InvalidArgument(
"Input and output tensors should have matching sizes."));
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size,
int out_size_factor,
int in_size_factor,
phi::AllocationType place) {
CheckRank(dst_rank, world_size);
CheckRank(cur_rank, world_size);
CheckPlace(out_tensor, in_tensor, place);
CheckDataType(out_tensor, in_tensor);
if (dst_rank == cur_rank) {
CheckShape(out_tensor, in_tensor, out_size_factor, in_size_factor);
} else {
CheckShape(out_tensor);
CheckShape(in_tensor);
}
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor,
int rank,
int world_size,
phi::AllocationType place) {
CheckPlace(tensor, place);
CheckRank(rank, world_size);
}
void CommStaticCheck::SameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size,
phi::AllocationType place) {
CheckShape(out_tensor,
in_tensor,
dst_rank,
cur_rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ 1,
place);
}
void CommStaticCheck::ScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size,
phi::AllocationType place) {
CheckShape(out_tensor,
in_tensor,
dst_rank,
cur_rank,
world_size,
/*out_size_factor*/ world_size,
/*in_size_factor*/ 1,
place);
}
void CommStaticCheck::GatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size,
phi::AllocationType place) {
CheckShape(out_tensor,
in_tensor,
dst_rank,
cur_rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ world_size,
place);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,32 +14,22 @@
#pragma once
#include <cstdint>
#include <vector>
#include "paddle/phi/common/place.h"
#include "paddle/phi/backends/gpu/forwards.h"
#ifdef PADDLE_WITH_HIP
using gpuStream_t = hipStream_t;
#else
using gpuStream_t = cudaStream_t;
#endif
// forward declarations
namespace phi {
// forward declaration
class DenseTensor;
}
namespace paddle {
namespace distributed {
struct CommStaticCheck {
static void CheckRank(int rank, int world_size);
static void CheckPlace(const phi::DenseTensor& tensor);
static void CheckPlace(const phi::DenseTensor& tensor,
phi::AllocationType place);
static void CheckPlace(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor);
const phi::DenseTensor& in_tensor,
phi::AllocationType place);
static void CheckDataType(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor);
......@@ -57,59 +47,39 @@ struct CommStaticCheck {
int cur_rank,
int world_size,
int out_size_factor,
int in_size_factor);
int in_size_factor,
phi::AllocationType place = phi::AllocationType::GPU);
// for p2p
static void CheckShape(const phi::DenseTensor& tensor,
int rank,
int world_size);
int world_size,
phi::AllocationType place = phi::AllocationType::GPU);
// for collective
static void SameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size);
int world_size,
phi::AllocationType place = phi::AllocationType::GPU);
static void ScatterLikeShape(const phi::DenseTensor& out_tensor,
static void ScatterLikeShape(
const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size);
int world_size,
phi::AllocationType place = phi::AllocationType::GPU);
static void GatherLikeShape(const phi::DenseTensor& out_tensor,
static void GatherLikeShape(
const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size);
};
struct CommDynamicCheck {
static void CheckDataType(const phi::DenseTensor& tensor, int64_t dtype);
static void CheckDataType(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm);
static void CheckShape(const phi::DenseTensor& tensor, int64_t shape);
static void CheckShape(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm);
static void CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& in_size_each_rank,
int cur_rank,
int world_size,
ncclComm_t comm);
private:
// `0` represents default stream for both cuda & hip
static constexpr gpuStream_t kDefaultStream = 0;
phi::AllocationType place = phi::AllocationType::GPU);
};
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -13,7 +13,7 @@
// limitations under the License.
#if defined(PADDLE_WITH_GLOO)
#include "gloo/rendezvous/prefix_store.h"
#include <gloo/rendezvous/prefix_store.h>
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "paddle/phi/core/distributed/gloo_utils.h"
......
......@@ -14,11 +14,12 @@
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "gloo/broadcast.h"
#include "gloo/types.h"
#include <gloo/broadcast.h>
#include <gloo/types.h>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/distributed/gloo_utils.h"
#include "paddle/phi/core/enforce.h"
......@@ -38,6 +39,13 @@ GlooCommContext::GlooCommContext(
void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root) {
// gloo only uses CPU now
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CPU);
gloo::BroadcastOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
......
......@@ -13,11 +13,11 @@
// limitations under the License.
#pragma once
#include <memory>
#include <gloo/rendezvous/context.h>
#include <gloo/rendezvous/store.h>
#include <gloo/transport/tcp/device.h>
#include "gloo/rendezvous/context.h"
#include "gloo/rendezvous/store.h"
#include "gloo/transport/tcp/device.h"
#include <memory>
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/macros.h"
......
......@@ -13,9 +13,9 @@
// limitations under the License.
#ifdef _WIN32
#include <gloo/common/win.h>
#include <winsock2.h>
#include <ws2tcpip.h>
#include "gloo/common/win.h"
#else
#include <netdb.h>
#include <sys/socket.h>
......
......@@ -14,13 +14,13 @@
#pragma once
#include <gloo/transport/tcp/device.h>
#include <gloo/types.h>
#include <climits>
#include <memory>
#include <string>
#include "gloo/transport/tcp/device.h"
#include "gloo/types.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
......
......@@ -14,12 +14,17 @@
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/check/nccl_dynamic_check.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace distributed {
// set this flag to `true` and recompile to enable dynamic checks
constexpr bool FLAGS_enable_nccl_dynamic_check = false;
NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id)
: CommContext(rank, size) {
PADDLE_ENFORCE_GPU_SUCCESS(
......@@ -30,13 +35,22 @@ void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
gpuStream_t stream) {
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
if (FLAGS_enable_nccl_dynamic_check) {
NCCLDynamicCheck::CheckShape(*out_tensor, root, rank_, nccl_comm_);
}
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclBroadcast(in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
ToNCCLDataType(in_tensor.type()),
root,
nccl_comm_,
stream);
stream));
}
} // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册