未验证 提交 e7711592 编写于 作者: W Wen Sun 提交者: GitHub

Add dynamic checks for collective communication on NCCL (#48915)

* chore: unify `SingleTensor`

* feat: dynamic check
上级 e66dbc38
......@@ -21,7 +21,7 @@ endif()
if(WITH_NCCL OR WITH_RCCL)
cc_library(
processgroup_nccl
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc static_check.cc
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc check.cc
DEPS processgroup
processgroup_stream
place
......
......@@ -14,7 +14,7 @@
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
......
......@@ -21,42 +21,29 @@
#include <hip/hip_runtime.h>
#endif
#include <error.h>
#include <string>
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#include "paddle/phi/backends/dynload/rccl.h"
#else
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/phi/backends/dynload/nccl.h"
#endif
#include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/variant.h"
namespace paddle {
namespace distributed {
#define NCCL_CHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
platform::dynload::ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
#define NCCL_CHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
phi::dynload::ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
......
......@@ -16,7 +16,7 @@
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/static_check.h"
#include "paddle/fluid/distributed/collective/check.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
......@@ -25,6 +25,8 @@
DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);
// set this flag to `true` and recompile to enable dynamic checks
constexpr bool FLAGS_enable_nccl_dynamic_check = false;
constexpr int64_t kWaitBlockTImeout = 10;
namespace paddle {
......@@ -89,12 +91,10 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
: ProcessGroupStream(rank, size, gid), store_(store) {}
void ProcessGroupNCCL::GroupStart() {
NCCL_CHECK(platform::dynload::ncclGroupStart());
NCCL_CHECK(phi::dynload::ncclGroupStart());
}
void ProcessGroupNCCL::GroupEnd() {
NCCL_CHECK(platform::dynload::ncclGroupEnd());
}
void ProcessGroupNCCL::GroupEnd() { NCCL_CHECK(phi::dynload::ncclGroupEnd()); }
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place) const {
......@@ -146,7 +146,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllGather(
if (FLAGS_enable_nccl_dynamic_check) {
CommDynamicCheck::CheckShape(*out_tensor,
/*root_rank*/ 0,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclAllGather(
in_tensor_maybe_partial.data(),
out_tensor->data(),
in_tensor_maybe_partial.numel(),
......@@ -173,7 +179,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllReduce(
if (FLAGS_enable_nccl_dynamic_check) {
CommDynamicCheck::CheckShape(*out_tensor,
/*root_rank*/ 0,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclAllReduce(
in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
......@@ -219,9 +231,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
// NOTE: Since `all_to_all` needs other processes's participation, it cannot
// NOTE: Since `all_to_all` needs other processes' participation, it cannot
// simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks in debug mode.
// shape check. Its shape check will be done by dynamic checks with
// FLAGS_enable_nccl_dynamic_check.
CommStaticCheck::CheckShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
......@@ -231,6 +244,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
/*in_size_factor*/ 0);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
CommDynamicCheck::CheckShape(
*out_tensor, in_tensor, in_size_each_rank, rank_, size_, comm);
}
int64_t in_row_size = in_tensor.numel() / in_dim[0],
out_row_size = out_tensor->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
......@@ -240,7 +257,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
NCCL_CHECK(platform::dynload::ncclSend(
NCCL_CHECK(phi::dynload::ncclSend(
input_partial.data(),
in_numel,
platform::ToNCCLDataType(input_partial.dtype()),
......@@ -251,7 +268,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
NCCL_CHECK(platform::dynload::ncclRecv(
NCCL_CHECK(phi::dynload::ncclRecv(
output_partial.data(),
out_numel,
platform::ToNCCLDataType(output_partial.dtype()),
......@@ -304,7 +321,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int root = opts.source_rank + opts.source_root;
NCCL_CHECK(platform::dynload::ncclBroadcast(
if (FLAGS_enable_nccl_dynamic_check) {
CommDynamicCheck::CheckShape(*out_tensor, root, rank_, comm);
}
NCCL_CHECK(phi::dynload::ncclBroadcast(
in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
......@@ -332,7 +352,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduce(
if (FLAGS_enable_nccl_dynamic_check) {
CommDynamicCheck::CheckShape(*out_tensor,
/*root_rank*/ opts.root_rank,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclReduce(
in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
......@@ -361,7 +387,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduceScatter(
if (FLAGS_enable_nccl_dynamic_check) {
CommDynamicCheck::CheckShape(*out_tensor,
/*root_rank*/ 0,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclReduceScatter(
in_tensor.data(),
out_tensor->data(),
out_tensor->numel(),
......@@ -389,6 +421,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
CommDynamicCheck::CheckShape(*out_tensor,
/*root_rank*/ opts.root_rank,
rank_,
comm);
}
int64_t numel = in_tensor.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
......@@ -396,7 +434,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
GroupStart();
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(in_tensor, offset, numel);
NCCL_CHECK(platform::dynload::ncclSend(
NCCL_CHECK(phi::dynload::ncclSend(
partial_tensor.data(),
numel,
platform::ToNCCLDataType(partial_tensor.dtype()),
......@@ -405,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
stream));
offset += numel;
}
NCCL_CHECK(platform::dynload::ncclRecv(
NCCL_CHECK(phi::dynload::ncclRecv(
out_tensor->data(),
numel,
platform::ToNCCLDataType(out_tensor->dtype()),
......@@ -414,7 +452,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
stream));
GroupEnd();
} else {
NCCL_CHECK(platform::dynload::ncclRecv(
NCCL_CHECK(phi::dynload::ncclRecv(
out_tensor->data(),
numel,
platform::ToNCCLDataType(out_tensor->dtype()),
......@@ -443,16 +481,22 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
tensor = &partial_tensor;
}
CommStaticCheck::SingleTensor(*tensor, rank_, size_);
CommStaticCheck::CheckShape(*tensor, rank_, size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclRecv(
tensor->data(),
tensor->numel(),
platform::ToNCCLDataType(tensor->dtype()),
src_rank,
comm,
stream));
if (FLAGS_enable_nccl_dynamic_check) {
CommDynamicCheck::CheckShape(*tensor,
/*root_rank*/ src_rank,
rank_,
comm);
}
NCCL_CHECK(
phi::dynload::ncclRecv(tensor->data(),
tensor->numel(),
platform::ToNCCLDataType(tensor->dtype()),
src_rank,
comm,
stream));
},
*tensor,
CommType::RECV,
......@@ -471,10 +515,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
CommStaticCheck::SingleTensor(tensor_maybe_partial, rank_, size_);
CommStaticCheck::CheckShape(tensor_maybe_partial, rank_, size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclSend(
if (FLAGS_enable_nccl_dynamic_check) {
CommDynamicCheck::CheckShape(tensor_maybe_partial,
/*root_rank*/ rank_,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclSend(
tensor_maybe_partial.data(),
tensor_maybe_partial.numel(),
platform::ToNCCLDataType(tensor_maybe_partial.dtype()),
......@@ -520,7 +570,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
ncclUniqueId nccl_id;
if (rank_ == 0) {
NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
NCCL_CHECK(phi::dynload::ncclGetUniqueId(&nccl_id));
}
BroadcastUniqueNCCLID(&nccl_id);
......@@ -532,7 +582,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
ncclComm_t nccl_comm;
NCCL_CHECK(platform::dynload::ncclCommInitRank(
NCCL_CHECK(phi::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
comm_ctx->set_nccl_comm(nccl_comm);
......@@ -589,6 +639,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
task->UpdateWaitChain(*comm_ctx);
}
if (FLAGS_enable_nccl_dynamic_check) {
task->SetBlockCPUInWait();
task->Wait();
}
return task;
}
......@@ -633,7 +687,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
ncclUniqueId nccl_id;
if (rank_ == 0) {
NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
NCCL_CHECK(phi::dynload::ncclGetUniqueId(&nccl_id));
}
BroadcastUniqueNCCLID(&nccl_id);
......@@ -654,7 +708,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
dev_ctx[i].reset(new phi::GPUContext(places[i]));
ncclComm_t nccl_comm;
NCCL_CHECK(platform::dynload::ncclCommInitRank(
NCCL_CHECK(phi::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
dev_ctx[i]->set_nccl_comm(nccl_comm);
dev_ctx_raw[i] = dev_ctx[i].get();
......@@ -791,7 +845,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllReduce(
return phi::dynload::ncclAllReduce(
input.data(),
output.data(),
input.numel(),
......@@ -821,7 +875,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const gpuStream_t& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
return platform::dynload::ncclBroadcast(
return phi::dynload::ncclBroadcast(
input.data(),
output.data(),
input.numel(),
......@@ -871,13 +925,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return platform::dynload::ncclSend(
input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
return phi::dynload::ncclSend(input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
dst_rank,
CommType::SEND);
......@@ -894,13 +947,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return platform::dynload::ncclRecv(
output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
return phi::dynload::ncclRecv(output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
src_rank,
CommType::RECV);
......@@ -925,7 +977,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
return phi::dynload::ncclAllGather(
input.data(),
output.data(),
input.numel(),
......@@ -994,14 +1046,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
size_t offset = 0;
GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRecv(
GetPointerByOffset(output.data(), offset, input.dtype()),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
......@@ -1030,15 +1082,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclReduce(input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
},
CommType::REDUCE);
}
......@@ -1066,7 +1118,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
if (rank_ == opts.root_rank) {
GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
......@@ -1075,22 +1127,22 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
stream));
offset += input.numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclRecv(output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclRecv(output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
}
},
CommType::SCATTER);
......
......@@ -33,9 +33,9 @@
#endif
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#include "paddle/phi/backends/dynload/rccl.h"
#elif PADDLE_WITH_NCCL
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/phi/backends/dynload/nccl.h"
#endif
namespace paddle {
......
......@@ -12,16 +12,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/static_check.h"
#include "paddle/fluid/distributed/collective/check.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
#ifdef PADDLE_WITH_HIP
#define gpuMalloc hipMalloc
#define gpuMemcpy hipMemcpy
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
#define gpuFree hipFree
#else
#define gpuMalloc cudaMalloc
#define gpuMemcpy cudaMemcpy
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
#define gpuFree cudaFree
#endif
namespace paddle {
namespace distributed {
// static checks
void CommStaticCheck::CheckRank(int rank, int world_size) {
PADDLE_ENFORCE_GE(rank,
0,
......@@ -102,9 +118,9 @@ void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor,
}
}
void CommStaticCheck::SingleTensor(const phi::DenseTensor& tensor,
int rank,
int world_size) {
void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor,
int rank,
int world_size) {
CheckPlace(tensor);
CheckRank(rank, world_size);
}
......@@ -151,5 +167,124 @@ void CommStaticCheck::GatherLikeShape(const phi::DenseTensor& out_tensor,
/*in_size_factor*/ world_size);
}
// dynamic checks
void CommDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
int64_t dtype) {
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(tensor.dtype()),
dtype,
phi::errors::InvalidArgument(
"Tensors in communication are expected to have the same data type."));
}
void CommDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm) {
constexpr int kSize = sizeof(int64_t);
int64_t dtype_host = static_cast<int64_t>(tensor.dtype());
int64_t* dtype_device;
PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&dtype_device, kSize));
PADDLE_ENFORCE_GPU_SUCCESS(
gpuMemcpy(dtype_device, &dtype_host, kSize, gpuMemcpyHostToDevice));
NCCL_CHECK(phi::dynload::ncclBroadcast(dtype_device,
dtype_device,
kSize,
ncclInt64,
root_rank,
comm,
kDefaultStream));
if (root_rank == cur_rank) {
VLOG(3) << "Dynamic check broadcast metadata, dtype: " << dtype_host;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
gpuMemcpy(&dtype_host, dtype_device, kSize, gpuMemcpyDeviceToHost));
VLOG(3) << "Dynamic check recv metadata, dtype: " << dtype_host;
CheckDataType(tensor, dtype_host);
}
PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(dtype_device));
}
void CommDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
int64_t shape) {
PADDLE_ENFORCE_EQ(
tensor.numel(),
shape,
phi::errors::InvalidArgument(
"Tensors in communication are expected to have matching sizes."));
}
void CommDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm) {
CheckDataType(tensor, root_rank, cur_rank, comm);
constexpr int kSize = sizeof(int64_t);
int64_t shape_host = tensor.numel();
int64_t* shape_device;
PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&shape_device, kSize));
PADDLE_ENFORCE_GPU_SUCCESS(
gpuMemcpy(shape_device, &shape_host, kSize, gpuMemcpyHostToDevice));
NCCL_CHECK(phi::dynload::ncclBroadcast(shape_device,
shape_device,
kSize,
ncclInt64,
root_rank,
comm,
kDefaultStream));
if (root_rank == cur_rank) {
VLOG(3) << "Dynamic check broadcast metadata, shape: " << shape_host;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
gpuMemcpy(&shape_host, shape_device, kSize, gpuMemcpyDeviceToHost));
VLOG(3) << "Dynamic check recv metadata, shape: " << shape_host;
CheckShape(tensor, shape_host);
}
PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(shape_device));
}
void CommDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& in_size_each_rank,
int cur_rank,
int world_size,
ncclComm_t comm) {
CheckDataType(out_tensor, /*root_rank*/ 0, cur_rank, comm);
CheckDataType(in_tensor, /*root_rank*/ 0, cur_rank, comm);
constexpr int kSize = sizeof(int64_t);
int64_t in_row_size = in_tensor.numel() / in_tensor.dims()[0];
for (int rank = 0; rank < world_size; ++rank) {
int64_t in_shape_host = in_size_each_rank[rank] * in_row_size;
int64_t* in_shape_device;
PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&in_shape_device, kSize));
PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy(
in_shape_device, &in_shape_host, kSize, gpuMemcpyHostToDevice));
NCCL_CHECK(phi::dynload::ncclReduce(in_shape_device,
in_shape_device,
kSize,
ncclInt64,
ncclSum,
rank,
comm,
kDefaultStream));
if (rank == cur_rank) {
PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy(
&in_shape_host, in_shape_device, kSize, gpuMemcpyDeviceToHost));
VLOG(3) << "Dynamic check recv metadata, shape: " << in_shape_host;
CheckShape(out_tensor, in_shape_host);
}
PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(in_shape_device));
}
}
} // namespace distributed
} // namespace paddle
......@@ -14,7 +14,18 @@
#pragma once
// forward declaration to reduce deps
#include <cstdint>
#include <vector>
#include "paddle/phi/backends/gpu/forwards.h"
#ifdef PADDLE_WITH_HIP
using gpuStream_t = hipStream_t;
#else
using gpuStream_t = cudaStream_t;
#endif
// forward declarations
namespace phi {
class DenseTensor;
}
......@@ -49,9 +60,9 @@ struct CommStaticCheck {
int in_size_factor);
// for p2p
static void SingleTensor(const phi::DenseTensor& tensor,
int rank,
int world_size);
static void CheckShape(const phi::DenseTensor& tensor,
int rank,
int world_size);
// for collective
static void SameShape(const phi::DenseTensor& out_tensor,
......@@ -73,5 +84,32 @@ struct CommStaticCheck {
int world_size);
};
struct CommDynamicCheck {
static void CheckDataType(const phi::DenseTensor& tensor, int64_t dtype);
static void CheckDataType(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm);
static void CheckShape(const phi::DenseTensor& tensor, int64_t shape);
static void CheckShape(const phi::DenseTensor& tensor,
int root_rank,
int cur_rank,
ncclComm_t comm);
static void CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& in_size_each_rank,
int cur_rank,
int world_size,
ncclComm_t comm);
private:
// `0` represents default stream for both cuda & hip
static constexpr gpuStream_t kDefaultStream = 0;
};
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册