未验证 提交 d17961ed 编写于 作者: S sneaxiy 提交者: GitHub

Optimize the CUDA kernel in DistributedFusedLamb optimizer (#39972)

* vectorize lamb kernel

* remove flags, add ut

* remove useless codes

* refine code, add param order
上级 1b585b28
...@@ -61,30 +61,31 @@ class DistributedFusedLambInitOpMaker ...@@ -61,30 +61,31 @@ class DistributedFusedLambInitOpMaker
"The fp32 beta1 power accumulator tensor. Its shape is [1]."); "The fp32 beta1 power accumulator tensor. Its shape is [1].");
AddOutput("Beta2Pow", AddOutput("Beta2Pow",
"The fp32 beta2 power accumulator tensor. Its shape is [1]."); "The fp32 beta2 power accumulator tensor. Its shape is [1].");
AddOutput("FusedIndices",
"The param index of each element in FP32FusedParam. Its shape is "
"[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...].");
AddOutput( AddOutput(
"FusedParamOffsets", "FusedParamOffsets",
"The numel offset of each parameter inside the FP32FusedParam. Its " "The numel offset of each parameter inside the FP32FusedParam. Its "
"shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 " "shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 "
"+ n_2, ...]."); "+ n_2, ...]. It should be in CPUPlace.");
AddOutput("FP32ShardFusedParamOffsets", AddOutput(
"The sharded numel offset of each parameter in the local rank. " "FP32ShardFusedParamOffsets",
"Its shape is [fp32_local_param_num + 1].");
AddOutput("FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. " "The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1]."); "Its shape is [fp32_local_param_num + 1]. It should be in CPUPlace.");
AddOutput( AddOutput(
"WeightDecay", "FP16ShardFusedParamOffsets",
"The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N]."); "The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1]. It should be in CPUPlace.");
AddOutput("ParamInfo", AddOutput("ParamInfo",
"The param info. It should be in CPUPlace, and its shape is [6]" "The param info. It should be in CPUPlace, and its shape is [6]"
"CPUPlace, and its shape is [6]. It is " "CPUPlace, and its shape is [8]. It is "
"[fp32_shard_param_start_idx, fp32_local_param_num, " "[fp32_shard_param_start_idx, fp32_local_param_num, "
"fp32_global_param_num, fp16_shard_param_start_idx, " "fp32_global_param_num, fp32_weight_decay_end_idx, "
"fp16_local_param_num, fp16_global_param_num]."); "fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num, "
"fp16_weight_decay_end_idx].");
AddOutput("ParamOrder",
"The reordered parameter order. Inside this op, "
"the parameter would be reordered by data type and weight decay "
"value.");
AddOutput("ParamOut", "The output parameter list.").AsDuplicable(); AddOutput("ParamOut", "The output parameter list.").AsDuplicable();
AddOutput("MasterParamOut", AddOutput("MasterParamOut",
"The output master parameter list. It would share the memory of " "The output master parameter list. It would share the memory of "
...@@ -96,10 +97,8 @@ class DistributedFusedLambInitOpMaker ...@@ -96,10 +97,8 @@ class DistributedFusedLambInitOpMaker
AddAttr<float>("beta1", "The initial value of Beta1Pow."); AddAttr<float>("beta1", "The initial value of Beta1Pow.");
AddAttr<float>("beta2", "The initial value of Beta2Pow."); AddAttr<float>("beta2", "The initial value of Beta2Pow.");
AddAttr<std::vector<float>>( AddAttr<std::vector<int>>("apply_weight_decay",
"weight_decay", "Whether to apply weight decay.");
"The weight decay for each parameter. Its "
"shape is equal to the global parameter number.");
AddAttr<int>("alignment", "The alignment in bytes for the fused tensors."); AddAttr<int>("alignment", "The alignment in bytes for the fused tensors.");
AddAttr<int>("rank", "The global rank of the current process."); AddAttr<int>("rank", "The global rank of the current process.");
AddAttr<int>("nranks", "The global world size."); AddAttr<int>("nranks", "The global world size.");
......
...@@ -258,32 +258,6 @@ static void ShareBufferForNonInitedTensor(framework::Tensor *origin, ...@@ -258,32 +258,6 @@ static void ShareBufferForNonInitedTensor(framework::Tensor *origin,
<< ") , dtype = " << fused_out->dtype(); << ") , dtype = " << fused_out->dtype();
} }
template <typename OffsetT, typename IndexT>
static __global__ void LambFillFusedIndicesCUDAKernel(const OffsetT *offsets,
IndexT *out,
int offset_num,
int out_num) {
CUDA_KERNEL_LOOP_TYPE(i, out_num, int) {
auto idx = phi::funcs::LowerBound(offsets, offset_num, i);
if (idx == offset_num || offsets[idx] != i) {
--idx;
}
out[i] = idx;
}
}
template <typename T>
static void CopyVectorToTensor(const std::vector<T> &src,
framework::Tensor *dst,
const platform::Place &place,
gpuStream_t stream) {
dst->Resize({static_cast<int64_t>(src.size())});
T *dst_ptr = dst->mutable_data<T>(place);
const T *src_ptr = src.data();
auto nbytes = src.size() * sizeof(T);
memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream);
}
template <typename T> template <typename T>
static void CopyVectorToCPUTensor(const std::vector<T> &src, static void CopyVectorToCPUTensor(const std::vector<T> &src,
framework::Tensor *dst) { framework::Tensor *dst) {
...@@ -294,6 +268,42 @@ static void CopyVectorToCPUTensor(const std::vector<T> &src, ...@@ -294,6 +268,42 @@ static void CopyVectorToCPUTensor(const std::vector<T> &src,
std::memcpy(dst_ptr, src_ptr, nbytes); std::memcpy(dst_ptr, src_ptr, nbytes);
} }
static size_t ReorderParamGradInfoList(const std::vector<int> &flags,
std::vector<ParamGradInfo> *infos) {
size_t n = infos->size();
std::vector<int> cur_flags;
cur_flags.reserve(n);
for (size_t i = 0; i < n; ++i) {
auto idx = (*infos)[i].idx;
cur_flags.push_back(flags[idx]);
}
auto origin_infos = *infos;
size_t j = 0;
for (size_t i = 0; i < n; ++i) {
if (cur_flags[i]) {
(*infos)[j] = origin_infos[i];
++j;
}
}
size_t ret_idx = j;
for (size_t i = 0; i < n; ++i) {
if (!cur_flags[i]) {
(*infos)[j] = origin_infos[i];
++j;
}
}
return ret_idx;
}
template <typename T>
static T ClipByBound(T x, T low_value, T high_value) {
if (x < low_value) return low_value;
if (x > high_value) return high_value;
return x;
}
template <typename T> template <typename T>
class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -404,6 +414,24 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -404,6 +414,24 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
info->numel_offset = 0; // not determined yet info->numel_offset = 0; // not determined yet
} }
} }
const auto &apply_weight_decay =
ctx.Attr<std::vector<int>>("apply_weight_decay");
size_t fp32_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp32_infos);
size_t fp16_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp16_infos);
auto *param_order_t = ctx.Output<framework::Tensor>("ParamOrder");
auto param_num = fp32_infos.size() + fp16_infos.size();
param_order_t->Resize({static_cast<int16_t>(param_num)});
auto *param_order = param_order_t->mutable_data<int>(platform::CPUPlace());
for (size_t i = 0; i < fp32_infos.size(); ++i) {
param_order[i] = static_cast<int>(fp32_infos[i].idx);
}
for (size_t i = 0; i < fp16_infos.size(); ++i) {
param_order[i + fp32_infos.size()] = static_cast<int>(fp16_infos[i].idx);
}
VLOG(10) << "Fill ParamGradInfo ends"; VLOG(10) << "Fill ParamGradInfo ends";
// Step 2: determine the numel_with_padding and numel_offset // Step 2: determine the numel_with_padding and numel_offset
...@@ -568,45 +596,29 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -568,45 +596,29 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
VLOG(10) << "Found the sharding arguments"; VLOG(10) << "Found the sharding arguments";
auto *param_info_t = ctx.Output<framework::Tensor>("ParamInfo"); auto *param_info_t = ctx.Output<framework::Tensor>("ParamInfo");
param_info_t->Resize({6}); param_info_t->Resize({8});
auto *param_info = param_info_t->mutable_data<int>(platform::CPUPlace()); auto *param_info = param_info_t->mutable_data<int>(platform::CPUPlace());
param_info[0] = static_cast<int>(fp32_start_idx); param_info[0] = static_cast<int>(fp32_start_idx);
param_info[1] = static_cast<int>(fp32_local_param_num); param_info[1] = static_cast<int>(fp32_local_param_num);
param_info[2] = static_cast<int>(fp32_infos.size()); param_info[2] = static_cast<int>(fp32_infos.size());
param_info[3] = static_cast<int>(fp16_start_idx + fp32_infos.size()); param_info[3] = ClipByBound<int>(fp32_wd_end_idx, fp32_start_idx,
param_info[4] = static_cast<int>(fp16_local_param_num); fp32_start_idx + fp32_local_param_num) -
param_info[5] = static_cast<int>(fp16_infos.size()); static_cast<int>(fp32_start_idx);
param_info[4] = static_cast<int>(fp16_start_idx + fp32_infos.size());
param_info[5] = static_cast<int>(fp16_local_param_num);
param_info[6] = static_cast<int>(fp16_infos.size());
param_info[7] = ClipByBound<int>(fp16_wd_end_idx, fp16_start_idx,
fp16_start_idx + fp16_local_param_num) -
static_cast<int>(fp16_start_idx);
VLOG(10) << "Start FP32 idx: " << param_info[0]; VLOG(10) << "Start FP32 idx: " << param_info[0];
VLOG(10) << "Local FP32 param num: " << param_info[1]; VLOG(10) << "Local FP32 param num: " << param_info[1];
VLOG(10) << "Global FP32 param num: " << param_info[2]; VLOG(10) << "Global FP32 param num: " << param_info[2];
VLOG(10) << "Start FP16 idx: " << param_info[3]; VLOG(10) << "Start FP16 idx: " << param_info[4];
VLOG(10) << "Local FP16 param num: " << param_info[4]; VLOG(10) << "Local FP16 param num: " << param_info[5];
VLOG(10) << "Global FP16 param num: " << param_info[5]; VLOG(10) << "Global FP16 param num: " << param_info[6];
// For WeightDecay, shard and perform H2D copy
const auto &origin_weight_decay =
ctx.Attr<std::vector<float>>("weight_decay");
PADDLE_ENFORCE_EQ(params.size(), origin_weight_decay.size(),
platform::errors::InvalidArgument(
"The attr(weight_decay) should have the "
"same length with Input(Param)."));
std::vector<float> shard_weight_decay;
shard_weight_decay.reserve(total_local_param_num);
for (size_t i = 0; i < fp32_local_param_num; ++i) {
shard_weight_decay.push_back(
origin_weight_decay[fp32_infos[i + fp32_start_idx].idx]);
}
for (size_t i = 0; i < fp16_local_param_num; ++i) {
shard_weight_decay.push_back(
origin_weight_decay[fp16_infos[i + fp16_start_idx].idx]);
}
// For FusedIndices, launch CUDA kernel to do binary search
auto *fused_indices_t = ctx.Output<framework::Tensor>("FusedIndices");
fused_indices_t->Resize({static_cast<int64_t>(total_numel)});
auto *fused_indices = fused_indices_t->mutable_data<int>(place);
std::vector<int> numel_offsets; std::vector<int> numel_offsets;
numel_offsets.reserve(params.size() + 1); numel_offsets.reserve(params.size() + 1);
for (const auto &info : fp32_infos) { for (const auto &info : fp32_infos) {
...@@ -621,21 +633,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -621,21 +633,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
"The numel_offsets number must be one larger than " "The numel_offsets number must be one larger than "
"the parameter number.")); "the parameter number."));
VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets); VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets);
auto *fused_param_offset_t =
ctx.Output<framework::Tensor>("FusedParamOffsets");
fused_param_offset_t->Resize({static_cast<int64_t>(numel_offsets.size())});
auto *fused_param_offset = fused_param_offset_t->mutable_data<int>(place);
memory::Copy(place, fused_param_offset, platform::CPUPlace(),
numel_offsets.data(),
numel_offsets.size() * sizeof(numel_offsets[0]), stream);
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, total_numel);
LambFillFusedIndicesCUDAKernel<<<config.block_per_grid,
config.thread_per_block, 0, stream>>>(
fused_param_offset, fused_indices, numel_offsets.size() - 1,
total_numel);
std::vector<int> lengths;
lengths.reserve(fp32_local_param_num + fp16_local_param_num);
std::vector<int> fp32_partial_numel_offsets; std::vector<int> fp32_partial_numel_offsets;
fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1); fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1);
...@@ -659,9 +656,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -659,9 +656,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
VLOG(10) << "FP32 Partial numel = [" VLOG(10) << "FP32 Partial numel = ["
<< valid_start_n + fp32_infos[i].numel << "," << valid_start_n + fp32_infos[i].numel << ","
<< end_n + fp32_infos[i].numel; << end_n + fp32_infos[i].numel;
lengths.push_back(end_n - valid_start_n); auto len = end_n - valid_start_n;
fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() + fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() +
lengths.back()); len);
} }
std::vector<int> fp16_partial_numel_offsets; std::vector<int> fp16_partial_numel_offsets;
...@@ -682,9 +679,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -682,9 +679,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE_NE(valid_start_n, end_n, PADDLE_ENFORCE_NE(valid_start_n, end_n,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Indices sharding error. This may be a bug.")); "Indices sharding error. This may be a bug."));
lengths.push_back(end_n - valid_start_n); auto len = end_n - valid_start_n;
fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() + fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() +
lengths.back()); len);
} }
CopyVectorToCPUTensor(numel_offsets, CopyVectorToCPUTensor(numel_offsets,
...@@ -696,23 +693,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -696,23 +693,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
fp16_partial_numel_offsets, fp16_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets")); ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"));
// Fill the weight decay tensor
PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(),
platform::errors::InvalidArgument(
"Invalid weight decay sharding. This may be a bug."));
std::vector<float> wd_cpu;
for (size_t i = 0; i < shard_weight_decay.size(); ++i) {
int len = lengths[i];
for (int j = 0; j < len; ++j) {
wd_cpu.push_back(shard_weight_decay[i]);
}
}
PADDLE_ENFORCE_EQ(wd_cpu.size() * nranks, fp32_numel + fp16_numel,
platform::errors::InvalidArgument(
"Invalid weight decay sharding. This may be a bug."));
CopyVectorToTensor(wd_cpu, ctx.Output<framework::Tensor>("WeightDecay"),
place, stream);
auto *global_scale = ctx.Output<framework::Tensor>("GlobalScale"); auto *global_scale = ctx.Output<framework::Tensor>("GlobalScale");
if (!global_scale->IsInitialized()) { if (!global_scale->IsInitialized()) {
TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f); TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f);
......
...@@ -66,28 +66,31 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,28 +66,31 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
"The fp32 beta1 power accumulator tensor. Its shape is [1]."); "The fp32 beta1 power accumulator tensor. Its shape is [1].");
AddInput("Beta2Pow", AddInput("Beta2Pow",
"The fp32 beta2 power accumulator tensor. Its shape is [1]."); "The fp32 beta2 power accumulator tensor. Its shape is [1].");
AddInput("FusedIndices",
"The param index of each element in FP32FusedParam. Its shape is "
"[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...].");
AddInput( AddInput(
"FusedParamOffsets", "FusedParamOffsets",
"The numel offset of each parameter inside the FP32FusedParam. Its " "The numel offset of each parameter inside the FP32FusedParam. Its "
"shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 " "shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 "
"+ n_2, ...]."); "+ n_2, ...]. It should be in CPUPlace.");
AddInput("FP32ShardFusedParamOffsets", AddInput(
"FP32ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. " "The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp32_local_param_num + 1]."); "Its shape is [fp32_local_param_num + 1]. It should be in CPUPlace.");
AddInput("FP16ShardFusedParamOffsets", AddInput(
"FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. " "The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1]."); "Its shape is [fp16_local_param_num + 1]. It should be in CPUPlace.");
AddInput("WeightDecay",
"The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N].");
AddInput("ParamInfo", AddInput("ParamInfo",
"The param info. It should be in CPUPlace, and its shape is [6]" "The param info. It should be in CPUPlace, and its shape is [6]"
"CPUPlace, and its shape is [6]. It is " "CPUPlace, and its shape is [8]. It is "
"[fp32_shard_param_start_idx, fp32_local_param_num, " "[fp32_shard_param_start_idx, fp32_local_param_num, "
"fp32_global_param_num, fp16_shard_param_start_idx, " "fp32_global_param_num, fp32_weight_decay_end_idx, "
"fp16_local_param_num, fp16_global_param_num]."); "fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num, "
"fp16_weight_decay_end_idx].");
AddInput("ParamOrder",
"The reordered parameter order. Inside this op, "
"the parameter would be reordered by data type and weight decay "
"value.");
AddInput("LearningRate", AddInput("LearningRate",
"The fp32 learning rate tensor. Its shape is [1]."); "The fp32 learning rate tensor. Its shape is [1].");
...@@ -116,6 +119,7 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -116,6 +119,7 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
"max_global_grad_norm", "max_global_grad_norm",
"The maximum global gradient l2-norm value for clipping. If " "The maximum global gradient l2-norm value for clipping. If "
"max_global_grad_norm <= 0, no clipping would be performed."); "max_global_grad_norm <= 0, no clipping would be performed.");
AddAttr<float>("weight_decay", "The weight decay value.");
AddAttr<bool>("clip_after_allreduce", AddAttr<bool>("clip_after_allreduce",
"Whether to clip before allreduce, only valid when the " "Whether to clip before allreduce, only valid when the "
"world size is larger than 1."); "world size is larger than 1.");
......
...@@ -87,7 +87,7 @@ struct L2NormFunctor { ...@@ -87,7 +87,7 @@ struct L2NormFunctor {
} }
}; };
template <typename InT, typename OutT, int BlockDim, bool NeedSqrt> template <typename InT, typename OutT, int BlockDim>
static __global__ void MultiTensorL2NormReduceAgainCUDAKernel( static __global__ void MultiTensorL2NormReduceAgainCUDAKernel(
const InT *x, OutT *y, int max_chunk_num) { const InT *x, OutT *y, int max_chunk_num) {
int tensor_id = blockIdx.x; int tensor_id = blockIdx.x;
...@@ -100,12 +100,8 @@ static __global__ void MultiTensorL2NormReduceAgainCUDAKernel( ...@@ -100,12 +100,8 @@ static __global__ void MultiTensorL2NormReduceAgainCUDAKernel(
} }
sum = BlockReduce(storage).Reduce(sum, cub::Sum()); sum = BlockReduce(storage).Reduce(sum, cub::Sum());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (NeedSqrt) {
y[blockIdx.x] = static_cast<OutT>(sqrtf(sum));
} else {
y[blockIdx.x] = static_cast<OutT>(sum); y[blockIdx.x] = static_cast<OutT>(sum);
} }
}
} }
template <typename T> template <typename T>
...@@ -118,6 +114,7 @@ static int GetChunkedVecSize(const T *ptr, int chunk_size) { ...@@ -118,6 +114,7 @@ static int GetChunkedVecSize(const T *ptr, int chunk_size) {
constexpr int vec8 = alignof(platform::AlignedVector<T, 8>); constexpr int vec8 = alignof(platform::AlignedVector<T, 8>);
constexpr int vec4 = alignof(platform::AlignedVector<T, 4>); constexpr int vec4 = alignof(platform::AlignedVector<T, 4>);
constexpr int vec2 = alignof(platform::AlignedVector<T, 2>); constexpr int vec2 = alignof(platform::AlignedVector<T, 2>);
chunk_size *= sizeof(T);
if (address % vec8 == 0 && chunk_size % vec8 == 0) { if (address % vec8 == 0 && chunk_size % vec8 == 0) {
return std::min(8, valid_vec_size); return std::min(8, valid_vec_size);
} else if (address % vec4 == 0 && chunk_size % vec4 == 0) { } else if (address % vec4 == 0 && chunk_size % vec4 == 0) {
...@@ -129,27 +126,26 @@ static int GetChunkedVecSize(const T *ptr, int chunk_size) { ...@@ -129,27 +126,26 @@ static int GetChunkedVecSize(const T *ptr, int chunk_size) {
} }
} }
#define PD_VEC_MULTI_TENSOR_APPLY_CASE(__vec_size, ...) \ #define PD_VEC_LAUNCH_KERNEL_CASE(__vec_size, ...) \
case __vec_size: { \ case __vec_size: { \
constexpr int kVecSize = __vec_size; \ constexpr int kVecSize = __vec_size; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} }
#define PD_VEC_MULTI_TENSOR_APPLY(__vec_size, ...) \ #define PD_VEC_LAUNCH_KERNEL(__vec_size, ...) \
do { \ do { \
switch (__vec_size) { \ switch (__vec_size) { \
PD_VEC_MULTI_TENSOR_APPLY_CASE(8, __VA_ARGS__); \ PD_VEC_LAUNCH_KERNEL_CASE(8, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(4, __VA_ARGS__); \ PD_VEC_LAUNCH_KERNEL_CASE(4, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(2, __VA_ARGS__); \ PD_VEC_LAUNCH_KERNEL_CASE(2, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(1, __VA_ARGS__); \ PD_VEC_LAUNCH_KERNEL_CASE(1, __VA_ARGS__); \
} \ } \
} while (0) } while (0)
// TODO(zengjinle): which chunk_size is better? // TODO(zengjinle): which chunk_size is better?
template <typename InT, typename OutT, bool NeedSqrt = false, template <typename InT, typename OutT, int MaxTensorNumPerLaunch = 160,
int MaxTensorNumPerLaunch = 50, int MaxChunkNumPerLaunch = 680, int MaxChunkNumPerLaunch = 780>
int BlockDim = 512>
static void MultiTensorL2Norm(const platform::CUDAPlace &place, static void MultiTensorL2Norm(const platform::CUDAPlace &place,
gpuStream_t stream, const InT *x, gpuStream_t stream, const InT *x,
const int *offsets, int n, OutT *y, const int *offsets, int n, OutT *y,
...@@ -158,7 +154,7 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, ...@@ -158,7 +154,7 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
constexpr int kNumTensor = MaxTensorNumPerLaunch; constexpr int kNumTensor = MaxTensorNumPerLaunch;
constexpr int kNumChunk = MaxChunkNumPerLaunch; constexpr int kNumChunk = MaxChunkNumPerLaunch;
constexpr int kBlockDim = BlockDim; constexpr int kBlockDim = 512;
int max_chunk_num = -1; int max_chunk_num = -1;
int vec_size = 8; int vec_size = 8;
...@@ -181,22 +177,22 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, ...@@ -181,22 +177,22 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num); auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream); FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);
#define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \ #define PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL \
do { \ do { \
using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \ using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \
VLOG(10) << __func__ << " " << typeid(InT).name() \ VLOG(10) << __func__ << " " << typeid(InT).name() \
<< " VecSize = " << kVecSize; \ << " VecSize = " << kVecSize; \
MultiTensorApply<FunctorT, kBlockDim, kNumTensor, kNumChunk>( \ MultiTensorApply<FunctorT, kNumTensor, kNumChunk>( \
FunctorT(), stream, offsets, n, chunk_size, x, tmp_out_ptr, \ FunctorT(), stream, offsets, n, chunk_size, kBlockDim, x, tmp_out_ptr, \
max_chunk_num); \ max_chunk_num); \
} while (0) } while (0)
PD_VEC_MULTI_TENSOR_APPLY(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL); PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL);
#undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL #undef PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL
MultiTensorL2NormReduceAgainCUDAKernel<MT, OutT, kBlockDim, MultiTensorL2NormReduceAgainCUDAKernel<
NeedSqrt><<<n, kBlockDim, 0, stream>>>( MT, OutT, kBlockDim><<<n, kBlockDim, 0, stream>>>(tmp_out_ptr, y,
tmp_out_ptr, y, max_chunk_num); max_chunk_num);
} }
template <int LogLevel> template <int LogLevel>
...@@ -208,34 +204,17 @@ static void LogParamAndTrustRatioDivSquareNorm( ...@@ -208,34 +204,17 @@ static void LogParamAndTrustRatioDivSquareNorm(
auto tensors = ctx.MultiInput<framework::Tensor>("Param"); auto tensors = ctx.MultiInput<framework::Tensor>("Param");
if (tensors.empty()) return; if (tensors.empty()) return;
const auto *order = ctx.Input<framework::Tensor>("ParamOrder")->data<int>();
size_t n = tensors.size(); size_t n = tensors.size();
auto place = tensors[0]->place(); auto place = tensors[0]->place();
auto pn_vec = ToVector(param_square_norm, n, place); auto pn_vec = ToVector(param_square_norm, n, place);
auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place); auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place);
std::vector<size_t> fp32_indices, fp16_indices;
fp32_indices.reserve(n);
fp16_indices.reserve(n);
for (size_t i = 0; i < n; ++i) {
const auto *t = tensors[i];
if (t->dtype() == phi::DataType::FLOAT32) {
fp32_indices.push_back(i);
} else if (t->dtype() == phi::DataType::FLOAT16) {
fp16_indices.push_back(i);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported data type %s.", t->dtype()));
}
}
for (auto idx : fp16_indices) {
fp32_indices.push_back(idx);
}
const auto &names = ctx.GetOp().Inputs("Param"); const auto &names = ctx.GetOp().Inputs("Param");
for (size_t i = 0; i < fp32_indices.size(); ++i) { for (size_t i = 0; i < n; ++i) {
auto idx = fp32_indices[i]; auto idx = order[i];
VLOG(LogLevel) << "Param " << tensors[idx]->dtype() << " " << names[idx] VLOG(LogLevel) << "Param " << tensors[idx]->dtype() << " " << names[idx]
<< " pn = " << pn_vec[i] << " , tn = " << tn_vec[i]; << " pn = " << pn_vec[i] << " , tn = " << tn_vec[i];
} }
...@@ -353,7 +332,7 @@ static __global__ void CalcGradNormClipBeforeAllReduceScale( ...@@ -353,7 +332,7 @@ static __global__ void CalcGradNormClipBeforeAllReduceScale(
const T1 *__restrict__ global_scale, T1 max_global_grad_norm, const T1 *__restrict__ global_scale, T1 max_global_grad_norm,
const T1 *__restrict__ square_grad_norm, T1 *__restrict__ out1, const T1 *__restrict__ square_grad_norm, T1 *__restrict__ out1,
T2 *__restrict__ out2, T1 clip_rescale_grad) { T2 *__restrict__ out2, T1 clip_rescale_grad) {
T1 grad_norm = static_cast<T1>(sqrt(*square_grad_norm)) * clip_rescale_grad; T1 grad_norm = static_cast<T1>(sqrtf(*square_grad_norm)) * clip_rescale_grad;
T1 scale = global_scale[0] * max_global_grad_norm / (1e-6 + grad_norm); T1 scale = global_scale[0] * max_global_grad_norm / (1e-6 + grad_norm);
bool found_nan_inf = !isfinite(scale); bool found_nan_inf = !isfinite(scale);
if (scale >= 1 || found_nan_inf) { if (scale >= 1 || found_nan_inf) {
...@@ -380,19 +359,24 @@ static __global__ void SetNanInfValueCUDAKernelTwoFlag(const bool *in_flag_p_1, ...@@ -380,19 +359,24 @@ static __global__ void SetNanInfValueCUDAKernelTwoFlag(const bool *in_flag_p_1,
((*in_flag_p_1) || (*in_flag_p_2)) ? __int_as_float(0x7fffffffU) : 0.0f; ((*in_flag_p_1) || (*in_flag_p_2)) ? __int_as_float(0x7fffffffU) : 0.0f;
} }
// TODO(zengjinle): Vectorize this function template <typename T, typename GradT, int VecSize>
// NOTE: this method does not update Beta1Pow and Beta2Pow! static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
template <typename T, typename GradT, typename IndexT>
static __global__ void UpdateLambMoment(
const T *__restrict__ param_p, const GradT *__restrict__ grad_p, const T *__restrict__ param_p, const GradT *__restrict__ grad_p,
const T *__restrict__ square_grad_norm_p, const T *__restrict__ square_grad_norm_p,
const T *__restrict__ global_scale, const IndexT *__restrict__ indices, const T *__restrict__ global_scale, const T *__restrict__ beta1pow_p,
const T *__restrict__ weight_decay_p, const T *__restrict__ beta1pow_p,
const T *__restrict__ beta2pow_p, T *__restrict__ mom1_p, const T *__restrict__ beta2pow_p, T *__restrict__ mom1_p,
T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p, T beta1, T beta2, T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p, bool *found_inf,
T epsilon, T max_global_grad_norm, int num, T rescale_grad) { T weight_decay, int weight_decay_end_numel, T beta1, T beta2, T epsilon,
T max_global_grad_norm, int num, T rescale_grad) {
T square_grad_norm = *square_grad_norm_p; T square_grad_norm = *square_grad_norm_p;
if (!isfinite(square_grad_norm)) return; bool need_update_found_inf =
(found_inf && threadIdx.x == 0 && blockIdx.x == 0);
if (!isfinite(square_grad_norm)) {
if (need_update_found_inf) *found_inf = true;
return;
} else if (need_update_found_inf) {
*found_inf = false;
}
T scale = rescale_grad / global_scale[0]; T scale = rescale_grad / global_scale[0];
if (max_global_grad_norm > 0) { if (max_global_grad_norm > 0) {
...@@ -406,27 +390,112 @@ static __global__ void UpdateLambMoment( ...@@ -406,27 +390,112 @@ static __global__ void UpdateLambMoment(
T one_minus_beta1pow = 1 - beta1pow_p[0]; T one_minus_beta1pow = 1 - beta1pow_p[0];
T one_minus_beta2pow = 1 - beta2pow_p[0]; T one_minus_beta2pow = 1 - beta2pow_p[0];
CUDA_KERNEL_LOOP(i, num) { int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
T p = param_p[i]; int stride = blockDim.x * gridDim.x * VecSize;
T g = static_cast<T>(grad_p[i]) * scale;
T weight_decay = weight_decay_p[i];
T mom1 = mom1_p[i];
T mom2 = mom2_p[i];
mom1 = beta1 * mom1 + (1 - beta1) * g; for (; i + VecSize <= num; i += stride) {
mom2 = beta2 * mom2 + (1 - beta2) * g * g; platform::AlignedVector<T, VecSize> param_vec;
platform::AlignedVector<GradT, VecSize> grad_vec;
platform::AlignedVector<T, VecSize> weight_decay_vec;
platform::AlignedVector<T, VecSize> mom1_vec;
platform::AlignedVector<T, VecSize> mom2_vec;
platform::AlignedVector<T, VecSize> trust_ratio_div_vec;
T mom1_unbiased = mom1 / one_minus_beta1pow; T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay;
T mom2_unbiased = mom2 / one_minus_beta2pow; if (cur_weight_decay != static_cast<T>(0.0)) {
T trust_ratio_div = platform::Load(param_p + i, &param_vec);
mom1_unbiased / (sqrtf(mom2_unbiased) + epsilon) + weight_decay * p; } else {
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
param_vec[j] = static_cast<T>(0);
}
}
platform::Load(grad_p + i, &grad_vec);
platform::Load(mom1_p + i, &mom1_vec);
platform::Load(mom2_p + i, &mom2_vec);
#define PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(__param, __grad, __mom1, __mom2, \
__trust_ratio_div, __idx) \
T p = __param[__idx]; \
T g = static_cast<T>(__grad[__idx]) * scale; \
T mom1 = __mom1[__idx]; \
T mom2 = __mom2[__idx]; \
mom1 = beta1 * mom1 + (1 - beta1) * g; \
mom2 = beta2 * mom2 + (1 - beta2) * g * g; \
T mom1_unbiased = mom1 / one_minus_beta1pow; \
T mom2_unbiased = mom2 / one_minus_beta2pow; \
__trust_ratio_div[__idx] = \
mom1_unbiased / (sqrtf(mom2_unbiased) + epsilon) + cur_weight_decay * p; \
__mom1[__idx] = mom1; \
__mom2[__idx] = mom2;
mom1_p[i] = mom1; #pragma unroll
mom2_p[i] = mom2; for (int j = 0; j < VecSize; ++j) {
trust_ratio_div_p[i] = trust_ratio_div; PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(param_vec, grad_vec, mom1_vec,
mom2_vec, trust_ratio_div_vec, j);
}
platform::Store(mom1_vec, mom1_p + i);
platform::Store(mom2_vec, mom2_p + i);
platform::Store(trust_ratio_div_vec, trust_ratio_div_p + i);
}
for (; i < num; ++i) {
T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay;
PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(param_p, grad_p, mom1_p, mom2_p,
trust_ratio_div_p, i);
} }
} }
template <typename T, typename GradT>
static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
const platform::CUDADeviceContext &dev_ctx, const int *offsets, int n,
const T *param_p, const GradT *grad_p, const T *square_grad_norm_p,
const T *global_scale, const T *beta1pow_p, const T *beta2pow_p, T *mom1_p,
T *mom2_p, T *trust_ratio_div_p, bool *found_inf_p, T weight_decay,
int weight_decay_end_idx, T beta1, T beta2, T epsilon,
T max_global_grad_norm, T rescale_grad) {
if (n <= 0) return;
int numel = offsets[n] - offsets[0];
PADDLE_ENFORCE_GE(weight_decay_end_idx, 0,
platform::errors::InvalidArgument(
"The weight decay end index should be >= 0."));
PADDLE_ENFORCE_LE(weight_decay_end_idx, n,
platform::errors::InvalidArgument(
"The weight decay end index should be < %d.", n));
auto weight_decay_end_numel = offsets[weight_decay_end_idx] - offsets[0];
int vec_size = GetChunkedVecSize(param_p, 0);
vec_size = std::min(vec_size, GetChunkedVecSize(grad_p, 0));
vec_size = std::min(vec_size, GetChunkedVecSize(mom1_p, 0));
vec_size = std::min(vec_size, GetChunkedVecSize(mom2_p, 0));
vec_size = std::min(vec_size, GetChunkedVecSize(trust_ratio_div_p, 0));
for (int i = 0; i < n; ++i) {
auto length = offsets[i + 1] - offsets[i];
while (length % vec_size != 0) {
vec_size /= 2;
}
}
VLOG(1) << __func__ << " VecSize = " << vec_size;
auto stream = dev_ctx.stream();
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \
do { \
UpdateLambMomentAndTrustRatioDivCUDAKernel<T, GradT, kVecSize><<< \
config.block_per_grid, config.thread_per_block, 0, stream>>>( \
param_p, grad_p, square_grad_norm_p, global_scale, beta1pow_p, \
beta2pow_p, mom1_p, mom2_p, trust_ratio_div_p, found_inf_p, \
weight_decay, weight_decay_end_numel, beta1, beta2, epsilon, \
max_global_grad_norm, numel, rescale_grad); \
} while (0)
PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL);
#undef PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL
}
template <typename T, bool NeedUpdate /*=true*/> template <typename T, bool NeedUpdate /*=true*/>
struct LambBetaPowUpdateOnceHelper { struct LambBetaPowUpdateOnceHelper {
LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) { LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) {
...@@ -468,33 +537,6 @@ struct LambBetaPowUpdateOnceHelper<T, false> { ...@@ -468,33 +537,6 @@ struct LambBetaPowUpdateOnceHelper<T, false> {
HOSTDEVICE void UpdateBetaPows() const {} HOSTDEVICE void UpdateBetaPows() const {}
}; };
template <bool HasFoundInf /*=true*/>
struct LambFoundInfHelper {
public:
explicit LambFoundInfHelper(bool *found_inf) : found_inf_(found_inf) {
PADDLE_ENFORCE_NOT_NULL(found_inf,
platform::errors::InvalidArgument(
"The found_inf should not be nullptr."));
}
HOSTDEVICE void UpdateFoundInf(bool value) { *found_inf_ = value; }
private:
bool *__restrict__ found_inf_;
};
template <>
struct LambFoundInfHelper<false> {
public:
explicit LambFoundInfHelper(bool *found_inf) {
PADDLE_ENFORCE_EQ(
found_inf, nullptr,
platform::errors::InvalidArgument("The found_inf should be nullptr."));
}
HOSTDEVICE void UpdateFoundInf(bool) {}
};
template <typename T, bool HasMasterParam /*=true*/> template <typename T, bool HasMasterParam /*=true*/>
struct LambParamHelper { struct LambParamHelper {
LambParamHelper(T *param, MasterT<T> *master_param) { LambParamHelper(T *param, MasterT<T> *master_param) {
...@@ -509,12 +551,9 @@ struct LambParamHelper { ...@@ -509,12 +551,9 @@ struct LambParamHelper {
master_param_ = master_param; master_param_ = master_param;
} }
HOSTDEVICE void SetParam(int i, MasterT<T> updated_p) { HOSTDEVICE T *__restrict__ ParamPtr() { return param_; }
param_[i] = static_cast<T>(updated_p);
master_param_[i] = updated_p;
}
HOSTDEVICE MasterT<T> GetParam(int i) { return master_param_[i]; } HOSTDEVICE MasterT<T> *__restrict__ MasterParamPtr() { return master_param_; }
private: private:
T *__restrict__ param_; T *__restrict__ param_;
...@@ -538,158 +577,169 @@ struct LambParamHelper<T, false> { ...@@ -538,158 +577,169 @@ struct LambParamHelper<T, false> {
param_ = param; param_ = param;
} }
HOSTDEVICE void SetParam(int i, MasterT<T> updated_p) { HOSTDEVICE T *__restrict__ ParamPtr() { return param_; }
param_[i] = static_cast<T>(updated_p);
}
HOSTDEVICE MasterT<T> GetParam(int i) { HOSTDEVICE constexpr MasterT<T> *MasterParamPtr() { return nullptr; }
return static_cast<MasterT<T>>(param_[i]);
}
private: private:
T *__restrict__ param_; T *__restrict__ param_;
}; };
template <typename ParamT, typename IndexT, bool HasMasterParam, template <typename ParamT, bool HasMasterParam, bool NeedUpdateBetaPow,
bool NeedUpdateBetaPow, bool HasFoundInf> int VecSize>
struct LambParamAndBetaPowsUpdateHelper struct LambUpdateParamAndBetaPowsFunctor {
: public LambParamHelper<ParamT, HasMasterParam>, DEVICE void operator()(
public LambBetaPowUpdateOnceHelper<MasterT<ParamT>, NeedUpdateBetaPow>, int tensor_id, int chunk_id, int offset, int size,
public LambFoundInfHelper<HasFoundInf> { LambParamHelper<ParamT, HasMasterParam> param_helper,
LambParamAndBetaPowsUpdateHelper( const MasterT<ParamT> *trust_ratio_div, const MasterT<ParamT> *lr,
ParamT *param, MasterT<ParamT> *master_param, MasterT<ParamT> *beta1pow,
MasterT<ParamT> *beta2pow, MasterT<ParamT> beta1, MasterT<ParamT> beta2,
bool *found_inf, const MasterT<ParamT> *trust_ratio_div,
const MasterT<ParamT> *lr, const IndexT *index,
const MasterT<ParamT> *param_square_norm, const MasterT<ParamT> *param_square_norm,
const MasterT<ParamT> *trust_ratio_div_square_norm, const MasterT<ParamT> *trust_ratio_div_square_norm, const bool *found_inf,
const MasterT<ParamT> *update_flag) LambBetaPowUpdateOnceHelper<MasterT<ParamT>, NeedUpdateBetaPow>
: LambParamHelper<ParamT, HasMasterParam>(param, master_param), betapow_helper) const {
LambBetaPowUpdateOnceHelper<MasterT<ParamT>, NeedUpdateBetaPow>( if (*found_inf) return;
beta1pow, beta2pow, beta1, beta2),
LambFoundInfHelper<HasFoundInf>(found_inf),
trust_ratio_div(trust_ratio_div),
lr(lr),
index(index),
param_square_norm(param_square_norm),
trust_ratio_div_square_norm(trust_ratio_div_square_norm),
update_flag(update_flag) {}
const MasterT<ParamT> *__restrict__ trust_ratio_div;
const MasterT<ParamT> *__restrict__ lr;
const IndexT *__restrict__ index;
const MasterT<ParamT> *__restrict__ param_square_norm;
const MasterT<ParamT> *__restrict__ trust_ratio_div_square_norm;
const MasterT<ParamT> *__restrict__ update_flag;
};
template <typename ParamT, typename IndexT, bool HasMasterParam,
bool NeedUpdateBetaPow, bool HasFoundInf>
static __global__ void LambUpdateParamAndBetaPowsCUDAKernel(
LambParamAndBetaPowsUpdateHelper<ParamT, IndexT, HasMasterParam,
NeedUpdateBetaPow, HasFoundInf>
args,
int num) {
auto should_update = *args.update_flag;
if (!isfinite(should_update)) {
if (HasFoundInf && threadIdx.x == 0 && blockIdx.x == 0) {
args.UpdateFoundInf(true);
}
return;
} else if (HasFoundInf && threadIdx.x == 0 && blockIdx.x == 0) {
args.UpdateFoundInf(false);
}
if (NeedUpdateBetaPow && threadIdx.x == 0 && blockIdx.x == 0) {
args.UpdateBetaPows();
}
using MT = MasterT<ParamT>; using MT = MasterT<ParamT>;
MT lr_value = *args.lr; MT p_square_norm = param_square_norm[tensor_id];
CUDA_KERNEL_LOOP(i, num) { MT t_square_norm = trust_ratio_div_square_norm[tensor_id];
MT p = args.GetParam(i); MT lr_value = *lr;
MT t = args.trust_ratio_div[i]; MT ratio = (p_square_norm != static_cast<MT>(0) &&
auto norm_idx = args.index[i]; t_square_norm != static_cast<MT>(0)
MT p_square_norm = args.param_square_norm[norm_idx]; ? lr_value * sqrtf(p_square_norm / t_square_norm)
MT t_square_norm = args.trust_ratio_div_square_norm[norm_idx]; : lr_value);
MT p_norm = static_cast<MT>(sqrtf(p_square_norm));
MT t_norm = static_cast<MT>(sqrtf(t_square_norm));
auto update = (p_norm != static_cast<MT>(0) && t_norm != static_cast<MT>(0)) int i;
? p_norm / t_norm int stride = blockDim.x * VecSize;
: static_cast<MT>(1);
ParamT *param = param_helper.ParamPtr() + offset;
MT updated_p = p - lr_value * update * t; MT *master_param = HasMasterParam ? param_helper.MasterParamPtr() + offset
args.SetParam(i, updated_p); : param_helper.MasterParamPtr();
trust_ratio_div += offset;
for (i = threadIdx.x * VecSize; i + VecSize <= size; i += stride) {
platform::AlignedVector<MT, VecSize> trust_ratio_div_vec;
platform::Load(trust_ratio_div + i, &trust_ratio_div_vec);
if (HasMasterParam) {
platform::AlignedVector<MT, VecSize> master_param_vec;
platform::Load(master_param + i, &master_param_vec);
platform::AlignedVector<ParamT, VecSize> param_vec;
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
MT p = master_param_vec[j] - ratio * trust_ratio_div_vec[j];
master_param_vec[j] = p;
param_vec[j] = static_cast<ParamT>(p);
} }
} platform::Store(master_param_vec, master_param + i);
platform::Store(param_vec, param + i);
template <typename ParamT, typename IndexT>
static void LambUpdateParamAndBetaPows(
const platform::CUDADeviceContext &dev_ctx,
const MasterT<ParamT> *trust_ratio_div, const MasterT<ParamT> *lr,
const IndexT *index, const MasterT<ParamT> *param_square_norm,
const MasterT<ParamT> *trust_ratio_div_square_norm,
const MasterT<ParamT> *update_flag, MasterT<ParamT> **beta1pow,
MasterT<ParamT> **beta2pow, bool **found_inf, MasterT<ParamT> beta1,
MasterT<ParamT> beta2, int num, ParamT *param,
MasterT<ParamT> *master_param, gpuStream_t stream) {
if (num == 0) return;
bool has_master_param = !(std::is_same<ParamT, MasterT<ParamT>>::value);
auto has_beta_pow = (*beta1pow) != nullptr && (*beta2pow) != nullptr;
auto has_found_inf = (*found_inf) != nullptr;
#define PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL( \
__has_master_param, __has_beta_pow, __has_found_inf) \
do { \
LambParamAndBetaPowsUpdateHelper<ParamT, IndexT, __has_master_param, \
__has_beta_pow, __has_found_inf> \
helper(param, master_param, *beta1pow, *beta2pow, beta1, beta2, \
*found_inf, trust_ratio_div, lr, index, param_square_norm, \
trust_ratio_div_square_norm, update_flag); \
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, num); \
LambUpdateParamAndBetaPowsCUDAKernel<<< \
config.block_per_grid, config.thread_per_block, 0, stream>>>(helper, \
num); \
} while (0)
if (has_master_param) {
if (has_beta_pow) {
if (has_found_inf) {
PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, true, true);
} else { } else {
PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, true, false); platform::AlignedVector<ParamT, VecSize> param_vec;
platform::Load(param + i, &param_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
MT p = static_cast<MT>(param_vec[j]) - ratio * trust_ratio_div_vec[j];
param_vec[j] = static_cast<ParamT>(p);
} }
} else { platform::Store(param_vec, param + i);
if (has_found_inf) {
PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, false, true);
} else {
PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, false, false);
} }
} }
for (; i < size; ++i) {
if (HasMasterParam) {
MT p = master_param[i] - ratio * trust_ratio_div[i];
master_param[i] = p;
param[i] = static_cast<ParamT>(p);
} else { } else {
if (has_beta_pow) { MT p = static_cast<MT>(param[i]) - ratio * trust_ratio_div[i];
if (has_found_inf) { param[i] = static_cast<ParamT>(p);
PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, true, true);
} else {
PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, true, false);
} }
}
if (NeedUpdateBetaPow && threadIdx.x == 0 && blockIdx.x == 0) {
betapow_helper.UpdateBetaPows();
}
}
};
// TODO(zengjinle): which block_dim and chunk_size would be better?
template <typename ParamT, int MaxTensorNumPerLaunch = 160,
int MaxChunkNumPerLaunch = 780>
static void MultiTensorUpdateLambParamAndBetaPows(
const platform::CUDADeviceContext &dev_ctx, const int *offsets, int n,
const MasterT<ParamT> *trust_ratio_div, const MasterT<ParamT> *lr,
const MasterT<ParamT> *param_square_norm,
const MasterT<ParamT> *trust_ratio_div_square_norm, const bool *found_inf,
ParamT *param, MasterT<ParamT> *master_param, MasterT<ParamT> *beta1pow,
MasterT<ParamT> *beta2pow, MasterT<ParamT> beta1, MasterT<ParamT> beta2,
int chunk_size = 65536) {
constexpr bool kHasMasterParam =
!(std::is_same<ParamT, MasterT<ParamT>>::value);
bool has_beta_pow = (beta1pow != nullptr);
if (has_beta_pow) {
PADDLE_ENFORCE_NOT_NULL(beta2pow, platform::errors::InvalidArgument(
"Beta2Pow should not be nullptr."));
} else { } else {
if (has_found_inf) { PADDLE_ENFORCE_EQ(beta2pow, nullptr, platform::errors::InvalidArgument(
PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, false, true); "Beta2Pow should be nullptr."));
} else {
PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, false, false);
} }
const int block_dim = 512;
int vec_size = 8;
for (int i = 0; i < n; ++i) {
int offset = offsets[i] - offsets[0];
vec_size =
std::min(vec_size, GetChunkedVecSize(param + offset, chunk_size));
if (kHasMasterParam) {
vec_size = std::min(vec_size,
GetChunkedVecSize(master_param + offset, chunk_size));
} }
vec_size = std::min(
vec_size, GetChunkedVecSize(trust_ratio_div + offset, chunk_size));
} }
*beta1pow = nullptr; VLOG(1) << __func__ << " VecSize = " << vec_size;
*beta2pow = nullptr;
*found_inf = nullptr; constexpr auto kNumTensor = MaxTensorNumPerLaunch;
#undef PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL constexpr auto kNumChunk = MaxChunkNumPerLaunch;
auto stream = dev_ctx.stream();
#define PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(__has_beta_pow) \
do { \
using FunctorT = \
LambUpdateParamAndBetaPowsFunctor<ParamT, kHasMasterParam, \
__has_beta_pow, kVecSize>; \
LambParamHelper<ParamT, kHasMasterParam> param_helper(param, \
master_param); \
LambBetaPowUpdateOnceHelper<MasterT<ParamT>, __has_beta_pow> \
betapow_helper(beta1pow, beta2pow, beta1, beta2); \
launcher.Launch(FunctorT(), param_helper, trust_ratio_div, lr, \
param_square_norm, trust_ratio_div_square_norm, found_inf, \
betapow_helper); \
} while (0)
#define PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE \
do { \
auto callback = [&]( \
const MultiTensorLauncher<kNumTensor, kNumChunk> &launcher, \
int launch_n) { \
if (has_beta_pow && launch_n == 0) { \
PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(true); \
beta1pow = nullptr; \
beta2pow = nullptr; \
} else { \
PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(false); \
} \
}; \
MultiTensorApplyWithCallback<kNumTensor, kNumChunk>( \
stream, offsets, n, chunk_size, block_dim, callback); \
} while (0)
PD_VEC_LAUNCH_KERNEL(vec_size,
PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE);
#undef PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW
#undef PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE
} }
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
...@@ -1005,15 +1055,16 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1005,15 +1055,16 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
"Too many parameter number. Only <= %d is supported.", "Too many parameter number. Only <= %d is supported.",
std::numeric_limits<int>::max())); std::numeric_limits<int>::max()));
// Step 3: Get FusedIndices, ParamInfo // Step 3: Get ParamInfo
const auto *indices = GetInputTensorPtr<int>(ctx, "FusedIndices");
const auto *param_info_tensor = GetInputTensorPtr<int>(ctx, "ParamInfo"); const auto *param_info_tensor = GetInputTensorPtr<int>(ctx, "ParamInfo");
auto fp32_local_start_idx = param_info_tensor[0]; auto fp32_local_start_idx = param_info_tensor[0];
auto fp32_local_param_num = param_info_tensor[1]; auto fp32_local_param_num = param_info_tensor[1];
auto fp32_global_param_num = param_info_tensor[2]; auto fp32_global_param_num = param_info_tensor[2];
auto fp16_local_start_idx = param_info_tensor[3]; auto fp32_weight_decay_end_idx = param_info_tensor[3];
auto fp16_local_param_num = param_info_tensor[4]; auto fp16_local_start_idx = param_info_tensor[4];
auto fp16_global_param_num = param_info_tensor[5]; auto fp16_local_param_num = param_info_tensor[5];
auto fp16_global_param_num = param_info_tensor[6];
auto fp16_weight_decay_end_idx = param_info_tensor[7];
auto local_param_num = fp32_local_param_num + fp16_local_param_num; auto local_param_num = fp32_local_param_num + fp16_local_param_num;
auto param_num = fp32_global_param_num + fp16_global_param_num; auto param_num = fp32_global_param_num + fp16_global_param_num;
...@@ -1031,7 +1082,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1031,7 +1082,7 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
<< " , fp16_global_param_num = " << fp16_global_param_num; << " , fp16_global_param_num = " << fp16_global_param_num;
// Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow, // Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow,
// WeightDecay, GlobalScale, FoundInf // GlobalScale, FoundInf
const auto *global_scale = GetInputTensorPtr<float>(ctx, "GlobalScale"); const auto *global_scale = GetInputTensorPtr<float>(ctx, "GlobalScale");
const auto *lr = GetInputTensorPtr<float>(ctx, "LearningRate"); const auto *lr = GetInputTensorPtr<float>(ctx, "LearningRate");
int64_t partial_numel = 0; int64_t partial_numel = 0;
...@@ -1065,14 +1116,15 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1065,14 +1116,15 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
GetSameInOutTensorPtr<float>(ctx, place, "Beta1Pow", "Beta1PowOut"); GetSameInOutTensorPtr<float>(ctx, place, "Beta1Pow", "Beta1PowOut");
auto *beta2pow = auto *beta2pow =
GetSameInOutTensorPtr<float>(ctx, place, "Beta2Pow", "Beta2PowOut"); GetSameInOutTensorPtr<float>(ctx, place, "Beta2Pow", "Beta2PowOut");
const float *weight_decay = GetInputTensorPtr<float>(ctx, "WeightDecay");
auto *found_inf_t = ctx.Output<framework::Tensor>("FoundInf"); auto *found_inf_t = ctx.Output<framework::Tensor>("FoundInf");
found_inf_t->Resize({1}); found_inf_t->Resize({1});
auto *found_inf = found_inf_t->mutable_data<bool>(place); auto *found_inf = found_inf_t->mutable_data<bool>(place);
// Step 5: Get attributes beta1, beta2, epsilon, max_grad_norm, ring_id, // Step 5: Get attributes weight_decay, beta1, beta2, epsilon,
// max_grad_norm, ring_id,
// use_master_param_norm, is_grad_scaled_by_nranks // use_master_param_norm, is_grad_scaled_by_nranks
auto weight_decay = ctx.Attr<float>("weight_decay");
auto beta1 = ctx.Attr<float>("beta1"); auto beta1 = ctx.Attr<float>("beta1");
auto beta2 = ctx.Attr<float>("beta2"); auto beta2 = ctx.Attr<float>("beta2");
auto epsilon = ctx.Attr<float>("epsilon"); auto epsilon = ctx.Attr<float>("epsilon");
...@@ -1105,7 +1157,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1105,7 +1157,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
platform::float16 *fp16_sum_grad; platform::float16 *fp16_sum_grad;
auto fp32_numel_each_device = fp32_numel / num_devices; auto fp32_numel_each_device = fp32_numel / num_devices;
auto fp16_numel_each_device = fp16_numel / num_devices; auto fp16_numel_each_device = fp16_numel / num_devices;
if (num_devices > 1) { if (num_devices > 1 ||
(max_global_grad_norm > 0 && !clip_after_allreduce)) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>( auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel_each_device * sizeof(float) + fp32_numel_each_device * sizeof(float) +
fp16_numel_each_device * sizeof(platform::float16)); fp16_numel_each_device * sizeof(platform::float16));
...@@ -1181,7 +1234,11 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1181,7 +1234,11 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
float, platform::float16><<<1, 1, 0, stream>>>( float, platform::float16><<<1, 1, 0, stream>>>(
global_scale, max_global_grad_norm, fp32_square_grad_norm, global_scale, max_global_grad_norm, fp32_square_grad_norm,
fp32_scale, fp16_scale, clip_scale); fp32_scale, fp16_scale, clip_scale);
if (fp32_scale) {
VLOG(1) << "Grad scale: " << FlattenToString(fp32_scale, 1, place); VLOG(1) << "Grad scale: " << FlattenToString(fp32_scale, 1, place);
} else {
VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place);
}
if (num_devices > 1) { if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
...@@ -1218,36 +1275,56 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1218,36 +1275,56 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
VLOG(10) << "ReduceScatter done"; VLOG(10) << "ReduceScatter done";
// Step 7: update the moment1, moment2. Calcuate the trust_ratio_div // Step 7: update the moment1, moment2. Calcuate the trust_ratio_div
auto *fused_offsets_t = ctx.Input<framework::Tensor>("FusedParamOffsets");
auto *fused_offsets = fused_offsets_t->data<int>();
auto *fp32_partial_fused_offsets_t =
ctx.Input<framework::Tensor>("FP32ShardFusedParamOffsets");
const auto *fp32_partial_fused_offsets =
fp32_partial_fused_offsets_t->data<int>();
auto *fp16_partial_fused_offsets_t =
ctx.Input<framework::Tensor>("FP16ShardFusedParamOffsets");
const auto *fp16_partial_fused_offsets =
fp16_partial_fused_offsets_t->data<int>();
VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(fused_offsets, fused_offsets_t->numel(),
fused_offsets_t->place());
VLOG(1) << "FP32ShardFusedParamOffsets: "
<< FlattenToString(fp32_partial_fused_offsets,
fp32_partial_fused_offsets_t->numel(),
fp32_partial_fused_offsets_t->place());
VLOG(1) << "FP16ShardFusedParamOffsets: "
<< FlattenToString(fp16_partial_fused_offsets,
fp16_partial_fused_offsets_t->numel(),
fp16_partial_fused_offsets_t->place());
memory::Buffer trust_ratio_div_buffer(place); memory::Buffer trust_ratio_div_buffer(place);
auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel); auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
auto fp32_offset = rank * fp32_numel_each_device; auto fp32_offset = rank * fp32_numel_each_device;
auto fp16_offset = rank * fp16_numel_each_device; auto fp16_offset = rank * fp16_numel_each_device;
if (has_fp32_param) { if (has_fp32_param) {
auto config =
platform::GetGpuLaunchConfig1D(dev_ctx, fp32_numel_each_device);
VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts"; VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts";
UpdateLambMoment<<<config.block_per_grid, config.thread_per_block, 0, MultiTensorUpdateLambMomentAndTrustRatioDiv(
stream>>>( dev_ctx, fp32_partial_fused_offsets, fp32_local_param_num,
fp32_param + fp32_offset, fp32_sum_grad, fp32_square_grad_norm, fp32_param + fp32_offset, fp32_sum_grad, fp32_square_grad_norm,
global_scale, indices + fp32_offset, weight_decay, beta1pow, beta2pow, global_scale, beta1pow, beta2pow, moment1, moment2, trust_ratio_div,
moment1, moment2, trust_ratio_div, beta1, beta2, epsilon, found_inf, weight_decay, fp32_weight_decay_end_idx, beta1, beta2,
max_global_grad_norm, fp32_numel_each_device, rescale_grad); epsilon, max_global_grad_norm, rescale_grad);
VLOG(10) << "Update FP32 Moment and TrustRatioDiv done"; VLOG(10) << "Update FP32 Moment and TrustRatioDiv done";
} }
float *master_param = nullptr; float *master_param = nullptr;
if (has_fp16_param) { if (has_fp16_param) {
master_param = fp32_param + fp32_numel; master_param = fp32_param + fp32_numel;
auto config =
platform::GetGpuLaunchConfig1D(dev_ctx, fp16_numel_each_device);
VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts"; VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts";
UpdateLambMoment<<<config.block_per_grid, config.thread_per_block, 0, auto tmp_found_inf = has_fp32_param ? nullptr : found_inf;
stream>>>( MultiTensorUpdateLambMomentAndTrustRatioDiv(
dev_ctx, fp16_partial_fused_offsets, fp16_local_param_num,
master_param + fp16_offset, fp16_sum_grad, fp32_square_grad_norm, master_param + fp16_offset, fp16_sum_grad, fp32_square_grad_norm,
global_scale, indices + fp32_numel + fp16_offset, weight_decay, global_scale, beta1pow, beta2pow, moment1 + fp32_numel_each_device,
beta1pow, beta2pow, moment1 + fp32_numel_each_device,
moment2 + fp32_numel_each_device, moment2 + fp32_numel_each_device,
trust_ratio_div + fp32_numel_each_device, beta1, beta2, epsilon, trust_ratio_div + fp32_numel_each_device, tmp_found_inf, weight_decay,
max_global_grad_norm, fp16_numel_each_device, rescale_grad); fp16_weight_decay_end_idx, beta1, beta2, epsilon,
max_global_grad_norm, rescale_grad);
VLOG(10) << "Update FP16 Moment and TrustRatioDiv done"; VLOG(10) << "Update FP16 Moment and TrustRatioDiv done";
} }
...@@ -1257,30 +1334,6 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1257,30 +1334,6 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
memory::Buffer square_norm_buffer(place); memory::Buffer square_norm_buffer(place);
auto *param_square_norm = square_norm_buffer.Alloc<float>(2 * param_num); auto *param_square_norm = square_norm_buffer.Alloc<float>(2 * param_num);
auto *trust_ratio_div_square_norm = param_square_norm + param_num; auto *trust_ratio_div_square_norm = param_square_norm + param_num;
auto *fused_offsets_t = ctx.Input<framework::Tensor>("FusedParamOffsets");
auto *fused_offsets = fused_offsets_t->data<int>();
auto *fp32_partial_fused_offsets_t =
ctx.Input<framework::Tensor>("FP32ShardFusedParamOffsets");
const auto *fp32_partial_fused_offsets =
fp32_partial_fused_offsets_t->data<int>();
auto *fp16_partial_fused_offsets_t =
ctx.Input<framework::Tensor>("FP16ShardFusedParamOffsets");
const auto *fp16_partial_fused_offsets =
fp16_partial_fused_offsets_t->data<int>();
VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(fused_offsets, fused_offsets_t->numel(),
fused_offsets_t->place());
VLOG(1) << "FP32ShardFusedParamOffsets: "
<< FlattenToString(fp32_partial_fused_offsets,
fp32_partial_fused_offsets_t->numel(),
fp32_partial_fused_offsets_t->place());
VLOG(1) << "FP16ShardFusedParamOffsets: "
<< FlattenToString(fp16_partial_fused_offsets,
fp16_partial_fused_offsets_t->numel(),
fp16_partial_fused_offsets_t->place());
if (num_devices > 1) { if (num_devices > 1) {
if (use_master_param_norm) { if (use_master_param_norm) {
FillZeroWithPtr(param_square_norm + fp32_global_param_num, FillZeroWithPtr(param_square_norm + fp32_global_param_num,
...@@ -1296,11 +1349,11 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1296,11 +1349,11 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
fp16_partial_fused_offsets, fp16_local_param_num, fp16_partial_fused_offsets, fp16_local_param_num,
param_square_norm + fp16_local_start_idx); param_square_norm + fp16_local_start_idx);
} else { } else {
// NOTE: extra computation is performed. We can improve this performance
// if needed in the future.
MultiTensorL2Norm( MultiTensorL2Norm(
place, stream, fp16_param, fused_offsets + fp32_global_param_num, place, stream, fp16_param + fused_offsets[fp16_local_start_idx] -
fp16_global_param_num, param_square_norm + fp32_global_param_num); fused_offsets[fp32_global_param_num],
fused_offsets + fp16_local_start_idx, fp16_local_param_num,
param_square_norm + fp16_local_start_idx);
} }
MultiTensorL2Norm(place, stream, trust_ratio_div, MultiTensorL2Norm(place, stream, trust_ratio_div,
...@@ -1333,26 +1386,29 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1333,26 +1386,29 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
// Step 9: update parameter, beta1pow, beta2pow. All gather parameters. // Step 9: update parameter, beta1pow, beta2pow. All gather parameters.
if (has_fp32_param) { if (has_fp32_param) {
LambUpdateParamAndBetaPows<float>( MultiTensorUpdateLambParamAndBetaPows<float>(
dev_ctx, trust_ratio_div, lr, indices + fp32_offset, dev_ctx, fp32_partial_fused_offsets, fp32_local_param_num,
param_square_norm, trust_ratio_div_square_norm, fp32_square_grad_norm, trust_ratio_div, lr, param_square_norm + fp32_local_start_idx,
&beta1pow, &beta2pow, &found_inf, beta1, beta2, trust_ratio_div_square_norm + fp32_local_start_idx, found_inf,
fp32_numel_each_device, fp32_param + fp32_offset, nullptr, stream); fp32_param + fp32_offset, nullptr, beta1pow, beta2pow, beta1, beta2);
if (num_devices > 1) { if (num_devices > 1) {
// ncclAllGather // ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
fp32_param + fp32_offset, fp32_param, fp32_numel_each_device, fp32_param + fp32_offset, fp32_param, fp32_numel_each_device,
ncclFloat32, comm, stream)); ncclFloat32, comm, stream));
} }
beta1pow = nullptr;
beta2pow = nullptr;
} }
if (has_fp16_param) { if (has_fp16_param) {
LambUpdateParamAndBetaPows<platform::float16>( MultiTensorUpdateLambParamAndBetaPows<platform::float16>(
dev_ctx, trust_ratio_div + fp32_numel_each_device, lr, dev_ctx, fp16_partial_fused_offsets, fp16_local_param_num,
indices + fp32_numel + fp16_offset, param_square_norm, trust_ratio_div + fp32_numel_each_device, lr,
trust_ratio_div_square_norm, fp32_square_grad_norm, &beta1pow, param_square_norm + fp16_local_start_idx,
&beta2pow, &found_inf, beta1, beta2, fp16_numel_each_device, trust_ratio_div_square_norm + fp16_local_start_idx, found_inf,
fp16_param + fp16_offset, master_param + fp16_offset, stream); fp16_param + fp16_offset, master_param + fp16_offset, beta1pow,
beta2pow, beta1, beta2);
if (num_devices > 1) { if (num_devices > 1) {
// ncclAllGather // ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
......
...@@ -94,11 +94,40 @@ static __global__ void MultiTensorApplyCUDAKernel( ...@@ -94,11 +94,40 @@ static __global__ void MultiTensorApplyCUDAKernel(
args...); args...);
} }
template <typename Functor, int BlockDim, int MaxTensorNumPerLaunch, template <int MaxTensorNumPerLaunch, int MaxChunkNumPerLaunch>
int MaxChunkNumPerLaunch, typename... Args> class MultiTensorLauncher {
static void MultiTensorApply(Functor functor, gpuStream_t stream, public:
const int *offsets, int n, int chunk_size, MultiTensorLauncher(
Args... args) { const TensorMetaList<MaxTensorNumPerLaunch, MaxChunkNumPerLaunch> &meta,
const int &chunk_id, const int &chunk_size, const int &block_dim,
const gpuStream_t &stream)
: meta_(meta),
chunk_id_(chunk_id),
chunk_size_(chunk_size),
block_dim_(block_dim),
stream_(stream) {}
template <typename Functor, typename... Args>
void Launch(Functor &&functor, Args &&... args) const {
MultiTensorApplyCUDAKernel<
Functor, MaxTensorNumPerLaunch,
MaxChunkNumPerLaunch><<<chunk_id_, block_dim_, 0, stream_>>>(
functor, meta_, chunk_size_, args...);
}
private:
const TensorMetaList<MaxTensorNumPerLaunch, MaxChunkNumPerLaunch> &meta_;
const int &chunk_id_;
const int &chunk_size_;
const int &block_dim_;
const gpuStream_t &stream_;
};
template <int MaxTensorNumPerLaunch, int MaxChunkNumPerLaunch,
typename Callback>
static void MultiTensorApplyWithCallback(gpuStream_t stream, const int *offsets,
int n, int chunk_size, int block_dim,
Callback &&callback) {
if (n == 0) return; if (n == 0) return;
constexpr auto NumTensor = MaxTensorNumPerLaunch; constexpr auto NumTensor = MaxTensorNumPerLaunch;
...@@ -110,6 +139,11 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, ...@@ -110,6 +139,11 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream,
int numel_offset = 0; int numel_offset = 0;
metas.start_tensor_id = 0; metas.start_tensor_id = 0;
metas.start_chunk_id = 0; metas.start_chunk_id = 0;
int launch_num = 0;
MultiTensorLauncher<MaxTensorNumPerLaunch, MaxChunkNumPerLaunch> launcher(
metas, chunk_id, chunk_size, block_dim, stream);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
auto length = offsets[i + 1] - offsets[i]; auto length = offsets[i + 1] - offsets[i];
if (tensor_id == 0) { if (tensor_id == 0) {
...@@ -132,9 +166,8 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, ...@@ -132,9 +166,8 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream,
bool last_chunk = (i + 1 == n && j + 1 == chunk_num); bool last_chunk = (i + 1 == n && j + 1 == chunk_num);
if (tensor_full || block_full || last_chunk) { if (tensor_full || block_full || last_chunk) {
MultiTensorApplyCUDAKernel<Functor, NumTensor, callback(launcher, launch_num);
NumChunk><<<chunk_id, BlockDim, 0, stream>>>( ++launch_num;
functor, metas, chunk_size, args...);
chunk_id = 0; chunk_id = 0;
if (j + 1 == chunk_num) { // chunk for the current tensor is full if (j + 1 == chunk_num) { // chunk for the current tensor is full
metas.start_chunk_id = 0; metas.start_chunk_id = 0;
...@@ -152,5 +185,17 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, ...@@ -152,5 +185,17 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream,
} }
} }
template <typename Functor, int MaxTensorNumPerLaunch, int MaxChunkNumPerLaunch,
typename... Args>
static void MultiTensorApply(Functor functor, gpuStream_t stream,
const int *offsets, int n, int chunk_size,
int block_dim, Args &&... args) {
auto callback = [&](const MultiTensorLauncher<MaxTensorNumPerLaunch,
MaxChunkNumPerLaunch> &launcher,
int i) { launcher.Launch(functor, args...); };
MultiTensorApplyWithCallback<MaxTensorNumPerLaunch, MaxChunkNumPerLaunch>(
stream, offsets, n, chunk_size, block_dim, callback);
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -144,6 +144,11 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): ...@@ -144,6 +144,11 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs):
grad_clip = kwargs.get('grad_clip', None) grad_clip = kwargs.get('grad_clip', None)
clip_after_allreduce = kwargs.get('clip_after_allreduce', True) clip_after_allreduce = kwargs.get('clip_after_allreduce', True)
parameters = [p.name for p in main.all_parameters()]
exclude_fn = lambda var: var.name in parameters[::4]
kwargs['exclude_from_weight_decay_fn'] = exclude_fn
kwargs['lamb_weight_decay'] = 0.1
if use_distributed_lamb: if use_distributed_lamb:
optimizer_class = DistributedFusedLamb optimizer_class = DistributedFusedLamb
kwargs = dict(kwargs) kwargs = dict(kwargs)
......
...@@ -171,10 +171,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -171,10 +171,7 @@ class DistributedFusedLamb(Optimizer):
moment2.is_distributed = True moment2.is_distributed = True
beta1pow = self._create_persistable_var('beta1pow') beta1pow = self._create_persistable_var('beta1pow')
beta2pow = self._create_persistable_var('beta2pow') beta2pow = self._create_persistable_var('beta2pow')
fused_indices = self._create_persistable_var(
'fused_indices', dtype='int32')
weight_decay = self._create_persistable_var('weight_decay')
weight_decay.is_distributed = True
param_info = self._create_persistable_var('param_info', dtype='int32') param_info = self._create_persistable_var('param_info', dtype='int32')
param_info.is_distributed = True param_info.is_distributed = True
...@@ -189,17 +186,20 @@ class DistributedFusedLamb(Optimizer): ...@@ -189,17 +186,20 @@ class DistributedFusedLamb(Optimizer):
'fp16_partial_fused_offsets', dtype='int32') 'fp16_partial_fused_offsets', dtype='int32')
fp16_partial_fused_offsets.is_distributed = True fp16_partial_fused_offsets.is_distributed = True
param_order = self._create_persistable_var('param_order', dtype='int32')
param_order.is_distributed = True
rank = get_rank() rank = get_rank()
nranks = get_world_size() nranks = get_world_size()
scale = self._get_or_create_scale() scale = self._get_or_create_scale()
params = [p for p, _ in params_grads] params = [p for p, _ in params_grads]
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
weight_decay_values = [self._weight_decay] * len(params) apply_weight_decay = [1] * len(params)
if self._exclude_from_weight_decay_fn is not None: if self._exclude_from_weight_decay_fn is not None:
for i, p in enumerate(params): for i, p in enumerate(params):
if self._exclude_from_weight_decay_fn(p): if self._exclude_from_weight_decay_fn(p):
weight_decay_values[i] = 0.0 apply_weight_decay[i] = 0
startup_block = self.helper.startup_program.global_block() startup_block = self.helper.startup_program.global_block()
for g in grads: for g in grads:
...@@ -225,8 +225,6 @@ class DistributedFusedLamb(Optimizer): ...@@ -225,8 +225,6 @@ class DistributedFusedLamb(Optimizer):
'Moment2': [moment2], 'Moment2': [moment2],
'Beta1Pow': [beta1pow], 'Beta1Pow': [beta1pow],
'Beta2Pow': [beta2pow], 'Beta2Pow': [beta2pow],
'FusedIndices': [fused_indices],
'WeightDecay': [weight_decay],
'GlobalScale': [scale], 'GlobalScale': [scale],
'ParamInfo': [param_info], 'ParamInfo': [param_info],
'ParamOut': params, 'ParamOut': params,
...@@ -235,12 +233,13 @@ class DistributedFusedLamb(Optimizer): ...@@ -235,12 +233,13 @@ class DistributedFusedLamb(Optimizer):
'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
'FusedParamOffsets': [fused_offsets], 'FusedParamOffsets': [fused_offsets],
'ParamOrder': [param_order],
}, },
attrs={ attrs={
'alignment': self._alignment, 'alignment': self._alignment,
'rank': rank, 'rank': rank,
'nranks': nranks, 'nranks': nranks,
'weight_decay': weight_decay_values, 'apply_weight_decay': apply_weight_decay,
'moment1': 0.0, 'moment1': 0.0,
'moment2': 0.0, 'moment2': 0.0,
'beta1': self._beta1, 'beta1': self._beta1,
...@@ -272,8 +271,6 @@ class DistributedFusedLamb(Optimizer): ...@@ -272,8 +271,6 @@ class DistributedFusedLamb(Optimizer):
'Moment2': [moment2], 'Moment2': [moment2],
'Beta1Pow': [beta1pow], 'Beta1Pow': [beta1pow],
'Beta2Pow': [beta2pow], 'Beta2Pow': [beta2pow],
'FusedIndices': [fused_indices],
'WeightDecay': [weight_decay],
'GlobalScale': [scale], 'GlobalScale': [scale],
'ParamInfo': [param_info], 'ParamInfo': [param_info],
'Param': params, 'Param': params,
...@@ -281,6 +278,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -281,6 +278,7 @@ class DistributedFusedLamb(Optimizer):
'FusedParamOffsets': [fused_offsets], 'FusedParamOffsets': [fused_offsets],
'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
'ParamOrder': [param_order],
}, },
outputs={ outputs={
'FP32FusedParamOut': [fp32_fused_param], 'FP32FusedParamOut': [fp32_fused_param],
...@@ -294,6 +292,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -294,6 +292,7 @@ class DistributedFusedLamb(Optimizer):
'FoundInf': [self._found_inf], 'FoundInf': [self._found_inf],
}, },
attrs={ attrs={
'weight_decay': self._weight_decay,
'beta1': self._beta1, 'beta1': self._beta1,
'beta2': self._beta2, 'beta2': self._beta2,
'epsilon': self._epsilon, 'epsilon': self._epsilon,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册