未验证 提交 6678def9 编写于 作者: S sneaxiy 提交者: GitHub

Add nproc_per_node for DistributedFusedLamb (#43295)

* add nproc_per_node for DistributedFusedLamb

* fix nproc_per_node communicator bug

* fix ring_id = 1 init bug

* fix ci

* fix test_parallel_executor_mnist.py
上级 2c8739e8
......@@ -793,8 +793,8 @@ void ParallelExecutor::BCastParamsToDevices(
std::vector<void *> buffers;
buffers.reserve(member_->places_.size());
size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(
framework::TransToProtoVarType(main_tensor.dtype()));
auto dtype = framework::TransToProtoVarType(main_tensor.dtype());
ncclDataType_t data_type = platform::ToNCCLDataType(dtype);
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i];
void *buffer;
......@@ -815,7 +815,7 @@ void ParallelExecutor::BCastParamsToDevices(
"variables' buffer size to bcast is %d, which is "
"NOT equal to places size %d",
buffers.size(), member_->places_.size()));
{
if (member_->nccl_ctxs_ != nullptr) {
auto *nccl_ctxs = member_->nccl_ctxs_->DefaultFlatCtx();
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
......@@ -824,6 +824,22 @@ void ParallelExecutor::BCastParamsToDevices(
nccl_ctx.comm_, nccl_ctx.stream());
}
nccl_ctxs->WaitAll();
} else {
auto src_place = member_->places_[0];
auto src_dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(src_place));
auto sizeof_dtype = framework::SizeOfType(dtype) * numel;
for (size_t i = 1; i < member_->places_.size(); ++i) {
auto dst_place = member_->places_[i];
auto dst_dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(dst_place));
src_dev_ctx->Wait();
dst_dev_ctx->Wait();
memory::Copy(dst_place, buffers[i], src_place, buffers[0],
sizeof_dtype, src_dev_ctx->stream());
src_dev_ctx->Wait();
dst_dev_ctx->Wait();
}
}
#endif
} else if (paddle::platform::is_xpu_place(main_tensor.place())) {
......@@ -1348,6 +1364,11 @@ std::vector<ir::Graph *> ParallelExecutor::CloneGraphToMultiDevices(
}
void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) {
if (member_->build_strategy_.reduce_ ==
BuildStrategy::ReduceStrategy::kNoReduce) {
return;
}
if (member_->IsUseCUDA(member_->use_device_) && member_->nranks_ > 1) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
member_->InitOrGetNCCLCommunicator(global_scope, &member_->build_strategy_);
......
......@@ -147,8 +147,10 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("is_grad_scaled_by_nranks",
"Whether the input gradient has been scaled by nranks.")
.SetDefault(true);
AddAttr<int>("ring_id", "The ring id of the NCCL communicator.")
.SetDefault(0);
AddAttr<int64_t>("nranks", "The world size.").SetDefault(1);
AddAttr<std::vector<int>>("ring_id",
"The ring id of the NCCL communicator.")
.SetDefault({0});
AddComment("The DistributedFusedLamb optimizer.");
}
};
......
......@@ -806,23 +806,24 @@ static void LaunchScaleKernel(const platform::CUDADeviceContext &dev_ctx,
#undef PD_LAMB_VEC_SCALE_KERNEL_CASE
}
template <typename T>
static void NCCLReduceScatterWithScale(
const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks,
ncclComm_t comm, gpuStream_t stream,
const platform::CUDADeviceContext &dev_ctx, const T *scale = nullptr) {
template <typename T, bool UseReduceScatter>
static void NCCLSumWithScaleBase(const T *sendbuff, T *recvbuff,
size_t recvcount, size_t nranks,
ncclComm_t comm, gpuStream_t stream,
const platform::CUDADeviceContext &dev_ctx,
const T *scale = nullptr) {
static_assert(std::is_same<T, float>::value ||
std::is_same<T, platform::float16>::value,
"T must be either float32 or float16.");
if (recvcount == 0) return;
auto numel = UseReduceScatter ? (recvcount * nranks) : recvcount;
if (comm == nullptr) {
if (scale != nullptr) {
PADDLE_ENFORCE_EQ(nranks, 1,
platform::errors::InvalidArgument(
"nranks must be 1 when scale != nullptr."));
LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, recvcount * nranks,
stream);
LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream);
}
return;
}
......@@ -834,14 +835,18 @@ static void NCCLReduceScatterWithScale(
scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op);
memory::Buffer buffer(dev_ctx.GetPlace());
if (scale && !should_destroy_op) {
size_t numel = recvcount * nranks;
T *new_sendbuff = buffer.Alloc<T>(numel);
LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
sendbuff = new_sendbuff;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
if (UseReduceScatter) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
}
#if NCCL_VERSION_CODE >= 21100
if (should_destroy_op) {
......@@ -851,6 +856,26 @@ static void NCCLReduceScatterWithScale(
}
#endif
}
template <typename T>
static void NCCLReduceScatterWithScale(
const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks,
ncclComm_t comm, gpuStream_t stream,
const platform::CUDADeviceContext &dev_ctx, const T *scale = nullptr) {
NCCLSumWithScaleBase<T, true>(sendbuff, recvbuff, recvcount, nranks, comm,
stream, dev_ctx, scale);
}
template <typename T>
static void NCCLAllReduceWithScale(const T *sendbuff, T *recvbuff,
size_t recvcount, size_t nranks,
ncclComm_t comm, gpuStream_t stream,
const platform::CUDADeviceContext &dev_ctx,
const T *scale = nullptr) {
NCCLSumWithScaleBase<T, false>(sendbuff, recvbuff, recvcount, nranks, comm,
stream, dev_ctx, scale);
}
#endif
template <typename InputIteratorT, typename OutputIteratorT, typename ReduceOpT,
......@@ -1321,6 +1346,9 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
"exactly by the element number %d of Moment1.",
numel, partial_numel));
// The num_devices means the number of devices that shard a complete set
// of all parameters. It may be num_devices < nranks or num_devices ==
// nranks.
int64_t num_devices = numel / partial_numel;
VLOG(1) << "num_devices = " << num_devices
<< " , partial_numel = " << partial_numel;
......@@ -1354,22 +1382,43 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
auto epsilon = ctx.Attr<float>("epsilon");
auto max_global_grad_norm = ctx.Attr<float>("max_global_grad_norm");
auto clip_after_allreduce = ctx.Attr<bool>("clip_after_allreduce");
auto ring_id = ctx.Attr<int>("ring_id");
auto nranks = ctx.Attr<int64_t>("nranks");
PADDLE_ENFORCE_GE(nranks, num_devices,
phi::errors::InvalidArgument(
"The nranks must be not less than num_devices."));
PADDLE_ENFORCE_EQ(
nranks % num_devices, 0,
phi::errors::InvalidArgument(
"The nranks must be exactly divided by num_devices."));
bool local_shard = (nranks > num_devices);
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_id");
auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
<< " , clip_after_allreduce = " << clip_after_allreduce
<< " , use_master_param_norm = " << use_master_param_norm
<< " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks;
<< " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
<< " , local_shard = " << local_shard;
// Step 6: allreduce + global norm gradient clip
int rank = 0;
ncclComm_t comm = nullptr;
if (num_devices > 1) {
int64_t global_rank = 0, local_rank = 0;
ncclComm_t global_comm = nullptr, local_comm = 0;
if (nranks > 1) {
auto *nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_id, place);
comm = nccl_comm_handle->comm();
rank = nccl_comm_handle->rank();
platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
global_comm = nccl_comm_handle->comm();
global_rank = nccl_comm_handle->rank();
if (local_shard) {
auto *local_nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
local_comm = local_nccl_comm_handle->comm();
local_rank = local_nccl_comm_handle->rank();
} else {
local_comm = global_comm;
local_rank = global_rank;
}
}
memory::Buffer grad_norm_square_buffer(place);
......@@ -1381,8 +1430,15 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
platform::float16 *fp16_sum_grad;
auto fp32_numel_each_device = fp32_numel / num_devices;
auto fp16_numel_each_device = fp16_numel / num_devices;
if (num_devices > 1 ||
(max_global_grad_norm > 0 && !clip_after_allreduce)) {
if (local_shard) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel * sizeof(float) + fp16_numel * sizeof(platform::float16));
fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
fp16_sum_grad = has_fp16_param ? reinterpret_cast<platform::float16 *>(
ptr + fp32_numel * sizeof(float))
: nullptr;
} else if (nranks > 1 ||
(max_global_grad_norm > 0 && !clip_after_allreduce)) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel_each_device * sizeof(float) +
fp16_numel_each_device * sizeof(platform::float16));
......@@ -1404,18 +1460,27 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
float rescale_grad = 1.0f;
if (!is_grad_scaled_by_nranks) {
rescale_grad /= num_devices;
rescale_grad /= nranks;
}
if (max_global_grad_norm > 0) {
if (clip_after_allreduce) {
// (1) ReduceScater first
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
fp32_numel_each_device, num_devices, comm,
stream, dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
fp16_numel_each_device, num_devices, comm,
stream, dev_ctx);
if (local_shard) {
NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks,
global_comm, stream, dev_ctx);
NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks,
global_comm, stream, dev_ctx);
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
fp32_numel_each_device, nranks,
global_comm, stream, dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
fp16_numel_each_device, nranks,
global_comm, stream, dev_ctx);
}
// (2) Calculate the global grad norm
GetSquareGradNorm(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad,
fp16_numel_each_device, fp32_square_grad_norm, stream,
......@@ -1425,7 +1490,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
ncclSum, comm, stream));
ncclSum, local_comm, stream));
}
VLOG(1) << "Grad square norm after all reduce: "
<< FlattenToString(fp32_square_grad_norm, 1, place);
......@@ -1452,7 +1517,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
float clip_scale = 1.0f;
if (is_grad_scaled_by_nranks) {
clip_scale *= num_devices;
clip_scale *= nranks;
}
CalcGradNormClipBeforeAllReduceScale<float, platform::float16>
<<<1, 1, 0, stream>>>(global_scale, max_global_grad_norm,
......@@ -1463,36 +1528,54 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
} else {
VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place);
}
if (num_devices > 1) {
if (nranks > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
ncclSum, comm, stream));
ncclSum, global_comm, stream));
}
// (3) Do ReduceScatter with scale
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
fp32_numel_each_device, num_devices, comm,
stream, dev_ctx, fp32_scale);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
fp16_numel_each_device, num_devices, comm,
stream, dev_ctx, fp16_scale);
if (local_shard) {
NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks,
global_comm, stream, dev_ctx, fp32_scale);
NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks,
global_comm, stream, dev_ctx, fp16_scale);
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
fp32_numel_each_device, nranks,
global_comm, stream, dev_ctx, fp32_scale);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
fp16_numel_each_device, nranks,
global_comm, stream, dev_ctx, fp16_scale);
}
// (4) mark max_global_grad_norm as 0, meaning that clip has been
// already performed
max_global_grad_norm = 0;
}
} else {
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
fp32_numel_each_device, num_devices, comm,
stream, dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
fp16_numel_each_device, num_devices, comm,
stream, dev_ctx);
if (local_shard) {
NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks,
global_comm, stream, dev_ctx);
NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks,
global_comm, stream, dev_ctx);
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
fp32_numel_each_device, num_devices,
global_comm, stream, dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
fp16_numel_each_device, num_devices,
global_comm, stream, dev_ctx);
}
CheckHasNanInfGrad(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad,
fp16_numel_each_device, fp32_square_grad_norm, stream,
&cub_tmp_buffer);
if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
ncclSum, comm, stream));
ncclSum, local_comm, stream));
}
max_global_grad_norm = 0;
}
......@@ -1526,8 +1609,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
memory::Buffer trust_ratio_div_buffer(place);
auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
auto fp32_offset = rank * fp32_numel_each_device;
auto fp16_offset = rank * fp16_numel_each_device;
auto fp32_offset = local_rank * fp32_numel_each_device;
auto fp16_offset = local_rank * fp16_numel_each_device;
if (has_fp32_param) {
VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts";
MultiTensorUpdateLambMomentAndTrustRatioDiv(
......@@ -1598,12 +1681,12 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
param_square_norm + fp32_global_param_num,
param_square_norm + fp32_global_param_num,
2 * param_num - fp32_global_param_num, ncclFloat32, ncclSum, comm,
stream));
2 * param_num - fp32_global_param_num, ncclFloat32, ncclSum,
local_comm, stream));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
trust_ratio_div_square_norm, trust_ratio_div_square_norm, param_num,
ncclFloat32, ncclSum, comm, stream));
ncclFloat32, ncclSum, local_comm, stream));
}
VLOG(10) << "ncclAllReduce done";
}
......@@ -1623,7 +1706,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
// ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
fp32_param + fp32_offset, fp32_param, fp32_numel_each_device,
ncclFloat32, comm, stream));
ncclFloat32, local_comm, stream));
}
beta1pow = nullptr;
......@@ -1641,7 +1724,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
// ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
fp16_param + fp16_offset, fp16_param, fp16_numel_each_device,
ncclFloat16, comm, stream));
ncclFloat16, local_comm, stream));
}
}
VLOG(10) << "Update Param done";
......
......@@ -69,6 +69,9 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
if trainer_id == 0 and not paddle.is_compiled_with_npu():
wait_server_ready(other_trainers)
if build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy._NoReduce:
return
if core.is_compiled_with_cuda():
comm_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.fluid import framework, core, layers, unique_name
from paddle.fluid.framework import Variable
from paddle.fluid.clip import ClipGradByGlobalNorm
......@@ -19,11 +20,69 @@ from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.optimizer import Optimizer
from paddle.distributed import get_rank, get_world_size
from paddle.distributed.collective import new_group
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import name_scope
from paddle.fluid import core, unique_name
import numpy as np
def init_communicator(block, rank, ranks, ring_id):
eps = os.environ['PADDLE_TRAINER_ENDPOINTS']
eps = [ep.strip() for ep in eps.split(",") if ep.strip()]
cur_ep = eps[rank]
other_eps = [eps[r] for r in ranks if r != rank]
local_rank = ranks.index(rank)
comm_var_name = unique_name.generate('comm_id')
comm_id_var = block.create_var(name=comm_var_name,
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(type='c_gen_nccl_id',
inputs={},
outputs={'Out': comm_id_var},
attrs={
'rank': local_rank,
'endpoint': cur_ep,
'other_endpoints': other_eps,
'ring_id': ring_id
})
block.append_op(type='c_comm_init',
inputs={'X': comm_id_var},
outputs={},
attrs={
'nranks': len(ranks),
'rank': local_rank,
'ring_id': ring_id
})
tmp_var = block.create_var(name=unique_name.generate('tmp'))
block.append_op(type='fill_constant',
outputs={'Out': tmp_var},
attrs={'value': 1})
block.append_op(type='c_allreduce_sum',
inputs={'X': tmp_var},
outputs={'Out': tmp_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': True
})
block.append_op(type='c_sync_calc_stream',
inputs={'X': tmp_var},
outputs={'Out': tmp_var})
return ring_id
def broadcast_parameters(block, parameters, ring_id):
for p in parameters:
block.append_op(type='c_broadcast',
inputs={'X': p},
outputs={'Out': p},
attrs={
'ring_id': ring_id,
'use_calc_stream': True
})
class DistributedFusedLamb(Optimizer):
def __init__(self,
......@@ -41,6 +100,7 @@ class DistributedFusedLamb(Optimizer):
use_master_param_norm=True,
gradient_accumulation_steps=1,
use_master_acc_grad=True,
nproc_per_node=None,
name=None):
assert not framework._non_static_mode(
), "DistributedFusedLamb does not support dygraph mode"
......@@ -65,10 +125,10 @@ class DistributedFusedLamb(Optimizer):
self._is_grad_scaled_by_nranks = is_grad_scaled_by_nranks
self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
self._scale = None
self._ring_id = 0
self._use_master_param_norm = use_master_param_norm
self._gradient_accumulation_steps = gradient_accumulation_steps
self._use_master_acc_grad = use_master_acc_grad
self._nproc_per_node = nproc_per_node
assert self._gradient_accumulation_steps >= 1
self.helper = LayerHelper('distributed_fused_lamb')
......@@ -228,6 +288,30 @@ class DistributedFusedLamb(Optimizer):
rank = get_rank()
nranks = get_world_size()
if self._nproc_per_node is None:
nproc_per_node = nranks
else:
nproc_per_node = self._nproc_per_node
assert nranks % nproc_per_node == 0, "nranks should be exactly divided by nproc_per_node"
shard_inside_node = (nranks > nproc_per_node)
local_rank = rank % nproc_per_node
node_id = int(rank / nproc_per_node)
node_num = int(nranks / nproc_per_node)
ring_ids = []
startup_block = self.helper.startup_program.global_block()
if nranks > 1:
ring_id = init_communicator(startup_block, rank,
list(range(nranks)), 0)
ring_ids.append(ring_id)
if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node:
local_group_ranks = list(
range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node))
ring_id = init_communicator(startup_block, rank, local_group_ranks,
1)
ring_ids.append(ring_id)
scale = self._get_or_create_scale()
params = [p for p, _ in params_grads]
......@@ -238,7 +322,6 @@ class DistributedFusedLamb(Optimizer):
if self._exclude_from_weight_decay_fn(p):
apply_weight_decay[i] = 0
startup_block = self.helper.startup_program.global_block()
for g in grads:
startup_block.create_var(name=g.name,
type=g.type,
......@@ -246,46 +329,45 @@ class DistributedFusedLamb(Optimizer):
persistable=g.persistable,
shape=g.shape)
startup_block.append_op(type='distributed_fused_lamb_init',
inputs={
'Param': params,
'Grad': grads,
},
outputs={
'FP32FusedParam': [fp32_fused_param],
'FP32FusedGrad': [fp32_fused_grad],
'FP16FusedParam': [fp16_fused_param],
'FP16FusedGrad': [fp16_fused_grad],
'Moment1': [moment1],
'Moment2': [moment2],
'Beta1Pow': [beta1pow],
'Beta2Pow': [beta2pow],
'GlobalScale': [scale],
'ParamInfo': [param_info],
'ParamOut':
params,
'MasterParamOut':
master_params,
'GradOut':
grads,
'FP32ShardFusedParamOffsets':
[fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets':
[fp16_partial_fused_offsets],
'FusedParamOffsets': [fused_offsets],
'ParamOrder': [param_order],
'Step': [step],
},
attrs={
'alignment': self._alignment,
'rank': rank,
'nranks': nranks,
'apply_weight_decay': apply_weight_decay,
'moment1': 0.0,
'moment2': 0.0,
'beta1': self._beta1,
'beta2': self._beta2,
})
if nranks > 1:
broadcast_parameters(startup_block, params, ring_ids[0])
startup_block.append_op(
type='distributed_fused_lamb_init',
inputs={
'Param': params,
'Grad': grads,
},
outputs={
'FP32FusedParam': [fp32_fused_param],
'FP32FusedGrad': [fp32_fused_grad],
'FP16FusedParam': [fp16_fused_param],
'FP16FusedGrad': [fp16_fused_grad],
'Moment1': [moment1],
'Moment2': [moment2],
'Beta1Pow': [beta1pow],
'Beta2Pow': [beta2pow],
'GlobalScale': [scale],
'ParamInfo': [param_info],
'ParamOut': params,
'MasterParamOut': master_params,
'GradOut': grads,
'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
'FusedParamOffsets': [fused_offsets],
'ParamOrder': [param_order],
'Step': [step],
},
attrs={
'alignment': self._alignment,
'rank': local_rank if shard_inside_node else rank,
'nranks': nproc_per_node if shard_inside_node else nranks,
'apply_weight_decay': apply_weight_decay,
'moment1': 0.0,
'moment2': 0.0,
'beta1': self._beta1,
'beta2': self._beta2,
})
main_block = self.helper.main_program.global_block()
self._create_global_learning_rate()
......@@ -351,7 +433,8 @@ class DistributedFusedLamb(Optimizer):
'max_global_grad_norm': self._max_global_grad_norm,
'clip_after_allreduce': self._clip_after_allreduce,
'rank': rank,
'ring_id': self._ring_id,
'nranks': nranks,
'ring_id': ring_ids,
'use_master_param_norm': self._use_master_param_norm,
'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks,
'acc_steps': self._gradient_accumulation_steps,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册