未验证 提交 6678def9 编写于 作者: S sneaxiy 提交者: GitHub

Add nproc_per_node for DistributedFusedLamb (#43295)

* add nproc_per_node for DistributedFusedLamb

* fix nproc_per_node communicator bug

* fix ring_id = 1 init bug

* fix ci

* fix test_parallel_executor_mnist.py
上级 2c8739e8
...@@ -793,8 +793,8 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -793,8 +793,8 @@ void ParallelExecutor::BCastParamsToDevices(
std::vector<void *> buffers; std::vector<void *> buffers;
buffers.reserve(member_->places_.size()); buffers.reserve(member_->places_.size());
size_t numel = main_tensor.numel(); size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType( auto dtype = framework::TransToProtoVarType(main_tensor.dtype());
framework::TransToProtoVarType(main_tensor.dtype())); ncclDataType_t data_type = platform::ToNCCLDataType(dtype);
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i]; auto place = member_->places_[i];
void *buffer; void *buffer;
...@@ -815,7 +815,7 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -815,7 +815,7 @@ void ParallelExecutor::BCastParamsToDevices(
"variables' buffer size to bcast is %d, which is " "variables' buffer size to bcast is %d, which is "
"NOT equal to places size %d", "NOT equal to places size %d",
buffers.size(), member_->places_.size())); buffers.size(), member_->places_.size()));
{ if (member_->nccl_ctxs_ != nullptr) {
auto *nccl_ctxs = member_->nccl_ctxs_->DefaultFlatCtx(); auto *nccl_ctxs = member_->nccl_ctxs_->DefaultFlatCtx();
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
...@@ -824,6 +824,22 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -824,6 +824,22 @@ void ParallelExecutor::BCastParamsToDevices(
nccl_ctx.comm_, nccl_ctx.stream()); nccl_ctx.comm_, nccl_ctx.stream());
} }
nccl_ctxs->WaitAll(); nccl_ctxs->WaitAll();
} else {
auto src_place = member_->places_[0];
auto src_dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(src_place));
auto sizeof_dtype = framework::SizeOfType(dtype) * numel;
for (size_t i = 1; i < member_->places_.size(); ++i) {
auto dst_place = member_->places_[i];
auto dst_dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(dst_place));
src_dev_ctx->Wait();
dst_dev_ctx->Wait();
memory::Copy(dst_place, buffers[i], src_place, buffers[0],
sizeof_dtype, src_dev_ctx->stream());
src_dev_ctx->Wait();
dst_dev_ctx->Wait();
}
} }
#endif #endif
} else if (paddle::platform::is_xpu_place(main_tensor.place())) { } else if (paddle::platform::is_xpu_place(main_tensor.place())) {
...@@ -1348,6 +1364,11 @@ std::vector<ir::Graph *> ParallelExecutor::CloneGraphToMultiDevices( ...@@ -1348,6 +1364,11 @@ std::vector<ir::Graph *> ParallelExecutor::CloneGraphToMultiDevices(
} }
void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) { void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) {
if (member_->build_strategy_.reduce_ ==
BuildStrategy::ReduceStrategy::kNoReduce) {
return;
}
if (member_->IsUseCUDA(member_->use_device_) && member_->nranks_ > 1) { if (member_->IsUseCUDA(member_->use_device_) && member_->nranks_ > 1) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
member_->InitOrGetNCCLCommunicator(global_scope, &member_->build_strategy_); member_->InitOrGetNCCLCommunicator(global_scope, &member_->build_strategy_);
......
...@@ -147,8 +147,10 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -147,8 +147,10 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("is_grad_scaled_by_nranks", AddAttr<bool>("is_grad_scaled_by_nranks",
"Whether the input gradient has been scaled by nranks.") "Whether the input gradient has been scaled by nranks.")
.SetDefault(true); .SetDefault(true);
AddAttr<int>("ring_id", "The ring id of the NCCL communicator.") AddAttr<int64_t>("nranks", "The world size.").SetDefault(1);
.SetDefault(0); AddAttr<std::vector<int>>("ring_id",
"The ring id of the NCCL communicator.")
.SetDefault({0});
AddComment("The DistributedFusedLamb optimizer."); AddComment("The DistributedFusedLamb optimizer.");
} }
}; };
......
...@@ -806,23 +806,24 @@ static void LaunchScaleKernel(const platform::CUDADeviceContext &dev_ctx, ...@@ -806,23 +806,24 @@ static void LaunchScaleKernel(const platform::CUDADeviceContext &dev_ctx,
#undef PD_LAMB_VEC_SCALE_KERNEL_CASE #undef PD_LAMB_VEC_SCALE_KERNEL_CASE
} }
template <typename T> template <typename T, bool UseReduceScatter>
static void NCCLReduceScatterWithScale( static void NCCLSumWithScaleBase(const T *sendbuff, T *recvbuff,
const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks, size_t recvcount, size_t nranks,
ncclComm_t comm, gpuStream_t stream, ncclComm_t comm, gpuStream_t stream,
const platform::CUDADeviceContext &dev_ctx, const T *scale = nullptr) { const platform::CUDADeviceContext &dev_ctx,
const T *scale = nullptr) {
static_assert(std::is_same<T, float>::value || static_assert(std::is_same<T, float>::value ||
std::is_same<T, platform::float16>::value, std::is_same<T, platform::float16>::value,
"T must be either float32 or float16."); "T must be either float32 or float16.");
if (recvcount == 0) return; if (recvcount == 0) return;
auto numel = UseReduceScatter ? (recvcount * nranks) : recvcount;
if (comm == nullptr) { if (comm == nullptr) {
if (scale != nullptr) { if (scale != nullptr) {
PADDLE_ENFORCE_EQ(nranks, 1, PADDLE_ENFORCE_EQ(nranks, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"nranks must be 1 when scale != nullptr.")); "nranks must be 1 when scale != nullptr."));
LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, recvcount * nranks, LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream);
stream);
} }
return; return;
} }
...@@ -834,14 +835,18 @@ static void NCCLReduceScatterWithScale( ...@@ -834,14 +835,18 @@ static void NCCLReduceScatterWithScale(
scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op); scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op);
memory::Buffer buffer(dev_ctx.GetPlace()); memory::Buffer buffer(dev_ctx.GetPlace());
if (scale && !should_destroy_op) { if (scale && !should_destroy_op) {
size_t numel = recvcount * nranks;
T *new_sendbuff = buffer.Alloc<T>(numel); T *new_sendbuff = buffer.Alloc<T>(numel);
LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream); LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
sendbuff = new_sendbuff; sendbuff = new_sendbuff;
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( if (UseReduceScatter) {
sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
}
#if NCCL_VERSION_CODE >= 21100 #if NCCL_VERSION_CODE >= 21100
if (should_destroy_op) { if (should_destroy_op) {
...@@ -851,6 +856,26 @@ static void NCCLReduceScatterWithScale( ...@@ -851,6 +856,26 @@ static void NCCLReduceScatterWithScale(
} }
#endif #endif
} }
template <typename T>
static void NCCLReduceScatterWithScale(
const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks,
ncclComm_t comm, gpuStream_t stream,
const platform::CUDADeviceContext &dev_ctx, const T *scale = nullptr) {
NCCLSumWithScaleBase<T, true>(sendbuff, recvbuff, recvcount, nranks, comm,
stream, dev_ctx, scale);
}
template <typename T>
static void NCCLAllReduceWithScale(const T *sendbuff, T *recvbuff,
size_t recvcount, size_t nranks,
ncclComm_t comm, gpuStream_t stream,
const platform::CUDADeviceContext &dev_ctx,
const T *scale = nullptr) {
NCCLSumWithScaleBase<T, false>(sendbuff, recvbuff, recvcount, nranks, comm,
stream, dev_ctx, scale);
}
#endif #endif
template <typename InputIteratorT, typename OutputIteratorT, typename ReduceOpT, template <typename InputIteratorT, typename OutputIteratorT, typename ReduceOpT,
...@@ -1321,6 +1346,9 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1321,6 +1346,9 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
"exactly by the element number %d of Moment1.", "exactly by the element number %d of Moment1.",
numel, partial_numel)); numel, partial_numel));
// The num_devices means the number of devices that shard a complete set
// of all parameters. It may be num_devices < nranks or num_devices ==
// nranks.
int64_t num_devices = numel / partial_numel; int64_t num_devices = numel / partial_numel;
VLOG(1) << "num_devices = " << num_devices VLOG(1) << "num_devices = " << num_devices
<< " , partial_numel = " << partial_numel; << " , partial_numel = " << partial_numel;
...@@ -1354,22 +1382,43 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1354,22 +1382,43 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
auto epsilon = ctx.Attr<float>("epsilon"); auto epsilon = ctx.Attr<float>("epsilon");
auto max_global_grad_norm = ctx.Attr<float>("max_global_grad_norm"); auto max_global_grad_norm = ctx.Attr<float>("max_global_grad_norm");
auto clip_after_allreduce = ctx.Attr<bool>("clip_after_allreduce"); auto clip_after_allreduce = ctx.Attr<bool>("clip_after_allreduce");
auto ring_id = ctx.Attr<int>("ring_id"); auto nranks = ctx.Attr<int64_t>("nranks");
PADDLE_ENFORCE_GE(nranks, num_devices,
phi::errors::InvalidArgument(
"The nranks must be not less than num_devices."));
PADDLE_ENFORCE_EQ(
nranks % num_devices, 0,
phi::errors::InvalidArgument(
"The nranks must be exactly divided by num_devices."));
bool local_shard = (nranks > num_devices);
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_id");
auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm"); auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks"); auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
<< " , clip_after_allreduce = " << clip_after_allreduce << " , clip_after_allreduce = " << clip_after_allreduce
<< " , use_master_param_norm = " << use_master_param_norm << " , use_master_param_norm = " << use_master_param_norm
<< " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks; << " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
<< " , local_shard = " << local_shard;
// Step 6: allreduce + global norm gradient clip // Step 6: allreduce + global norm gradient clip
int rank = 0; int64_t global_rank = 0, local_rank = 0;
ncclComm_t comm = nullptr; ncclComm_t global_comm = nullptr, local_comm = 0;
if (num_devices > 1) { if (nranks > 1) {
auto *nccl_comm_handle = auto *nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_id, place); platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
comm = nccl_comm_handle->comm(); global_comm = nccl_comm_handle->comm();
rank = nccl_comm_handle->rank(); global_rank = nccl_comm_handle->rank();
if (local_shard) {
auto *local_nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
local_comm = local_nccl_comm_handle->comm();
local_rank = local_nccl_comm_handle->rank();
} else {
local_comm = global_comm;
local_rank = global_rank;
}
} }
memory::Buffer grad_norm_square_buffer(place); memory::Buffer grad_norm_square_buffer(place);
...@@ -1381,8 +1430,15 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1381,8 +1430,15 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
platform::float16 *fp16_sum_grad; platform::float16 *fp16_sum_grad;
auto fp32_numel_each_device = fp32_numel / num_devices; auto fp32_numel_each_device = fp32_numel / num_devices;
auto fp16_numel_each_device = fp16_numel / num_devices; auto fp16_numel_each_device = fp16_numel / num_devices;
if (num_devices > 1 || if (local_shard) {
(max_global_grad_norm > 0 && !clip_after_allreduce)) { auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel * sizeof(float) + fp16_numel * sizeof(platform::float16));
fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
fp16_sum_grad = has_fp16_param ? reinterpret_cast<platform::float16 *>(
ptr + fp32_numel * sizeof(float))
: nullptr;
} else if (nranks > 1 ||
(max_global_grad_norm > 0 && !clip_after_allreduce)) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>( auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel_each_device * sizeof(float) + fp32_numel_each_device * sizeof(float) +
fp16_numel_each_device * sizeof(platform::float16)); fp16_numel_each_device * sizeof(platform::float16));
...@@ -1404,18 +1460,27 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1404,18 +1460,27 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
float rescale_grad = 1.0f; float rescale_grad = 1.0f;
if (!is_grad_scaled_by_nranks) { if (!is_grad_scaled_by_nranks) {
rescale_grad /= num_devices; rescale_grad /= nranks;
} }
if (max_global_grad_norm > 0) { if (max_global_grad_norm > 0) {
if (clip_after_allreduce) { if (clip_after_allreduce) {
// (1) ReduceScater first // (1) ReduceScater first
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, if (local_shard) {
fp32_numel_each_device, num_devices, comm, NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks,
stream, dev_ctx); global_comm, stream, dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks,
fp16_numel_each_device, num_devices, comm, global_comm, stream, dev_ctx);
stream, dev_ctx); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
fp32_numel_each_device, nranks,
global_comm, stream, dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
fp16_numel_each_device, nranks,
global_comm, stream, dev_ctx);
}
// (2) Calculate the global grad norm // (2) Calculate the global grad norm
GetSquareGradNorm(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad, GetSquareGradNorm(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad,
fp16_numel_each_device, fp32_square_grad_norm, stream, fp16_numel_each_device, fp32_square_grad_norm, stream,
...@@ -1425,7 +1490,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1425,7 +1490,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
if (num_devices > 1) { if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
ncclSum, comm, stream)); ncclSum, local_comm, stream));
} }
VLOG(1) << "Grad square norm after all reduce: " VLOG(1) << "Grad square norm after all reduce: "
<< FlattenToString(fp32_square_grad_norm, 1, place); << FlattenToString(fp32_square_grad_norm, 1, place);
...@@ -1452,7 +1517,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1452,7 +1517,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
float clip_scale = 1.0f; float clip_scale = 1.0f;
if (is_grad_scaled_by_nranks) { if (is_grad_scaled_by_nranks) {
clip_scale *= num_devices; clip_scale *= nranks;
} }
CalcGradNormClipBeforeAllReduceScale<float, platform::float16> CalcGradNormClipBeforeAllReduceScale<float, platform::float16>
<<<1, 1, 0, stream>>>(global_scale, max_global_grad_norm, <<<1, 1, 0, stream>>>(global_scale, max_global_grad_norm,
...@@ -1463,36 +1528,54 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1463,36 +1528,54 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
} else { } else {
VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place); VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place);
} }
if (num_devices > 1) { if (nranks > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
ncclSum, comm, stream)); ncclSum, global_comm, stream));
} }
// (3) Do ReduceScatter with scale // (3) Do ReduceScatter with scale
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, if (local_shard) {
fp32_numel_each_device, num_devices, comm, NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks,
stream, dev_ctx, fp32_scale); global_comm, stream, dev_ctx, fp32_scale);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks,
fp16_numel_each_device, num_devices, comm, global_comm, stream, dev_ctx, fp16_scale);
stream, dev_ctx, fp16_scale); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
fp32_numel_each_device, nranks,
global_comm, stream, dev_ctx, fp32_scale);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
fp16_numel_each_device, nranks,
global_comm, stream, dev_ctx, fp16_scale);
}
// (4) mark max_global_grad_norm as 0, meaning that clip has been // (4) mark max_global_grad_norm as 0, meaning that clip has been
// already performed // already performed
max_global_grad_norm = 0; max_global_grad_norm = 0;
} }
} else { } else {
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, if (local_shard) {
fp32_numel_each_device, num_devices, comm, NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks,
stream, dev_ctx); global_comm, stream, dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks,
fp16_numel_each_device, num_devices, comm, global_comm, stream, dev_ctx);
stream, dev_ctx); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
fp32_numel_each_device, num_devices,
global_comm, stream, dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
fp16_numel_each_device, num_devices,
global_comm, stream, dev_ctx);
}
CheckHasNanInfGrad(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad, CheckHasNanInfGrad(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad,
fp16_numel_each_device, fp32_square_grad_norm, stream, fp16_numel_each_device, fp32_square_grad_norm, stream,
&cub_tmp_buffer); &cub_tmp_buffer);
if (num_devices > 1) { if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
ncclSum, comm, stream)); ncclSum, local_comm, stream));
} }
max_global_grad_norm = 0; max_global_grad_norm = 0;
} }
...@@ -1526,8 +1609,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1526,8 +1609,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
memory::Buffer trust_ratio_div_buffer(place); memory::Buffer trust_ratio_div_buffer(place);
auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel); auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
auto fp32_offset = rank * fp32_numel_each_device; auto fp32_offset = local_rank * fp32_numel_each_device;
auto fp16_offset = rank * fp16_numel_each_device; auto fp16_offset = local_rank * fp16_numel_each_device;
if (has_fp32_param) { if (has_fp32_param) {
VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts"; VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts";
MultiTensorUpdateLambMomentAndTrustRatioDiv( MultiTensorUpdateLambMomentAndTrustRatioDiv(
...@@ -1598,12 +1681,12 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1598,12 +1681,12 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
param_square_norm + fp32_global_param_num, param_square_norm + fp32_global_param_num,
param_square_norm + fp32_global_param_num, param_square_norm + fp32_global_param_num,
2 * param_num - fp32_global_param_num, ncclFloat32, ncclSum, comm, 2 * param_num - fp32_global_param_num, ncclFloat32, ncclSum,
stream)); local_comm, stream));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
trust_ratio_div_square_norm, trust_ratio_div_square_norm, param_num, trust_ratio_div_square_norm, trust_ratio_div_square_norm, param_num,
ncclFloat32, ncclSum, comm, stream)); ncclFloat32, ncclSum, local_comm, stream));
} }
VLOG(10) << "ncclAllReduce done"; VLOG(10) << "ncclAllReduce done";
} }
...@@ -1623,7 +1706,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1623,7 +1706,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
// ncclAllGather // ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
fp32_param + fp32_offset, fp32_param, fp32_numel_each_device, fp32_param + fp32_offset, fp32_param, fp32_numel_each_device,
ncclFloat32, comm, stream)); ncclFloat32, local_comm, stream));
} }
beta1pow = nullptr; beta1pow = nullptr;
...@@ -1641,7 +1724,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1641,7 +1724,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
// ncclAllGather // ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
fp16_param + fp16_offset, fp16_param, fp16_numel_each_device, fp16_param + fp16_offset, fp16_param, fp16_numel_each_device,
ncclFloat16, comm, stream)); ncclFloat16, local_comm, stream));
} }
} }
VLOG(10) << "Update Param done"; VLOG(10) << "Update Param done";
......
...@@ -69,6 +69,9 @@ class GraphExecutionOptimizer(MetaOptimizerBase): ...@@ -69,6 +69,9 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
if trainer_id == 0 and not paddle.is_compiled_with_npu(): if trainer_id == 0 and not paddle.is_compiled_with_npu():
wait_server_ready(other_trainers) wait_server_ready(other_trainers)
if build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy._NoReduce:
return
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
comm_id_var = startup_program.global_block().create_var( comm_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from paddle.fluid import framework, core, layers, unique_name from paddle.fluid import framework, core, layers, unique_name
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.fluid.clip import ClipGradByGlobalNorm
...@@ -19,11 +20,69 @@ from paddle.fluid.initializer import Constant ...@@ -19,11 +20,69 @@ from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.optimizer import Optimizer from paddle.fluid.optimizer import Optimizer
from paddle.distributed import get_rank, get_world_size from paddle.distributed import get_rank, get_world_size
from paddle.distributed.collective import new_group
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.framework import name_scope from paddle.fluid.framework import name_scope
from paddle.fluid import core, unique_name
import numpy as np import numpy as np
def init_communicator(block, rank, ranks, ring_id):
eps = os.environ['PADDLE_TRAINER_ENDPOINTS']
eps = [ep.strip() for ep in eps.split(",") if ep.strip()]
cur_ep = eps[rank]
other_eps = [eps[r] for r in ranks if r != rank]
local_rank = ranks.index(rank)
comm_var_name = unique_name.generate('comm_id')
comm_id_var = block.create_var(name=comm_var_name,
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(type='c_gen_nccl_id',
inputs={},
outputs={'Out': comm_id_var},
attrs={
'rank': local_rank,
'endpoint': cur_ep,
'other_endpoints': other_eps,
'ring_id': ring_id
})
block.append_op(type='c_comm_init',
inputs={'X': comm_id_var},
outputs={},
attrs={
'nranks': len(ranks),
'rank': local_rank,
'ring_id': ring_id
})
tmp_var = block.create_var(name=unique_name.generate('tmp'))
block.append_op(type='fill_constant',
outputs={'Out': tmp_var},
attrs={'value': 1})
block.append_op(type='c_allreduce_sum',
inputs={'X': tmp_var},
outputs={'Out': tmp_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': True
})
block.append_op(type='c_sync_calc_stream',
inputs={'X': tmp_var},
outputs={'Out': tmp_var})
return ring_id
def broadcast_parameters(block, parameters, ring_id):
for p in parameters:
block.append_op(type='c_broadcast',
inputs={'X': p},
outputs={'Out': p},
attrs={
'ring_id': ring_id,
'use_calc_stream': True
})
class DistributedFusedLamb(Optimizer): class DistributedFusedLamb(Optimizer):
def __init__(self, def __init__(self,
...@@ -41,6 +100,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -41,6 +100,7 @@ class DistributedFusedLamb(Optimizer):
use_master_param_norm=True, use_master_param_norm=True,
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
use_master_acc_grad=True, use_master_acc_grad=True,
nproc_per_node=None,
name=None): name=None):
assert not framework._non_static_mode( assert not framework._non_static_mode(
), "DistributedFusedLamb does not support dygraph mode" ), "DistributedFusedLamb does not support dygraph mode"
...@@ -65,10 +125,10 @@ class DistributedFusedLamb(Optimizer): ...@@ -65,10 +125,10 @@ class DistributedFusedLamb(Optimizer):
self._is_grad_scaled_by_nranks = is_grad_scaled_by_nranks self._is_grad_scaled_by_nranks = is_grad_scaled_by_nranks
self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
self._scale = None self._scale = None
self._ring_id = 0
self._use_master_param_norm = use_master_param_norm self._use_master_param_norm = use_master_param_norm
self._gradient_accumulation_steps = gradient_accumulation_steps self._gradient_accumulation_steps = gradient_accumulation_steps
self._use_master_acc_grad = use_master_acc_grad self._use_master_acc_grad = use_master_acc_grad
self._nproc_per_node = nproc_per_node
assert self._gradient_accumulation_steps >= 1 assert self._gradient_accumulation_steps >= 1
self.helper = LayerHelper('distributed_fused_lamb') self.helper = LayerHelper('distributed_fused_lamb')
...@@ -228,6 +288,30 @@ class DistributedFusedLamb(Optimizer): ...@@ -228,6 +288,30 @@ class DistributedFusedLamb(Optimizer):
rank = get_rank() rank = get_rank()
nranks = get_world_size() nranks = get_world_size()
if self._nproc_per_node is None:
nproc_per_node = nranks
else:
nproc_per_node = self._nproc_per_node
assert nranks % nproc_per_node == 0, "nranks should be exactly divided by nproc_per_node"
shard_inside_node = (nranks > nproc_per_node)
local_rank = rank % nproc_per_node
node_id = int(rank / nproc_per_node)
node_num = int(nranks / nproc_per_node)
ring_ids = []
startup_block = self.helper.startup_program.global_block()
if nranks > 1:
ring_id = init_communicator(startup_block, rank,
list(range(nranks)), 0)
ring_ids.append(ring_id)
if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node:
local_group_ranks = list(
range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node))
ring_id = init_communicator(startup_block, rank, local_group_ranks,
1)
ring_ids.append(ring_id)
scale = self._get_or_create_scale() scale = self._get_or_create_scale()
params = [p for p, _ in params_grads] params = [p for p, _ in params_grads]
...@@ -238,7 +322,6 @@ class DistributedFusedLamb(Optimizer): ...@@ -238,7 +322,6 @@ class DistributedFusedLamb(Optimizer):
if self._exclude_from_weight_decay_fn(p): if self._exclude_from_weight_decay_fn(p):
apply_weight_decay[i] = 0 apply_weight_decay[i] = 0
startup_block = self.helper.startup_program.global_block()
for g in grads: for g in grads:
startup_block.create_var(name=g.name, startup_block.create_var(name=g.name,
type=g.type, type=g.type,
...@@ -246,46 +329,45 @@ class DistributedFusedLamb(Optimizer): ...@@ -246,46 +329,45 @@ class DistributedFusedLamb(Optimizer):
persistable=g.persistable, persistable=g.persistable,
shape=g.shape) shape=g.shape)
startup_block.append_op(type='distributed_fused_lamb_init', if nranks > 1:
inputs={ broadcast_parameters(startup_block, params, ring_ids[0])
'Param': params,
'Grad': grads, startup_block.append_op(
}, type='distributed_fused_lamb_init',
outputs={ inputs={
'FP32FusedParam': [fp32_fused_param], 'Param': params,
'FP32FusedGrad': [fp32_fused_grad], 'Grad': grads,
'FP16FusedParam': [fp16_fused_param], },
'FP16FusedGrad': [fp16_fused_grad], outputs={
'Moment1': [moment1], 'FP32FusedParam': [fp32_fused_param],
'Moment2': [moment2], 'FP32FusedGrad': [fp32_fused_grad],
'Beta1Pow': [beta1pow], 'FP16FusedParam': [fp16_fused_param],
'Beta2Pow': [beta2pow], 'FP16FusedGrad': [fp16_fused_grad],
'GlobalScale': [scale], 'Moment1': [moment1],
'ParamInfo': [param_info], 'Moment2': [moment2],
'ParamOut': 'Beta1Pow': [beta1pow],
params, 'Beta2Pow': [beta2pow],
'MasterParamOut': 'GlobalScale': [scale],
master_params, 'ParamInfo': [param_info],
'GradOut': 'ParamOut': params,
grads, 'MasterParamOut': master_params,
'FP32ShardFusedParamOffsets': 'GradOut': grads,
[fp32_partial_fused_offsets], 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets': 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
[fp16_partial_fused_offsets], 'FusedParamOffsets': [fused_offsets],
'FusedParamOffsets': [fused_offsets], 'ParamOrder': [param_order],
'ParamOrder': [param_order], 'Step': [step],
'Step': [step], },
}, attrs={
attrs={ 'alignment': self._alignment,
'alignment': self._alignment, 'rank': local_rank if shard_inside_node else rank,
'rank': rank, 'nranks': nproc_per_node if shard_inside_node else nranks,
'nranks': nranks, 'apply_weight_decay': apply_weight_decay,
'apply_weight_decay': apply_weight_decay, 'moment1': 0.0,
'moment1': 0.0, 'moment2': 0.0,
'moment2': 0.0, 'beta1': self._beta1,
'beta1': self._beta1, 'beta2': self._beta2,
'beta2': self._beta2, })
})
main_block = self.helper.main_program.global_block() main_block = self.helper.main_program.global_block()
self._create_global_learning_rate() self._create_global_learning_rate()
...@@ -351,7 +433,8 @@ class DistributedFusedLamb(Optimizer): ...@@ -351,7 +433,8 @@ class DistributedFusedLamb(Optimizer):
'max_global_grad_norm': self._max_global_grad_norm, 'max_global_grad_norm': self._max_global_grad_norm,
'clip_after_allreduce': self._clip_after_allreduce, 'clip_after_allreduce': self._clip_after_allreduce,
'rank': rank, 'rank': rank,
'ring_id': self._ring_id, 'nranks': nranks,
'ring_id': ring_ids,
'use_master_param_norm': self._use_master_param_norm, 'use_master_param_norm': self._use_master_param_norm,
'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks, 'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks,
'acc_steps': self._gradient_accumulation_steps, 'acc_steps': self._gradient_accumulation_steps,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册