未验证 提交 d17961ed 编写于 作者: S sneaxiy 提交者: GitHub

Optimize the CUDA kernel in DistributedFusedLamb optimizer (#39972)

* vectorize lamb kernel

* remove flags, add ut

* remove useless codes

* refine code, add param order
上级 1b585b28
...@@ -61,30 +61,31 @@ class DistributedFusedLambInitOpMaker ...@@ -61,30 +61,31 @@ class DistributedFusedLambInitOpMaker
"The fp32 beta1 power accumulator tensor. Its shape is [1]."); "The fp32 beta1 power accumulator tensor. Its shape is [1].");
AddOutput("Beta2Pow", AddOutput("Beta2Pow",
"The fp32 beta2 power accumulator tensor. Its shape is [1]."); "The fp32 beta2 power accumulator tensor. Its shape is [1].");
AddOutput("FusedIndices",
"The param index of each element in FP32FusedParam. Its shape is "
"[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...].");
AddOutput( AddOutput(
"FusedParamOffsets", "FusedParamOffsets",
"The numel offset of each parameter inside the FP32FusedParam. Its " "The numel offset of each parameter inside the FP32FusedParam. Its "
"shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 " "shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 "
"+ n_2, ...]."); "+ n_2, ...]. It should be in CPUPlace.");
AddOutput("FP32ShardFusedParamOffsets", AddOutput(
"The sharded numel offset of each parameter in the local rank. " "FP32ShardFusedParamOffsets",
"Its shape is [fp32_local_param_num + 1].");
AddOutput("FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. " "The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1]."); "Its shape is [fp32_local_param_num + 1]. It should be in CPUPlace.");
AddOutput( AddOutput(
"WeightDecay", "FP16ShardFusedParamOffsets",
"The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N]."); "The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1]. It should be in CPUPlace.");
AddOutput("ParamInfo", AddOutput("ParamInfo",
"The param info. It should be in CPUPlace, and its shape is [6]" "The param info. It should be in CPUPlace, and its shape is [6]"
"CPUPlace, and its shape is [6]. It is " "CPUPlace, and its shape is [8]. It is "
"[fp32_shard_param_start_idx, fp32_local_param_num, " "[fp32_shard_param_start_idx, fp32_local_param_num, "
"fp32_global_param_num, fp16_shard_param_start_idx, " "fp32_global_param_num, fp32_weight_decay_end_idx, "
"fp16_local_param_num, fp16_global_param_num]."); "fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num, "
"fp16_weight_decay_end_idx].");
AddOutput("ParamOrder",
"The reordered parameter order. Inside this op, "
"the parameter would be reordered by data type and weight decay "
"value.");
AddOutput("ParamOut", "The output parameter list.").AsDuplicable(); AddOutput("ParamOut", "The output parameter list.").AsDuplicable();
AddOutput("MasterParamOut", AddOutput("MasterParamOut",
"The output master parameter list. It would share the memory of " "The output master parameter list. It would share the memory of "
...@@ -96,10 +97,8 @@ class DistributedFusedLambInitOpMaker ...@@ -96,10 +97,8 @@ class DistributedFusedLambInitOpMaker
AddAttr<float>("beta1", "The initial value of Beta1Pow."); AddAttr<float>("beta1", "The initial value of Beta1Pow.");
AddAttr<float>("beta2", "The initial value of Beta2Pow."); AddAttr<float>("beta2", "The initial value of Beta2Pow.");
AddAttr<std::vector<float>>( AddAttr<std::vector<int>>("apply_weight_decay",
"weight_decay", "Whether to apply weight decay.");
"The weight decay for each parameter. Its "
"shape is equal to the global parameter number.");
AddAttr<int>("alignment", "The alignment in bytes for the fused tensors."); AddAttr<int>("alignment", "The alignment in bytes for the fused tensors.");
AddAttr<int>("rank", "The global rank of the current process."); AddAttr<int>("rank", "The global rank of the current process.");
AddAttr<int>("nranks", "The global world size."); AddAttr<int>("nranks", "The global world size.");
......
...@@ -258,32 +258,6 @@ static void ShareBufferForNonInitedTensor(framework::Tensor *origin, ...@@ -258,32 +258,6 @@ static void ShareBufferForNonInitedTensor(framework::Tensor *origin,
<< ") , dtype = " << fused_out->dtype(); << ") , dtype = " << fused_out->dtype();
} }
template <typename OffsetT, typename IndexT>
static __global__ void LambFillFusedIndicesCUDAKernel(const OffsetT *offsets,
IndexT *out,
int offset_num,
int out_num) {
CUDA_KERNEL_LOOP_TYPE(i, out_num, int) {
auto idx = phi::funcs::LowerBound(offsets, offset_num, i);
if (idx == offset_num || offsets[idx] != i) {
--idx;
}
out[i] = idx;
}
}
template <typename T>
static void CopyVectorToTensor(const std::vector<T> &src,
framework::Tensor *dst,
const platform::Place &place,
gpuStream_t stream) {
dst->Resize({static_cast<int64_t>(src.size())});
T *dst_ptr = dst->mutable_data<T>(place);
const T *src_ptr = src.data();
auto nbytes = src.size() * sizeof(T);
memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream);
}
template <typename T> template <typename T>
static void CopyVectorToCPUTensor(const std::vector<T> &src, static void CopyVectorToCPUTensor(const std::vector<T> &src,
framework::Tensor *dst) { framework::Tensor *dst) {
...@@ -294,6 +268,42 @@ static void CopyVectorToCPUTensor(const std::vector<T> &src, ...@@ -294,6 +268,42 @@ static void CopyVectorToCPUTensor(const std::vector<T> &src,
std::memcpy(dst_ptr, src_ptr, nbytes); std::memcpy(dst_ptr, src_ptr, nbytes);
} }
static size_t ReorderParamGradInfoList(const std::vector<int> &flags,
std::vector<ParamGradInfo> *infos) {
size_t n = infos->size();
std::vector<int> cur_flags;
cur_flags.reserve(n);
for (size_t i = 0; i < n; ++i) {
auto idx = (*infos)[i].idx;
cur_flags.push_back(flags[idx]);
}
auto origin_infos = *infos;
size_t j = 0;
for (size_t i = 0; i < n; ++i) {
if (cur_flags[i]) {
(*infos)[j] = origin_infos[i];
++j;
}
}
size_t ret_idx = j;
for (size_t i = 0; i < n; ++i) {
if (!cur_flags[i]) {
(*infos)[j] = origin_infos[i];
++j;
}
}
return ret_idx;
}
template <typename T>
static T ClipByBound(T x, T low_value, T high_value) {
if (x < low_value) return low_value;
if (x > high_value) return high_value;
return x;
}
template <typename T> template <typename T>
class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -404,6 +414,24 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -404,6 +414,24 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
info->numel_offset = 0; // not determined yet info->numel_offset = 0; // not determined yet
} }
} }
const auto &apply_weight_decay =
ctx.Attr<std::vector<int>>("apply_weight_decay");
size_t fp32_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp32_infos);
size_t fp16_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp16_infos);
auto *param_order_t = ctx.Output<framework::Tensor>("ParamOrder");
auto param_num = fp32_infos.size() + fp16_infos.size();
param_order_t->Resize({static_cast<int16_t>(param_num)});
auto *param_order = param_order_t->mutable_data<int>(platform::CPUPlace());
for (size_t i = 0; i < fp32_infos.size(); ++i) {
param_order[i] = static_cast<int>(fp32_infos[i].idx);
}
for (size_t i = 0; i < fp16_infos.size(); ++i) {
param_order[i + fp32_infos.size()] = static_cast<int>(fp16_infos[i].idx);
}
VLOG(10) << "Fill ParamGradInfo ends"; VLOG(10) << "Fill ParamGradInfo ends";
// Step 2: determine the numel_with_padding and numel_offset // Step 2: determine the numel_with_padding and numel_offset
...@@ -568,45 +596,29 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -568,45 +596,29 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
VLOG(10) << "Found the sharding arguments"; VLOG(10) << "Found the sharding arguments";
auto *param_info_t = ctx.Output<framework::Tensor>("ParamInfo"); auto *param_info_t = ctx.Output<framework::Tensor>("ParamInfo");
param_info_t->Resize({6}); param_info_t->Resize({8});
auto *param_info = param_info_t->mutable_data<int>(platform::CPUPlace()); auto *param_info = param_info_t->mutable_data<int>(platform::CPUPlace());
param_info[0] = static_cast<int>(fp32_start_idx); param_info[0] = static_cast<int>(fp32_start_idx);
param_info[1] = static_cast<int>(fp32_local_param_num); param_info[1] = static_cast<int>(fp32_local_param_num);
param_info[2] = static_cast<int>(fp32_infos.size()); param_info[2] = static_cast<int>(fp32_infos.size());
param_info[3] = static_cast<int>(fp16_start_idx + fp32_infos.size()); param_info[3] = ClipByBound<int>(fp32_wd_end_idx, fp32_start_idx,
param_info[4] = static_cast<int>(fp16_local_param_num); fp32_start_idx + fp32_local_param_num) -
param_info[5] = static_cast<int>(fp16_infos.size()); static_cast<int>(fp32_start_idx);
param_info[4] = static_cast<int>(fp16_start_idx + fp32_infos.size());
param_info[5] = static_cast<int>(fp16_local_param_num);
param_info[6] = static_cast<int>(fp16_infos.size());
param_info[7] = ClipByBound<int>(fp16_wd_end_idx, fp16_start_idx,
fp16_start_idx + fp16_local_param_num) -
static_cast<int>(fp16_start_idx);
VLOG(10) << "Start FP32 idx: " << param_info[0]; VLOG(10) << "Start FP32 idx: " << param_info[0];
VLOG(10) << "Local FP32 param num: " << param_info[1]; VLOG(10) << "Local FP32 param num: " << param_info[1];
VLOG(10) << "Global FP32 param num: " << param_info[2]; VLOG(10) << "Global FP32 param num: " << param_info[2];
VLOG(10) << "Start FP16 idx: " << param_info[3]; VLOG(10) << "Start FP16 idx: " << param_info[4];
VLOG(10) << "Local FP16 param num: " << param_info[4]; VLOG(10) << "Local FP16 param num: " << param_info[5];
VLOG(10) << "Global FP16 param num: " << param_info[5]; VLOG(10) << "Global FP16 param num: " << param_info[6];
// For WeightDecay, shard and perform H2D copy
const auto &origin_weight_decay =
ctx.Attr<std::vector<float>>("weight_decay");
PADDLE_ENFORCE_EQ(params.size(), origin_weight_decay.size(),
platform::errors::InvalidArgument(
"The attr(weight_decay) should have the "
"same length with Input(Param)."));
std::vector<float> shard_weight_decay;
shard_weight_decay.reserve(total_local_param_num);
for (size_t i = 0; i < fp32_local_param_num; ++i) {
shard_weight_decay.push_back(
origin_weight_decay[fp32_infos[i + fp32_start_idx].idx]);
}
for (size_t i = 0; i < fp16_local_param_num; ++i) {
shard_weight_decay.push_back(
origin_weight_decay[fp16_infos[i + fp16_start_idx].idx]);
}
// For FusedIndices, launch CUDA kernel to do binary search
auto *fused_indices_t = ctx.Output<framework::Tensor>("FusedIndices");
fused_indices_t->Resize({static_cast<int64_t>(total_numel)});
auto *fused_indices = fused_indices_t->mutable_data<int>(place);
std::vector<int> numel_offsets; std::vector<int> numel_offsets;
numel_offsets.reserve(params.size() + 1); numel_offsets.reserve(params.size() + 1);
for (const auto &info : fp32_infos) { for (const auto &info : fp32_infos) {
...@@ -621,21 +633,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -621,21 +633,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
"The numel_offsets number must be one larger than " "The numel_offsets number must be one larger than "
"the parameter number.")); "the parameter number."));
VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets); VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets);
auto *fused_param_offset_t =
ctx.Output<framework::Tensor>("FusedParamOffsets");
fused_param_offset_t->Resize({static_cast<int64_t>(numel_offsets.size())});
auto *fused_param_offset = fused_param_offset_t->mutable_data<int>(place);
memory::Copy(place, fused_param_offset, platform::CPUPlace(),
numel_offsets.data(),
numel_offsets.size() * sizeof(numel_offsets[0]), stream);
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, total_numel);
LambFillFusedIndicesCUDAKernel<<<config.block_per_grid,
config.thread_per_block, 0, stream>>>(
fused_param_offset, fused_indices, numel_offsets.size() - 1,
total_numel);
std::vector<int> lengths;
lengths.reserve(fp32_local_param_num + fp16_local_param_num);
std::vector<int> fp32_partial_numel_offsets; std::vector<int> fp32_partial_numel_offsets;
fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1); fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1);
...@@ -659,9 +656,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -659,9 +656,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
VLOG(10) << "FP32 Partial numel = [" VLOG(10) << "FP32 Partial numel = ["
<< valid_start_n + fp32_infos[i].numel << "," << valid_start_n + fp32_infos[i].numel << ","
<< end_n + fp32_infos[i].numel; << end_n + fp32_infos[i].numel;
lengths.push_back(end_n - valid_start_n); auto len = end_n - valid_start_n;
fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() + fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() +
lengths.back()); len);
} }
std::vector<int> fp16_partial_numel_offsets; std::vector<int> fp16_partial_numel_offsets;
...@@ -682,9 +679,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -682,9 +679,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE_NE(valid_start_n, end_n, PADDLE_ENFORCE_NE(valid_start_n, end_n,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Indices sharding error. This may be a bug.")); "Indices sharding error. This may be a bug."));
lengths.push_back(end_n - valid_start_n); auto len = end_n - valid_start_n;
fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() + fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() +
lengths.back()); len);
} }
CopyVectorToCPUTensor(numel_offsets, CopyVectorToCPUTensor(numel_offsets,
...@@ -696,23 +693,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -696,23 +693,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
fp16_partial_numel_offsets, fp16_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets")); ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"));
// Fill the weight decay tensor
PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(),
platform::errors::InvalidArgument(
"Invalid weight decay sharding. This may be a bug."));
std::vector<float> wd_cpu;
for (size_t i = 0; i < shard_weight_decay.size(); ++i) {
int len = lengths[i];
for (int j = 0; j < len; ++j) {
wd_cpu.push_back(shard_weight_decay[i]);
}
}
PADDLE_ENFORCE_EQ(wd_cpu.size() * nranks, fp32_numel + fp16_numel,
platform::errors::InvalidArgument(
"Invalid weight decay sharding. This may be a bug."));
CopyVectorToTensor(wd_cpu, ctx.Output<framework::Tensor>("WeightDecay"),
place, stream);
auto *global_scale = ctx.Output<framework::Tensor>("GlobalScale"); auto *global_scale = ctx.Output<framework::Tensor>("GlobalScale");
if (!global_scale->IsInitialized()) { if (!global_scale->IsInitialized()) {
TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f); TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f);
......
...@@ -66,28 +66,31 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,28 +66,31 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
"The fp32 beta1 power accumulator tensor. Its shape is [1]."); "The fp32 beta1 power accumulator tensor. Its shape is [1].");
AddInput("Beta2Pow", AddInput("Beta2Pow",
"The fp32 beta2 power accumulator tensor. Its shape is [1]."); "The fp32 beta2 power accumulator tensor. Its shape is [1].");
AddInput("FusedIndices",
"The param index of each element in FP32FusedParam. Its shape is "
"[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...].");
AddInput( AddInput(
"FusedParamOffsets", "FusedParamOffsets",
"The numel offset of each parameter inside the FP32FusedParam. Its " "The numel offset of each parameter inside the FP32FusedParam. Its "
"shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 " "shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 "
"+ n_2, ...]."); "+ n_2, ...]. It should be in CPUPlace.");
AddInput("FP32ShardFusedParamOffsets", AddInput(
"FP32ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. " "The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp32_local_param_num + 1]."); "Its shape is [fp32_local_param_num + 1]. It should be in CPUPlace.");
AddInput("FP16ShardFusedParamOffsets", AddInput(
"FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. " "The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1]."); "Its shape is [fp16_local_param_num + 1]. It should be in CPUPlace.");
AddInput("WeightDecay",
"The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N].");
AddInput("ParamInfo", AddInput("ParamInfo",
"The param info. It should be in CPUPlace, and its shape is [6]" "The param info. It should be in CPUPlace, and its shape is [6]"
"CPUPlace, and its shape is [6]. It is " "CPUPlace, and its shape is [8]. It is "
"[fp32_shard_param_start_idx, fp32_local_param_num, " "[fp32_shard_param_start_idx, fp32_local_param_num, "
"fp32_global_param_num, fp16_shard_param_start_idx, " "fp32_global_param_num, fp32_weight_decay_end_idx, "
"fp16_local_param_num, fp16_global_param_num]."); "fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num, "
"fp16_weight_decay_end_idx].");
AddInput("ParamOrder",
"The reordered parameter order. Inside this op, "
"the parameter would be reordered by data type and weight decay "
"value.");
AddInput("LearningRate", AddInput("LearningRate",
"The fp32 learning rate tensor. Its shape is [1]."); "The fp32 learning rate tensor. Its shape is [1].");
...@@ -116,6 +119,7 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -116,6 +119,7 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
"max_global_grad_norm", "max_global_grad_norm",
"The maximum global gradient l2-norm value for clipping. If " "The maximum global gradient l2-norm value for clipping. If "
"max_global_grad_norm <= 0, no clipping would be performed."); "max_global_grad_norm <= 0, no clipping would be performed.");
AddAttr<float>("weight_decay", "The weight decay value.");
AddAttr<bool>("clip_after_allreduce", AddAttr<bool>("clip_after_allreduce",
"Whether to clip before allreduce, only valid when the " "Whether to clip before allreduce, only valid when the "
"world size is larger than 1."); "world size is larger than 1.");
......
...@@ -94,11 +94,40 @@ static __global__ void MultiTensorApplyCUDAKernel( ...@@ -94,11 +94,40 @@ static __global__ void MultiTensorApplyCUDAKernel(
args...); args...);
} }
template <typename Functor, int BlockDim, int MaxTensorNumPerLaunch, template <int MaxTensorNumPerLaunch, int MaxChunkNumPerLaunch>
int MaxChunkNumPerLaunch, typename... Args> class MultiTensorLauncher {
static void MultiTensorApply(Functor functor, gpuStream_t stream, public:
const int *offsets, int n, int chunk_size, MultiTensorLauncher(
Args... args) { const TensorMetaList<MaxTensorNumPerLaunch, MaxChunkNumPerLaunch> &meta,
const int &chunk_id, const int &chunk_size, const int &block_dim,
const gpuStream_t &stream)
: meta_(meta),
chunk_id_(chunk_id),
chunk_size_(chunk_size),
block_dim_(block_dim),
stream_(stream) {}
template <typename Functor, typename... Args>
void Launch(Functor &&functor, Args &&... args) const {
MultiTensorApplyCUDAKernel<
Functor, MaxTensorNumPerLaunch,
MaxChunkNumPerLaunch><<<chunk_id_, block_dim_, 0, stream_>>>(
functor, meta_, chunk_size_, args...);
}
private:
const TensorMetaList<MaxTensorNumPerLaunch, MaxChunkNumPerLaunch> &meta_;
const int &chunk_id_;
const int &chunk_size_;
const int &block_dim_;
const gpuStream_t &stream_;
};
template <int MaxTensorNumPerLaunch, int MaxChunkNumPerLaunch,
typename Callback>
static void MultiTensorApplyWithCallback(gpuStream_t stream, const int *offsets,
int n, int chunk_size, int block_dim,
Callback &&callback) {
if (n == 0) return; if (n == 0) return;
constexpr auto NumTensor = MaxTensorNumPerLaunch; constexpr auto NumTensor = MaxTensorNumPerLaunch;
...@@ -110,6 +139,11 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, ...@@ -110,6 +139,11 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream,
int numel_offset = 0; int numel_offset = 0;
metas.start_tensor_id = 0; metas.start_tensor_id = 0;
metas.start_chunk_id = 0; metas.start_chunk_id = 0;
int launch_num = 0;
MultiTensorLauncher<MaxTensorNumPerLaunch, MaxChunkNumPerLaunch> launcher(
metas, chunk_id, chunk_size, block_dim, stream);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
auto length = offsets[i + 1] - offsets[i]; auto length = offsets[i + 1] - offsets[i];
if (tensor_id == 0) { if (tensor_id == 0) {
...@@ -132,9 +166,8 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, ...@@ -132,9 +166,8 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream,
bool last_chunk = (i + 1 == n && j + 1 == chunk_num); bool last_chunk = (i + 1 == n && j + 1 == chunk_num);
if (tensor_full || block_full || last_chunk) { if (tensor_full || block_full || last_chunk) {
MultiTensorApplyCUDAKernel<Functor, NumTensor, callback(launcher, launch_num);
NumChunk><<<chunk_id, BlockDim, 0, stream>>>( ++launch_num;
functor, metas, chunk_size, args...);
chunk_id = 0; chunk_id = 0;
if (j + 1 == chunk_num) { // chunk for the current tensor is full if (j + 1 == chunk_num) { // chunk for the current tensor is full
metas.start_chunk_id = 0; metas.start_chunk_id = 0;
...@@ -152,5 +185,17 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, ...@@ -152,5 +185,17 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream,
} }
} }
template <typename Functor, int MaxTensorNumPerLaunch, int MaxChunkNumPerLaunch,
typename... Args>
static void MultiTensorApply(Functor functor, gpuStream_t stream,
const int *offsets, int n, int chunk_size,
int block_dim, Args &&... args) {
auto callback = [&](const MultiTensorLauncher<MaxTensorNumPerLaunch,
MaxChunkNumPerLaunch> &launcher,
int i) { launcher.Launch(functor, args...); };
MultiTensorApplyWithCallback<MaxTensorNumPerLaunch, MaxChunkNumPerLaunch>(
stream, offsets, n, chunk_size, block_dim, callback);
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -144,6 +144,11 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): ...@@ -144,6 +144,11 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs):
grad_clip = kwargs.get('grad_clip', None) grad_clip = kwargs.get('grad_clip', None)
clip_after_allreduce = kwargs.get('clip_after_allreduce', True) clip_after_allreduce = kwargs.get('clip_after_allreduce', True)
parameters = [p.name for p in main.all_parameters()]
exclude_fn = lambda var: var.name in parameters[::4]
kwargs['exclude_from_weight_decay_fn'] = exclude_fn
kwargs['lamb_weight_decay'] = 0.1
if use_distributed_lamb: if use_distributed_lamb:
optimizer_class = DistributedFusedLamb optimizer_class = DistributedFusedLamb
kwargs = dict(kwargs) kwargs = dict(kwargs)
......
...@@ -171,10 +171,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -171,10 +171,7 @@ class DistributedFusedLamb(Optimizer):
moment2.is_distributed = True moment2.is_distributed = True
beta1pow = self._create_persistable_var('beta1pow') beta1pow = self._create_persistable_var('beta1pow')
beta2pow = self._create_persistable_var('beta2pow') beta2pow = self._create_persistable_var('beta2pow')
fused_indices = self._create_persistable_var(
'fused_indices', dtype='int32')
weight_decay = self._create_persistable_var('weight_decay')
weight_decay.is_distributed = True
param_info = self._create_persistable_var('param_info', dtype='int32') param_info = self._create_persistable_var('param_info', dtype='int32')
param_info.is_distributed = True param_info.is_distributed = True
...@@ -189,17 +186,20 @@ class DistributedFusedLamb(Optimizer): ...@@ -189,17 +186,20 @@ class DistributedFusedLamb(Optimizer):
'fp16_partial_fused_offsets', dtype='int32') 'fp16_partial_fused_offsets', dtype='int32')
fp16_partial_fused_offsets.is_distributed = True fp16_partial_fused_offsets.is_distributed = True
param_order = self._create_persistable_var('param_order', dtype='int32')
param_order.is_distributed = True
rank = get_rank() rank = get_rank()
nranks = get_world_size() nranks = get_world_size()
scale = self._get_or_create_scale() scale = self._get_or_create_scale()
params = [p for p, _ in params_grads] params = [p for p, _ in params_grads]
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
weight_decay_values = [self._weight_decay] * len(params) apply_weight_decay = [1] * len(params)
if self._exclude_from_weight_decay_fn is not None: if self._exclude_from_weight_decay_fn is not None:
for i, p in enumerate(params): for i, p in enumerate(params):
if self._exclude_from_weight_decay_fn(p): if self._exclude_from_weight_decay_fn(p):
weight_decay_values[i] = 0.0 apply_weight_decay[i] = 0
startup_block = self.helper.startup_program.global_block() startup_block = self.helper.startup_program.global_block()
for g in grads: for g in grads:
...@@ -225,8 +225,6 @@ class DistributedFusedLamb(Optimizer): ...@@ -225,8 +225,6 @@ class DistributedFusedLamb(Optimizer):
'Moment2': [moment2], 'Moment2': [moment2],
'Beta1Pow': [beta1pow], 'Beta1Pow': [beta1pow],
'Beta2Pow': [beta2pow], 'Beta2Pow': [beta2pow],
'FusedIndices': [fused_indices],
'WeightDecay': [weight_decay],
'GlobalScale': [scale], 'GlobalScale': [scale],
'ParamInfo': [param_info], 'ParamInfo': [param_info],
'ParamOut': params, 'ParamOut': params,
...@@ -235,12 +233,13 @@ class DistributedFusedLamb(Optimizer): ...@@ -235,12 +233,13 @@ class DistributedFusedLamb(Optimizer):
'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
'FusedParamOffsets': [fused_offsets], 'FusedParamOffsets': [fused_offsets],
'ParamOrder': [param_order],
}, },
attrs={ attrs={
'alignment': self._alignment, 'alignment': self._alignment,
'rank': rank, 'rank': rank,
'nranks': nranks, 'nranks': nranks,
'weight_decay': weight_decay_values, 'apply_weight_decay': apply_weight_decay,
'moment1': 0.0, 'moment1': 0.0,
'moment2': 0.0, 'moment2': 0.0,
'beta1': self._beta1, 'beta1': self._beta1,
...@@ -272,8 +271,6 @@ class DistributedFusedLamb(Optimizer): ...@@ -272,8 +271,6 @@ class DistributedFusedLamb(Optimizer):
'Moment2': [moment2], 'Moment2': [moment2],
'Beta1Pow': [beta1pow], 'Beta1Pow': [beta1pow],
'Beta2Pow': [beta2pow], 'Beta2Pow': [beta2pow],
'FusedIndices': [fused_indices],
'WeightDecay': [weight_decay],
'GlobalScale': [scale], 'GlobalScale': [scale],
'ParamInfo': [param_info], 'ParamInfo': [param_info],
'Param': params, 'Param': params,
...@@ -281,6 +278,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -281,6 +278,7 @@ class DistributedFusedLamb(Optimizer):
'FusedParamOffsets': [fused_offsets], 'FusedParamOffsets': [fused_offsets],
'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
'ParamOrder': [param_order],
}, },
outputs={ outputs={
'FP32FusedParamOut': [fp32_fused_param], 'FP32FusedParamOut': [fp32_fused_param],
...@@ -294,6 +292,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -294,6 +292,7 @@ class DistributedFusedLamb(Optimizer):
'FoundInf': [self._found_inf], 'FoundInf': [self._found_inf],
}, },
attrs={ attrs={
'weight_decay': self._weight_decay,
'beta1': self._beta1, 'beta1': self._beta1,
'beta2': self._beta2, 'beta2': self._beta2,
'epsilon': self._epsilon, 'epsilon': self._epsilon,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册