未验证 提交 5f8e7d8f 编写于 作者: H huangjiyi 提交者: GitHub

Functionalize distributed_fused_lamb kernel (#53896)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update HostAlloc

* update param name

* update cpu kernel

* remove kernel header

* update

* update
上级 6e0cf610
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -170,8 +171,63 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb, ...@@ -170,8 +171,63 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb,
ops::DistributedFusedLambOp, ops::DistributedFusedLambOp,
ops::DistributedFusedLambOpMaker); ops::DistributedFusedLambOpMaker);
PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb, namespace phi {
CPU, namespace fusion {
ALL_LAYOUT,
ops::DistributedFusedLambOpKernel, template <typename T, typename Context>
float) {} void DistributedFusedLambKernel(const Context &dev_ctx,
const std::vector<const DenseTensor *> &param,
const std::vector<const DenseTensor *> &grad,
const paddle::optional<DenseTensor> &fp32_param,
const paddle::optional<DenseTensor> &fp32_grad,
const paddle::optional<DenseTensor> &fp16_param,
const paddle::optional<DenseTensor> &fp16_grad,
const DenseTensor &moment1,
const DenseTensor &moment2,
const DenseTensor &beta1_pow,
const DenseTensor &beta2_pow,
const DenseTensor &param_offsets,
const DenseTensor &fp32_partial_offsets,
const DenseTensor &fp16_partial_offsets,
const DenseTensor &param_info,
const DenseTensor &param_order,
const DenseTensor &learning_rate,
const DenseTensor &global_scale,
int acc_steps,
float beta1,
float beta2,
float epsilon,
float max_global_grad_norm,
float weight_decay,
bool clip_after_allreduce,
bool use_master_param_norm,
bool use_master_acc_grad,
bool is_grad_scaled_by_nranks,
bool use_hierarchical_allreduce,
int64_t nranks,
const std::vector<int> &ring_ids,
DenseTensor *fp32_param_out,
DenseTensor *fp16_param_out,
DenseTensor *fp32_acc_grad,
DenseTensor *fp16_acc_grad,
DenseTensor *moment1_out,
DenseTensor *moment2_out,
DenseTensor *beta1_pow_out,
DenseTensor *beta2_pow_out,
DenseTensor *param_out,
DenseTensor *found_inf,
DenseTensor *acc_step,
DenseTensor *stop_update,
DenseTensor *step) {
PADDLE_THROW(phi::errors::Unimplemented(
"The distributed_fused_lamb operator does not support CPU yet."));
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(distributed_fused_lamb,
CPU,
ALL_LAYOUT,
phi::fusion::DistributedFusedLambKernel,
float) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,19 +12,21 @@ ...@@ -12,19 +12,21 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cmath>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h" #include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/cuda_stream.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/tensor_to_string.h" #include "paddle/phi/kernels/funcs/tensor_to_string.h"
#include "paddle/utils/optional.h"
#ifdef __NVCC__ #ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
...@@ -38,11 +40,11 @@ ...@@ -38,11 +40,11 @@
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
namespace paddle { namespace phi {
namespace operators { namespace fusion {
template <typename T> template <typename T>
using MasterT = typename details::MPTypeTrait<T>::Type; using MasterT = typename phi::dtype::MPTypeTrait<T>::Type;
using phi::funcs::FlattenToString; using phi::funcs::FlattenToString;
using phi::funcs::ToVector; using phi::funcs::ToVector;
...@@ -157,7 +159,7 @@ template <typename InT, ...@@ -157,7 +159,7 @@ template <typename InT,
typename OutT, typename OutT,
int MaxTensorNumPerLaunch = 160, int MaxTensorNumPerLaunch = 160,
int MaxChunkNumPerLaunch = 780> int MaxChunkNumPerLaunch = 780>
static void MultiTensorL2Norm(const platform::CUDAPlace &place, static void MultiTensorL2Norm(const phi::GPUPlace &place,
gpuStream_t stream, gpuStream_t stream,
const InT *x, const InT *x,
const int *offsets, const int *offsets,
...@@ -191,24 +193,25 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, ...@@ -191,24 +193,25 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
<< " , tensor_num = " << n; << " , tensor_num = " << n;
using MT = MasterT<InT>; using MT = MasterT<InT>;
phi::memory_utils::Buffer tmp_out(place); memory_utils::Buffer tmp_out(place);
auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num); auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream); FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);
#define PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL \ #define PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL \
do { \ do { \
using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \ using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \
VLOG(10) << __func__ << " " << typeid(InT).name() \ VLOG(10) << __func__ << " " << typeid(InT).name() \
<< " VecSize = " << kVecSize; \ << " VecSize = " << kVecSize; \
MultiTensorApply<FunctorT, kNumTensor, kNumChunk>(FunctorT(), \ paddle::operators::MultiTensorApply<FunctorT, kNumTensor, kNumChunk>( \
stream, \ FunctorT(), \
offsets, \ stream, \
n, \ offsets, \
chunk_size, \ n, \
kBlockDim, \ chunk_size, \
x, \ kBlockDim, \
tmp_out_ptr, \ x, \
max_chunk_num); \ tmp_out_ptr, \
max_chunk_num); \
} while (0) } while (0)
PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL); PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL);
...@@ -220,27 +223,27 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, ...@@ -220,27 +223,27 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
template <int LogLevel> template <int LogLevel>
static void LogParamAndTrustRatioDivSquareNorm( static void LogParamAndTrustRatioDivSquareNorm(
const framework::ExecutionContext &ctx, const std::vector<const DenseTensor *> &param,
const DenseTensor &order,
const float *param_square_norm, const float *param_square_norm,
const float *trust_ratio_div_square_norm) { const float *trust_ratio_div_square_norm) {
if (!VLOG_IS_ON(LogLevel)) return; if (!VLOG_IS_ON(LogLevel)) return;
auto tensors = ctx.MultiInput<phi::DenseTensor>("Param"); if (param.empty()) return;
if (tensors.empty()) return;
const auto *order = ctx.Input<phi::DenseTensor>("ParamOrder")->data<int>(); const auto *order_data = order.data<int>();
size_t n = tensors.size(); size_t n = param.size();
auto place = tensors[0]->place(); auto place = param[0]->place();
auto pn_vec = ToVector(param_square_norm, n, place); auto pn_vec = ToVector(param_square_norm, n, place);
auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place); auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place);
const auto &names = ctx.GetOp().Inputs("Param");
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
auto idx = order[i]; auto idx = order_data[i];
VLOG(LogLevel) << "Param " << tensors[idx]->dtype() << " " << names[idx] VLOG(LogLevel) << "Param " << param[idx]->dtype() << " "
<< " pn = " << pn_vec[i] << " , tn = " << tn_vec[i]; << param[idx]->name() << " pn = " << pn_vec[i]
<< " , tn = " << tn_vec[i];
} }
} }
...@@ -261,13 +264,12 @@ static bool IsFinite(const phi::GPUContext &dev_ctx, const float *ptr) { ...@@ -261,13 +264,12 @@ static bool IsFinite(const phi::GPUContext &dev_ctx, const float *ptr) {
} }
template <typename T> template <typename T>
static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx, static const T *GetInputTensorPtr(const DenseTensor *in_tensor,
const char *in_name, const char *in_name,
int64_t *numel = nullptr) { int64_t *numel = nullptr) {
const auto *in_tensor = ctx.Input<phi::DenseTensor>(in_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
in_tensor, in_tensor,
platform::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name)); phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
if (in_tensor->IsInitialized()) { if (in_tensor->IsInitialized()) {
if (numel) *numel = in_tensor->numel(); if (numel) *numel = in_tensor->numel();
return in_tensor->data<T>(); return in_tensor->data<T>();
...@@ -277,34 +279,34 @@ static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx, ...@@ -277,34 +279,34 @@ static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx,
} }
} }
template <typename T, bool AllowNotExist = false> template <typename T, typename Context, bool AllowNotExist = false>
static T *GetSameInOutTensorPtr(const framework::ExecutionContext &ctx, static T *GetSameInOutTensorPtr(const Context &dev_ctx,
const platform::Place &place, const DenseTensor *in_tensor,
DenseTensor *out_tensor,
const char *in_name, const char *in_name,
const char *out_name, const char *out_name,
int64_t *numel = nullptr) { int64_t *numel = nullptr) {
const auto *in_tensor = ctx.Input<phi::DenseTensor>(in_name);
if (in_tensor == nullptr || !in_tensor->IsInitialized()) { if (in_tensor == nullptr || !in_tensor->IsInitialized()) {
PADDLE_ENFORCE_EQ(AllowNotExist, PADDLE_ENFORCE_EQ(
true, AllowNotExist,
platform::errors::InvalidArgument( true,
"Input(%s) cannot be NULL.", in_name)); phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
if (numel) *numel = 0; if (numel) *numel = 0;
return nullptr; return nullptr;
} }
auto *out_tensor = ctx.Output<phi::DenseTensor>(out_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
in_tensor, in_tensor,
platform::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name)); phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
PADDLE_ENFORCE_NOT_NULL(out_tensor, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( out_tensor,
"Output(%s) cannot be NULL.", out_name)); phi::errors::InvalidArgument("Output(%s) cannot be NULL.", out_name));
const T *in_data = in_tensor->data<T>(); const T *in_data = in_tensor->data<T>();
T *out_data = out_tensor->mutable_data<T>(place);
T *out_data = dev_ctx.template Alloc<T>(out_tensor);
PADDLE_ENFORCE_EQ(in_data, PADDLE_ENFORCE_EQ(in_data,
out_data, out_data,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(%s) and Output(%s) must be the same Tensor.", "Input(%s) and Output(%s) must be the same Tensor.",
in_name, in_name,
out_name)); out_name));
...@@ -535,11 +537,11 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv( ...@@ -535,11 +537,11 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
int numel = offsets[n] - offsets[0]; int numel = offsets[n] - offsets[0];
PADDLE_ENFORCE_GE(weight_decay_end_idx, PADDLE_ENFORCE_GE(weight_decay_end_idx,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The weight decay end index should be >= 0.")); "The weight decay end index should be >= 0."));
PADDLE_ENFORCE_LE(weight_decay_end_idx, PADDLE_ENFORCE_LE(weight_decay_end_idx,
n, n,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The weight decay end index should be < %d.", n)); "The weight decay end index should be < %d.", n));
auto weight_decay_end_numel = offsets[weight_decay_end_idx] - offsets[0]; auto weight_decay_end_numel = offsets[weight_decay_end_idx] - offsets[0];
...@@ -558,17 +560,17 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv( ...@@ -558,17 +560,17 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
VLOG(1) << __func__ << " VecSize = " << vec_size; VLOG(1) << __func__ << " VecSize = " << vec_size;
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
if (found_inf_p == nullptr) { if (found_inf_p == nullptr) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
step, step,
nullptr, nullptr,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Output(Step) cannot be updated twice in one mini-batch.")); "Output(Step) cannot be updated twice in one mini-batch."));
} else { } else {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
step, step, phi::errors::InvalidArgument("Output(Step) cannot be nullptr."));
platform::errors::InvalidArgument("Output(Step) cannot be nullptr."));
} }
#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \ #define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \
...@@ -603,12 +605,12 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv( ...@@ -603,12 +605,12 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
template <typename T, bool NeedUpdate /*=true*/> template <typename T, bool NeedUpdate /*=true*/>
struct LambBetaPowUpdateOnceHelper { struct LambBetaPowUpdateOnceHelper {
LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) { LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) {
PADDLE_ENFORCE_NOT_NULL(beta1pow, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( beta1pow,
"The beta1pow should not be nullptr.")); phi::errors::InvalidArgument("The beta1pow should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(beta2pow, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( beta2pow,
"The beta2pow should not be nullptr.")); phi::errors::InvalidArgument("The beta2pow should not be nullptr."));
beta1pow_ = beta1pow; beta1pow_ = beta1pow;
beta2pow_ = beta2pow; beta2pow_ = beta2pow;
beta1_ = beta1; beta1_ = beta1;
...@@ -633,11 +635,11 @@ struct LambBetaPowUpdateOnceHelper<T, false> { ...@@ -633,11 +635,11 @@ struct LambBetaPowUpdateOnceHelper<T, false> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta1pow, beta1pow,
nullptr, nullptr,
platform::errors::InvalidArgument("The beta1pow should be nullptr.")); phi::errors::InvalidArgument("The beta1pow should be nullptr."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta2pow, beta2pow,
nullptr, nullptr,
platform::errors::InvalidArgument("The beta2pow should be nullptr.")); phi::errors::InvalidArgument("The beta2pow should be nullptr."));
} }
HOSTDEVICE void UpdateBetaPows() const {} HOSTDEVICE void UpdateBetaPows() const {}
...@@ -649,11 +651,11 @@ struct LambParamHelper { ...@@ -649,11 +651,11 @@ struct LambParamHelper {
constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value; constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
PADDLE_ENFORCE_EQ(kIsSameType, PADDLE_ENFORCE_EQ(kIsSameType,
false, false,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"T must not be the same with MasterT<T>.")); "T must not be the same with MasterT<T>."));
PADDLE_ENFORCE_NOT_NULL(master_param, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( master_param,
"Master parameter must be provided.")); phi::errors::InvalidArgument("Master parameter must be provided."));
param_ = param; param_ = param;
master_param_ = master_param; master_param_ = master_param;
} }
...@@ -671,14 +673,14 @@ template <typename T> ...@@ -671,14 +673,14 @@ template <typename T>
struct LambParamHelper<T, false> { struct LambParamHelper<T, false> {
LambParamHelper(T *param, MasterT<T> *master_param) { LambParamHelper(T *param, MasterT<T> *master_param) {
constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value; constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
PADDLE_ENFORCE_EQ(kIsSameType, PADDLE_ENFORCE_EQ(
true, kIsSameType,
platform::errors::InvalidArgument( true,
"T must be the same with MasterT<T>.")); phi::errors::InvalidArgument("T must be the same with MasterT<T>."));
if (master_param != nullptr) { if (master_param != nullptr) {
PADDLE_ENFORCE_EQ(static_cast<void *>(param), PADDLE_ENFORCE_EQ(static_cast<void *>(param),
static_cast<void *>(master_param), static_cast<void *>(master_param),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Master parameter must be nullptr or the same as " "Master parameter must be nullptr or the same as "
"non-master parameter.")); "non-master parameter."));
} }
...@@ -802,12 +804,12 @@ static void MultiTensorUpdateLambParamAndBetaPows( ...@@ -802,12 +804,12 @@ static void MultiTensorUpdateLambParamAndBetaPows(
if (has_beta_pow) { if (has_beta_pow) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
beta2pow, beta2pow,
platform::errors::InvalidArgument("Beta2Pow should not be nullptr.")); phi::errors::InvalidArgument("Beta2Pow should not be nullptr."));
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta2pow, beta2pow,
nullptr, nullptr,
platform::errors::InvalidArgument("Beta2Pow should be nullptr.")); phi::errors::InvalidArgument("Beta2Pow should be nullptr."));
} }
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -855,21 +857,22 @@ static void MultiTensorUpdateLambParamAndBetaPows( ...@@ -855,21 +857,22 @@ static void MultiTensorUpdateLambParamAndBetaPows(
betapow_helper); \ betapow_helper); \
} while (0) } while (0)
#define PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE \ #define PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE \
do { \ do { \
auto callback = \ auto callback = \
[&](const MultiTensorLauncher<kNumTensor, kNumChunk> &launcher, \ [&](const paddle::operators::MultiTensorLauncher<kNumTensor, \
int launch_n) { \ kNumChunk> &launcher, \
if (has_beta_pow && launch_n == 0) { \ int launch_n) { \
PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(true); \ if (has_beta_pow && launch_n == 0) { \
beta1pow = nullptr; \ PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(true); \
beta2pow = nullptr; \ beta1pow = nullptr; \
} else { \ beta2pow = nullptr; \
PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(false); \ } else { \
} \ PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(false); \
}; \ } \
MultiTensorApplyWithCallback<kNumTensor, kNumChunk>( \ }; \
stream, offsets, n, chunk_size, block_dim, callback); \ paddle::operators::MultiTensorApplyWithCallback<kNumTensor, kNumChunk>( \
stream, offsets, n, chunk_size, block_dim, callback); \
} while (0) } while (0)
PD_VEC_LAUNCH_KERNEL(vec_size, PD_VEC_LAUNCH_KERNEL(vec_size,
...@@ -886,10 +889,10 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype, ...@@ -886,10 +889,10 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype,
ncclRedOp_t *op) { ncclRedOp_t *op) {
#if NCCL_VERSION_CODE >= 21100 #if NCCL_VERSION_CODE >= 21100
int ver; int ver;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetVersion(&ver)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetVersion(&ver));
if (ver >= 21100) { if (ver >= 21100) {
VLOG(10) << "ncclRedOpCreatePreMulSum is supported."; VLOG(10) << "ncclRedOpCreatePreMulSum is supported.";
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpCreatePreMulSum( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpCreatePreMulSum(
op, const_cast<void *>(scale), dtype, ncclScalarDevice, comm)); op, const_cast<void *>(scale), dtype, ncclScalarDevice, comm));
return true; return true;
} }
...@@ -906,7 +909,7 @@ static void LaunchScaleKernel(const phi::GPUContext &dev_ctx, ...@@ -906,7 +909,7 @@ static void LaunchScaleKernel(const phi::GPUContext &dev_ctx,
int n, int n,
gpuStream_t stream) { gpuStream_t stream) {
int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)); int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0));
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
#define PD_LAMB_VEC_SCALE_KERNEL_CASE \ #define PD_LAMB_VEC_SCALE_KERNEL_CASE \
do { \ do { \
...@@ -928,9 +931,9 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ...@@ -928,9 +931,9 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
gpuStream_t stream, gpuStream_t stream,
const phi::GPUContext &dev_ctx, const phi::GPUContext &dev_ctx,
const T *scale = nullptr) { const T *scale = nullptr) {
static_assert(std::is_same<T, float>::value || static_assert(
std::is_same<T, platform::float16>::value, std::is_same<T, float>::value || std::is_same<T, dtype::float16>::value,
"T must be either float32 or float16."); "T must be either float32 or float16.");
if (recvcount == 0) return; if (recvcount == 0) return;
auto numel = UseReduceScatter ? (recvcount * nranks) : recvcount; auto numel = UseReduceScatter ? (recvcount * nranks) : recvcount;
...@@ -938,7 +941,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ...@@ -938,7 +941,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
if (scale != nullptr) { if (scale != nullptr) {
PADDLE_ENFORCE_EQ(nranks, PADDLE_ENFORCE_EQ(nranks,
1, 1,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"nranks must be 1 when scale != nullptr.")); "nranks must be 1 when scale != nullptr."));
LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream); LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream);
} }
...@@ -950,7 +953,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ...@@ -950,7 +953,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
std::is_same<T, float>::value ? ncclFloat32 : ncclFloat16; std::is_same<T, float>::value ? ncclFloat32 : ncclFloat16;
bool should_destroy_op = bool should_destroy_op =
scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op); scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op);
phi::memory_utils::Buffer buffer(dev_ctx.GetPlace()); memory_utils::Buffer buffer(dev_ctx.GetPlace());
if (scale && !should_destroy_op) { if (scale && !should_destroy_op) {
T *new_sendbuff = buffer.Alloc<T>(numel); T *new_sendbuff = buffer.Alloc<T>(numel);
LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream); LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
...@@ -958,17 +961,17 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ...@@ -958,17 +961,17 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
} }
if (UseReduceScatter) { if (UseReduceScatter) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclReduceScatter(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
} }
#if NCCL_VERSION_CODE >= 21100 #if NCCL_VERSION_CODE >= 21100
if (should_destroy_op) { if (should_destroy_op) {
VLOG(10) << "ncclRedOpDestroy starts"; VLOG(10) << "ncclRedOpDestroy starts";
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpDestroy(op, comm)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpDestroy(op, comm));
VLOG(10) << "ncclRedOpDestroy ends"; VLOG(10) << "ncclRedOpDestroy ends";
} }
#endif #endif
...@@ -1012,7 +1015,7 @@ static void CubDeviceReduce(InputIteratorT d_in, ...@@ -1012,7 +1015,7 @@ static void CubDeviceReduce(InputIteratorT d_in,
ReduceOpT reduction_op, ReduceOpT reduction_op,
T init, T init,
gpuStream_t stream, gpuStream_t stream,
phi::memory_utils::Buffer *buffer) { memory_utils::Buffer *buffer) {
void *d_temp_storage = nullptr; void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Reduce(d_temp_storage, PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Reduce(d_temp_storage,
...@@ -1041,7 +1044,7 @@ static void GetSquareGradNormImpl(const T *grad, ...@@ -1041,7 +1044,7 @@ static void GetSquareGradNormImpl(const T *grad,
int n, int n,
float *square_norm, float *square_norm,
gpuStream_t stream, gpuStream_t stream,
phi::memory_utils::Buffer *cub_tmp_buffer) { memory_utils::Buffer *cub_tmp_buffer) {
using Iterator = using Iterator =
cub::TransformInputIterator<float, SquareFunctor<T>, const T *>; cub::TransformInputIterator<float, SquareFunctor<T>, const T *>;
Iterator iter(grad, SquareFunctor<T>()); Iterator iter(grad, SquareFunctor<T>());
...@@ -1057,11 +1060,11 @@ static void GetSquareGradNormImpl(const T *grad, ...@@ -1057,11 +1060,11 @@ static void GetSquareGradNormImpl(const T *grad,
// square_norm is of length 2 at least // square_norm is of length 2 at least
static void GetSquareGradNorm(const float *fp32_grad, static void GetSquareGradNorm(const float *fp32_grad,
int fp32_numel, int fp32_numel,
const platform::float16 *fp16_grad, const dtype::float16 *fp16_grad,
int fp16_numel, int fp16_numel,
float *square_norm, float *square_norm,
gpuStream_t stream, gpuStream_t stream,
phi::memory_utils::Buffer *cub_tmp_buffer) { memory_utils::Buffer *cub_tmp_buffer) {
VLOG(10) << "GetSquareGradNorm starts, fp32_numel = " << fp32_numel VLOG(10) << "GetSquareGradNorm starts, fp32_numel = " << fp32_numel
<< " , fp16_numel = " << fp16_numel; << " , fp16_numel = " << fp16_numel;
if (fp32_numel > 0) { if (fp32_numel > 0) {
...@@ -1096,23 +1099,21 @@ std::string NumToString(T x) { ...@@ -1096,23 +1099,21 @@ std::string NumToString(T x) {
} }
template <typename T> template <typename T>
static std::string GetMinMaxStr(const T *x, static std::string GetMinMaxStr(const T *x, size_t n, const phi::Place &place) {
size_t n,
const platform::Place &place) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place), place.GetType() == phi::AllocationType::GPU,
true, true,
platform::errors::InvalidArgument("Only support CUDAPlace currently.")); phi::errors::InvalidArgument("Only support CUDAPlace currently."));
auto *dev_ctx = static_cast<phi::GPUContext *>( auto *dev_ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(place)); phi::DeviceContextPool::Instance().Get(place));
auto stream = dev_ctx->stream(); auto stream = dev_ctx->stream();
phi::memory_utils::Buffer ret_buffer(place); memory_utils::Buffer ret_buffer(place);
T *ret = ret_buffer.Alloc<T>(2); T *ret = ret_buffer.Alloc<T>(2);
if (n > 0) { if (n > 0) {
phi::memory_utils::Buffer cub_buffer(place); memory_utils::Buffer cub_buffer(place);
CubDeviceReduce(x, CubDeviceReduce(x,
ret, ret,
n, n,
...@@ -1160,45 +1161,20 @@ struct VisitDTypeFunctor { ...@@ -1160,45 +1161,20 @@ struct VisitDTypeFunctor {
static std::string GetMinMaxStr(const phi::DenseTensor *x) { static std::string GetMinMaxStr(const phi::DenseTensor *x) {
if (x == nullptr) return "null"; if (x == nullptr) return "null";
if (!x->IsInitialized()) return "not_inited"; if (!x->IsInitialized()) return "not_inited";
if (!platform::is_gpu_place(x->place())) return "CPUTensor"; if (x->place().GetType() != phi::AllocationType::GPU) return "CPUTensor";
std::string str; std::string str;
VisitDTypeFunctor functor(x, &str); VisitDTypeFunctor functor(x, &str);
phi::VisitDataType(x->dtype(), functor); phi::VisitDataType(x->dtype(), functor);
return str; return str;
} }
static void PrintAllMinMaxRange(const framework::ExecutionContext &ctx,
bool only_inputs) {
if (!VLOG_IS_ON(1)) return;
for (const auto &pair : ctx.GetOp().Inputs()) {
const auto &key = pair.first;
const auto tensors = ctx.MultiInput<phi::DenseTensor>(key);
size_t n = tensors.size();
for (size_t i = 0; i < n; ++i) {
VLOG(1) << "Input(" << key + ")[" << i << "] = " << pair.second[i]
<< " , " << GetMinMaxStr(tensors[i]);
}
}
if (only_inputs) return;
for (const auto &pair : ctx.GetOp().Outputs()) {
const auto &key = pair.first;
const auto tensors = ctx.MultiOutput<phi::DenseTensor>(key);
size_t n = tensors.size();
for (size_t i = 0; i < n; ++i) {
VLOG(1) << "Output(" << key + ")[" << i << "] = " << pair.second[i]
<< " , " << GetMinMaxStr(tensors[i]);
}
}
}
template <typename T> template <typename T>
static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) { static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) {
if (numel <= 0) return false; if (numel <= 0) return false;
cub::TransformInputIterator<bool, IsNanInfFunctor<T>, const T *> iter( cub::TransformInputIterator<bool, IsNanInfFunctor<T>, const T *> iter(
x, IsNanInfFunctor<T>()); x, IsNanInfFunctor<T>());
phi::memory_utils::Buffer buffer(dev_ctx.GetPlace()); memory_utils::Buffer buffer(dev_ctx.GetPlace());
phi::memory_utils::Buffer out(dev_ctx.GetPlace()); memory_utils::Buffer out(dev_ctx.GetPlace());
CubDeviceReduce(iter, CubDeviceReduce(iter,
out.Alloc<bool>(1), out.Alloc<bool>(1),
numel, numel,
...@@ -1226,11 +1202,11 @@ static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) { ...@@ -1226,11 +1202,11 @@ static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) {
static void CheckHasNanInfGrad(const float *fp32_grad, static void CheckHasNanInfGrad(const float *fp32_grad,
int fp32_numel, int fp32_numel,
const platform::float16 *fp16_grad, const dtype::float16 *fp16_grad,
int fp16_numel, int fp16_numel,
float *nan_inf_flag, float *nan_inf_flag,
gpuStream_t stream, gpuStream_t stream,
phi::memory_utils::Buffer *cub_tmp_buffer) { memory_utils::Buffer *cub_tmp_buffer) {
bool *fp32_has_nan_inf = nullptr; bool *fp32_has_nan_inf = nullptr;
bool *fp16_has_nan_inf = nullptr; bool *fp16_has_nan_inf = nullptr;
if (fp32_numel > 0) { if (fp32_numel > 0) {
...@@ -1249,9 +1225,9 @@ static void CheckHasNanInfGrad(const float *fp32_grad, ...@@ -1249,9 +1225,9 @@ static void CheckHasNanInfGrad(const float *fp32_grad,
if (fp16_numel > 0) { if (fp16_numel > 0) {
fp16_has_nan_inf = reinterpret_cast<bool *>(nan_inf_flag + 1) + 1; fp16_has_nan_inf = reinterpret_cast<bool *>(nan_inf_flag + 1) + 1;
cub::TransformInputIterator<bool, cub::TransformInputIterator<bool,
IsNanInfFunctor<platform::float16>, IsNanInfFunctor<dtype::float16>,
const platform::float16 *> const dtype::float16 *>
iter(fp16_grad, IsNanInfFunctor<platform::float16>()); iter(fp16_grad, IsNanInfFunctor<dtype::float16>());
CubDeviceReduce(iter, CubDeviceReduce(iter,
fp16_has_nan_inf, fp16_has_nan_inf,
fp16_numel, fp16_numel,
...@@ -1316,7 +1292,7 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx, ...@@ -1316,7 +1292,7 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx,
int vec_size = int vec_size =
std::min(std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)), std::min(std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)),
GetChunkedVecSize(z, 0)); GetChunkedVecSize(z, 0));
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
#define PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL \ #define PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL \
do { \ do { \
...@@ -1329,656 +1305,579 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx, ...@@ -1329,656 +1305,579 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx,
#undef PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL #undef PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL
} }
template <typename T> template <typename T, typename Context>
class DistributedFusedLambOpKernel<T, phi::GPUContext> void DistributedFusedLambKernel(
: public framework::OpKernel<T> { const Context &dev_ctx,
public: const std::vector<const DenseTensor *> &param,
void Compute(const framework::ExecutionContext &ctx) const override { const std::vector<const DenseTensor *> &grad, /*unused*/
const paddle::optional<DenseTensor> &fp32_param,
const paddle::optional<DenseTensor> &fp32_grad,
const paddle::optional<DenseTensor> &fp16_param,
const paddle::optional<DenseTensor> &fp16_grad,
const DenseTensor &moment1,
const DenseTensor &moment2,
const DenseTensor &beta1_pow,
const DenseTensor &beta2_pow,
const DenseTensor &param_offsets,
const DenseTensor &fp32_partial_offsets,
const DenseTensor &fp16_partial_offsets,
const DenseTensor &param_info,
const DenseTensor &param_order,
const DenseTensor &learning_rate,
const DenseTensor &global_scale,
int acc_steps,
float beta1,
float beta2,
float epsilon,
float max_global_grad_norm,
float weight_decay,
bool clip_after_allreduce,
bool use_master_param_norm,
bool use_master_acc_grad,
bool is_grad_scaled_by_nranks,
bool use_hierarchical_allreduce,
int64_t nranks,
const std::vector<int> &ring_ids,
DenseTensor *fp32_param_out,
DenseTensor *fp16_param_out,
DenseTensor *fp32_acc_grad,
DenseTensor *fp16_acc_grad,
DenseTensor *moment1_out,
DenseTensor *moment2_out,
DenseTensor *beta1_pow_out,
DenseTensor *beta2_pow_out,
DenseTensor *param_out, /*unused*/
DenseTensor *found_inf,
DenseTensor *acc_step,
DenseTensor *stop_update,
DenseTensor *step) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto stream = dev_ctx.stream();
auto stream = dev_ctx.stream(); auto place = dev_ctx.GetPlace();
auto place = dev_ctx.GetPlace();
found_inf->Resize({1});
auto *found_inf_t = ctx.Output<phi::DenseTensor>("FoundInf");
found_inf_t->Resize({1}); // Step 1: Get fp16 param and grad tensors
int64_t fp16_numel;
// Step 1: Get fp16 param and grad tensors auto *fp16_param_data =
int64_t fp16_numel; GetSameInOutTensorPtr<dtype::float16, Context, true>(dev_ctx,
auto *fp16_param = GetSameInOutTensorPtr<platform::float16, true>( fp16_param.get_ptr(),
ctx, place, "FP16FusedParam", "FP16FusedParamOut", &fp16_numel); fp16_param_out,
bool has_fp16_param = (fp16_numel > 0); "FP16FusedParam",
const platform::float16 *fp16_grad = nullptr; "FP16FusedParamOut",
if (has_fp16_param) { &fp16_numel);
fp16_grad = GetInputTensorPtr<platform::float16>(ctx, "FP16FusedGrad"); bool has_fp16_param = (fp16_numel > 0);
const dtype::float16 *fp16_grad_data = nullptr;
if (has_fp16_param) {
fp16_grad_data =
GetInputTensorPtr<dtype::float16>(fp16_grad.get_ptr(), "FP16FusedGrad");
} else {
fp16_param_data = nullptr;
}
// Step 2: Get fp32 param and grad tensors
int64_t fp32_numel = 0;
auto *fp32_param_data =
GetSameInOutTensorPtr<float, Context, true>(dev_ctx,
fp32_param.get_ptr(),
fp32_param_out,
"FP32FusedParam",
"FP32FusedParamOut",
&fp32_numel);
PADDLE_ENFORCE_GE(fp32_numel,
fp16_numel,
phi::errors::InvalidArgument(
"The element number in FP32FusedParam should be not "
"less than FP16FusedParam."));
fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and
// fp16 master weight
bool has_fp32_param = (fp32_numel > 0);
const float *fp32_grad_data = nullptr;
if (has_fp32_param) {
fp32_grad_data =
GetInputTensorPtr<float>(fp32_grad.get_ptr(), "FP32FusedGrad");
} else {
PADDLE_ENFORCE_EQ(
has_fp16_param,
true,
phi::errors::InvalidArgument(
"Either FP32FusedGrad or FP16FusedGrad cannot be NULL."));
}
auto numel = fp32_numel + fp16_numel;
VLOG(1) << "numel = " << numel << " , fp32_numel = " << fp32_numel
<< " , fp16_numel = " << fp16_numel;
// The NVIDIA cub library does not support number > INT32_MAX
PADDLE_ENFORCE_LE(numel,
std::numeric_limits<int>::max(),
phi::errors::Unimplemented(
"Too many parameter number. Only <= %d is supported.",
std::numeric_limits<int>::max()));
PADDLE_ENFORCE_GE(
acc_steps,
1,
phi::errors::InvalidArgument(
"The gradient accumulation steps should be not less than 1."));
if (acc_steps > 1) {
PADDLE_ENFORCE_NOT_NULL(
acc_step,
phi::errors::InvalidArgument(
"Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1."));
bool is_initialized = acc_step->IsInitialized();
int64_t *acc_step_data;
if (is_initialized) {
acc_step_data = dev_ctx.template HostAlloc<int64_t>(acc_step);
++(*acc_step_data);
} else { } else {
fp16_param = nullptr; acc_step->Resize({1});
acc_step_data = dev_ctx.template HostAlloc<int64_t>(acc_step);
*acc_step_data = 1;
} }
int64_t rounded_step = (*acc_step_data) % acc_steps;
// Step 2: Get fp32 param and grad tensors float *fp32_acc_grad_data = nullptr;
int64_t fp32_numel = 0;
auto *fp32_param = GetSameInOutTensorPtr<float, true>(
ctx, place, "FP32FusedParam", "FP32FusedParamOut", &fp32_numel);
PADDLE_ENFORCE_GE(fp32_numel,
fp16_numel,
platform::errors::InvalidArgument(
"The element number in FP32FusedParam should be not "
"less than FP16FusedParam."));
fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and
// fp16 master weight
bool has_fp32_param = (fp32_numel > 0);
const float *fp32_grad = nullptr;
if (has_fp32_param) { if (has_fp32_param) {
fp32_grad = GetInputTensorPtr<float>(ctx, "FP32FusedGrad"); PADDLE_ENFORCE_NOT_NULL(fp32_acc_grad,
} else { phi::errors::InvalidArgument(
PADDLE_ENFORCE_EQ( "Output(FP32AccFusedGrad) cannot be nullptr "
has_fp16_param, "when Attr(acc_steps) > 1."));
true, if (!fp32_acc_grad->IsInitialized()) {
platform::errors::InvalidArgument( fp32_acc_grad->Resize({static_cast<int64_t>(fp32_numel)});
"Either FP32FusedGrad or FP16FusedGrad cannot be NULL.")); fp32_acc_grad_data = dev_ctx.template Alloc<float>(fp32_acc_grad);
} else {
fp32_acc_grad_data = fp32_acc_grad->data<float>();
}
} }
auto numel = fp32_numel + fp16_numel; dtype::float16 *fp16_acc_grad_data = nullptr;
VLOG(1) << "numel = " << numel << " , fp32_numel = " << fp32_numel float *master_acc_grad = nullptr;
<< " , fp16_numel = " << fp16_numel; if (has_fp16_param) {
PADDLE_ENFORCE_NOT_NULL(fp16_acc_grad,
// The NVIDIA cub library does not support number > INT32_MAX phi::errors::InvalidArgument(
PADDLE_ENFORCE_LE(numel, "Output(FP16AccFusedGrad) cannot be nullptr "
std::numeric_limits<int>::max(), "when Attr(acc_steps) > 1."));
platform::errors::Unimplemented( if (!fp16_acc_grad->IsInitialized()) {
"Too many parameter number. Only <= %d is supported.", auto acc_grad_size =
std::numeric_limits<int>::max())); use_master_acc_grad ? (3 * fp16_numel) : fp16_numel;
fp16_acc_grad->Resize({static_cast<int64_t>(acc_grad_size)});
auto acc_steps = ctx.Attr<int>("acc_steps"); fp16_acc_grad_data =
PADDLE_ENFORCE_GE( dev_ctx.template Alloc<dtype::float16>(fp16_acc_grad);
acc_steps,
1,
platform::errors::InvalidArgument(
"The gradient accumulation steps should be not less than 1."));
if (acc_steps > 1) {
auto *step_t = ctx.Output<phi::DenseTensor>("AccStep");
PADDLE_ENFORCE_NOT_NULL(
step_t,
platform::errors::InvalidArgument(
"Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1."));
bool is_initialized = step_t->IsInitialized();
int64_t *step_ptr;
if (is_initialized) {
step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace());
++(*step_ptr);
} else { } else {
step_t->Resize({1}); fp16_acc_grad_data = fp16_acc_grad->data<dtype::float16>();
step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace());
*step_ptr = 1;
} }
int64_t rounded_step = (*step_ptr) % acc_steps; if (use_master_acc_grad) {
master_acc_grad =
float *fp32_acc_grad = nullptr; reinterpret_cast<float *>(fp16_acc_grad_data + fp16_numel);
if (has_fp32_param) {
auto *fp32_acc_grad_t =
ctx.Output<phi::DenseTensor>("FP32AccFusedGrad");
PADDLE_ENFORCE_NOT_NULL(
fp32_acc_grad_t,
platform::errors::InvalidArgument(
"Output(FP32AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1."));
if (!fp32_acc_grad_t->IsInitialized()) {
fp32_acc_grad_t->Resize({static_cast<int64_t>(fp32_numel)});
fp32_acc_grad = fp32_acc_grad_t->mutable_data<float>(place);
} else {
fp32_acc_grad = fp32_acc_grad_t->data<float>();
}
} }
} else {
use_master_acc_grad = false;
}
platform::float16 *fp16_acc_grad = nullptr; // Inplace addto
float *master_acc_grad = nullptr; if (has_fp32_param) {
bool use_master_acc_grad = false; if (rounded_step == 1) {
if (has_fp16_param) { memory_utils::Copy(place,
use_master_acc_grad = ctx.Attr<bool>("use_master_acc_grad"); fp32_acc_grad_data,
auto *fp16_acc_grad_t = place,
ctx.Output<phi::DenseTensor>("FP16AccFusedGrad"); fp32_grad_data,
PADDLE_ENFORCE_NOT_NULL( fp32_numel * sizeof(float),
fp16_acc_grad_t, stream);
platform::errors::InvalidArgument( } else {
"Output(FP16AccFusedGrad) cannot be nullptr " LaunchElementwiseAddWithCastKernel(dev_ctx,
"when Attr(acc_steps) > 1.")); fp32_grad_data,
if (!fp16_acc_grad_t->IsInitialized()) { fp32_acc_grad_data,
auto acc_grad_size = fp32_acc_grad_data,
use_master_acc_grad ? (3 * fp16_numel) : fp16_numel; fp32_numel,
fp16_acc_grad_t->Resize({static_cast<int64_t>(acc_grad_size)}); stream);
fp16_acc_grad =
fp16_acc_grad_t->mutable_data<platform::float16>(place);
} else {
fp16_acc_grad = fp16_acc_grad_t->data<platform::float16>();
}
if (use_master_acc_grad) {
master_acc_grad =
reinterpret_cast<float *>(fp16_acc_grad + fp16_numel);
}
} }
}
// Inplace addto if (has_fp16_param) {
if (has_fp32_param) { if (acc_steps == 2 || !use_master_acc_grad) {
if (rounded_step == 1) { if (rounded_step != 1) {
memory::Copy(place, LaunchElementwiseAddWithCastKernel(dev_ctx,
fp32_acc_grad, fp16_acc_grad_data,
place, fp16_grad_data,
fp32_grad, fp16_acc_grad_data,
fp32_numel * sizeof(float), fp16_numel,
stream); stream);
} else {
memory_utils::Copy(place,
fp16_acc_grad_data,
place,
fp16_grad_data,
fp16_numel * sizeof(dtype::float16),
stream);
}
} else { // acc_steps >= 3
if (rounded_step == 0) {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad_data,
master_acc_grad,
fp16_acc_grad_data,
fp16_numel,
stream);
} else if (rounded_step == 1) {
memory_utils::Copy(place,
fp16_acc_grad_data,
place,
fp16_grad_data,
fp16_numel * sizeof(dtype::float16),
stream);
} else if (rounded_step == 2) {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad_data,
fp16_acc_grad_data,
master_acc_grad,
fp16_numel,
stream);
} else { } else {
LaunchElementwiseAddWithCastKernel(dev_ctx, LaunchElementwiseAddWithCastKernel(dev_ctx,
fp32_grad, fp16_grad_data,
fp32_acc_grad, master_acc_grad,
fp32_acc_grad, master_acc_grad,
fp32_numel, fp16_numel,
stream); stream);
} }
} }
}
if (has_fp16_param) { stop_update->Resize({1});
if (acc_steps == 2 || !use_master_acc_grad) { auto *stop_update_data = dev_ctx.template HostAlloc<bool>(stop_update);
if (rounded_step != 1) { auto *found_inf_cpu = dev_ctx.template HostAlloc<bool>(found_inf);
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_acc_grad,
fp16_grad,
fp16_acc_grad,
fp16_numel,
stream);
} else {
memory::Copy(place,
fp16_acc_grad,
place,
fp16_grad,
fp16_numel * sizeof(platform::float16),
stream);
}
} else { // acc_steps >= 3
if (rounded_step == 0) {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad,
master_acc_grad,
fp16_acc_grad,
fp16_numel,
stream);
} else if (rounded_step == 1) {
memory::Copy(place,
fp16_acc_grad,
place,
fp16_grad,
fp16_numel * sizeof(platform::float16),
stream);
} else if (rounded_step == 2) {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad,
fp16_acc_grad,
master_acc_grad,
fp16_numel,
stream);
} else {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad,
master_acc_grad,
master_acc_grad,
fp16_numel,
stream);
}
}
}
auto *stop_update_t = ctx.Output<phi::DenseTensor>("StopUpdate"); if (rounded_step != 0) {
stop_update_t->Resize({1}); *stop_update_data = true;
auto *stop_update = *found_inf_cpu = false;
stop_update_t->mutable_data<bool>(platform::CPUPlace()); return;
} else {
// swap pointer
fp32_grad_data = fp32_acc_grad_data;
fp16_grad_data = fp16_acc_grad_data;
*stop_update_data = false;
found_inf->clear();
}
}
auto *found_inf_cpu = // Step 3: Get ParamInfo
found_inf_t->mutable_data<bool>(platform::CPUPlace()); const auto *param_info_data =
GetInputTensorPtr<int>(&param_info, "ParamInfo");
auto fp32_local_start_idx = param_info_data[0];
auto fp32_local_param_num = param_info_data[1];
auto fp32_global_param_num = param_info_data[2];
auto fp32_weight_decay_end_idx = param_info_data[3];
auto fp16_local_start_idx = param_info_data[4];
auto fp16_local_param_num = param_info_data[5];
auto fp16_global_param_num = param_info_data[6];
auto fp16_weight_decay_end_idx = param_info_data[7];
auto local_param_num = fp32_local_param_num + fp16_local_param_num;
auto param_num = fp32_global_param_num + fp16_global_param_num;
PADDLE_ENFORCE_LE(local_param_num,
param_num,
phi::errors::InvalidArgument(
"The local parameter number should not exceed the "
"global parameter number."));
VLOG(1) << "local_param_num = " << local_param_num
<< " , global_param_num = " << param_num
<< " , fp32_local_start_idx = " << fp32_local_start_idx
<< " , fp32_local_param_num = " << fp32_local_param_num
<< " , fp32_global_param_num = " << fp32_global_param_num
<< " , fp16_local_start_idx = " << fp16_local_start_idx
<< " , fp16_local_param_num = " << fp16_local_param_num
<< " , fp16_global_param_num = " << fp16_global_param_num;
// Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow,
// GlobalScale
const auto *global_scale_data =
GetInputTensorPtr<float>(&global_scale, "GlobalScale");
const auto *lr_data =
GetInputTensorPtr<float>(&learning_rate, "LearningRate");
int64_t partial_numel = 0;
auto *moment1_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &moment1, moment1_out, "Moment1", "Moment1Out", &partial_numel);
PADDLE_ENFORCE_EQ(numel % partial_numel,
0,
phi::errors::InvalidArgument(
"The total parameter number %d should be divided "
"exactly by the element number %d of Moment1.",
numel,
partial_numel));
// The num_devices means the number of devices that shard a complete set
// of all parameters. It may be num_devices < nranks or num_devices ==
// nranks.
int64_t num_devices = numel / partial_numel;
VLOG(1) << "num_devices = " << num_devices
<< " , partial_numel = " << partial_numel;
PADDLE_ENFORCE_EQ(fp32_numel % num_devices,
0,
phi::errors::InvalidArgument(
"The fp32 parameter number %d should be divided "
"exactly by the device number %d.",
fp32_numel,
num_devices));
PADDLE_ENFORCE_EQ(fp16_numel % num_devices,
0,
phi::errors::InvalidArgument(
"The fp16 parameter number %d should be divided "
"exactly by the device number %d.",
fp16_numel,
num_devices));
auto *moment2_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &moment2, moment2_out, "Moment2", "Moment2Out");
auto *beta1_pow_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &beta1_pow, beta1_pow_out, "Beta1Pow", "Beta1PowOut");
auto *beta2_pow_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &beta2_pow, beta2_pow_out, "Beta2Pow", "Beta2PowOut");
auto *found_inf_data = dev_ctx.template Alloc<bool>(found_inf);
// Step 5: Get attributes weight_decay, beta1, beta2, epsilon,
// max_grad_norm, ring_id,
// use_master_param_norm, is_grad_scaled_by_nranks
PADDLE_ENFORCE_GE(nranks,
num_devices,
phi::errors::InvalidArgument(
"The nranks must be not less than num_devices."));
PADDLE_ENFORCE_EQ(nranks % num_devices,
0,
phi::errors::InvalidArgument(
"The nranks must be exactly divided by num_devices."));
bool local_shard = (nranks > num_devices);
VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
<< " , clip_after_allreduce = " << clip_after_allreduce
<< " , use_master_param_norm = " << use_master_param_norm
<< " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
<< " , local_shard = " << local_shard
<< " , use_hierarchical_allreduce = " << use_hierarchical_allreduce;
// Step 6: allreduce + global norm gradient clip
int64_t global_rank = 0, local_rank = 0;
ncclComm_t global_comm = nullptr, local_comm = nullptr,
external_comm = nullptr;
if (nranks > 1) {
auto *nccl_comm_handle =
paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
global_comm = nccl_comm_handle->comm();
global_rank = nccl_comm_handle->rank();
if (rounded_step != 0) { if (local_shard) {
*stop_update = true; auto *local_nccl_comm_handle =
auto *found_inf_cpu = paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
found_inf_t->mutable_data<bool>(platform::CPUPlace()); local_comm = local_nccl_comm_handle->comm();
*found_inf_cpu = false; local_rank = local_nccl_comm_handle->rank();
return; if (use_hierarchical_allreduce) {
} else { external_comm = paddle::platform::NCCLCommContext::Instance()
// swap pointer .Get(ring_ids[2], place)
fp32_grad = fp32_acc_grad; ->comm();
fp16_grad = fp16_acc_grad;
*stop_update = false;
found_inf_t->clear();
} }
} else {
local_comm = global_comm;
local_rank = global_rank;
} }
}
// Step 3: Get ParamInfo memory_utils::Buffer grad_norm_square_buffer(place);
const auto *param_info_tensor = GetInputTensorPtr<int>(ctx, "ParamInfo"); auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2);
auto fp32_local_start_idx = param_info_tensor[0]; memory_utils::Buffer cub_tmp_buffer(place);
auto fp32_local_param_num = param_info_tensor[1];
auto fp32_global_param_num = param_info_tensor[2]; memory_utils::Buffer sum_grad_buffer(place);
auto fp32_weight_decay_end_idx = param_info_tensor[3]; float *fp32_sum_grad;
auto fp16_local_start_idx = param_info_tensor[4]; dtype::float16 *fp16_sum_grad;
auto fp16_local_param_num = param_info_tensor[5]; auto fp32_numel_each_device = fp32_numel / num_devices;
auto fp16_global_param_num = param_info_tensor[6]; auto fp16_numel_each_device = fp16_numel / num_devices;
auto fp16_weight_decay_end_idx = param_info_tensor[7]; if (local_shard) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>(
auto local_param_num = fp32_local_param_num + fp16_local_param_num; fp32_numel * sizeof(float) + fp16_numel * sizeof(dtype::float16));
auto param_num = fp32_global_param_num + fp16_global_param_num; fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
PADDLE_ENFORCE_LE(local_param_num, fp16_sum_grad = has_fp16_param ? reinterpret_cast<dtype::float16 *>(
param_num, ptr + fp32_numel * sizeof(float))
platform::errors::InvalidArgument( : nullptr;
"The local parameter number should not exceed the " } else if (nranks > 1 ||
"global parameter number.")); (max_global_grad_norm > 0 && !clip_after_allreduce)) {
VLOG(1) << "local_param_num = " << local_param_num auto ptr = sum_grad_buffer.Alloc<uint8_t>(
<< " , global_param_num = " << param_num fp32_numel_each_device * sizeof(float) +
<< " , fp32_local_start_idx = " << fp32_local_start_idx fp16_numel_each_device * sizeof(dtype::float16));
<< " , fp32_local_param_num = " << fp32_local_param_num fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
<< " , fp32_global_param_num = " << fp32_global_param_num fp16_sum_grad = has_fp16_param
<< " , fp16_local_start_idx = " << fp16_local_start_idx ? reinterpret_cast<dtype::float16 *>(
<< " , fp16_local_param_num = " << fp16_local_param_num ptr + fp32_numel_each_device * sizeof(float))
<< " , fp16_global_param_num = " << fp16_global_param_num; : nullptr;
} else {
// Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow, // NOTE: The const_cast here is not important. The fp32_sum_grad and
// GlobalScale // fp16_sum_grad would not be changed when num_devices == 1
const auto *global_scale = GetInputTensorPtr<float>(ctx, "GlobalScale"); // But if I do not perform const_cast here, there would be more
const auto *lr = GetInputTensorPtr<float>(ctx, "LearningRate"); // if-else codes (num_devices > 1) when I write the following code.
int64_t partial_numel = 0; // So I prefer to use const_cast to unify the following code to reduce
auto *moment1 = GetSameInOutTensorPtr<float>( // the if-else codes.
ctx, place, "Moment1", "Moment1Out", &partial_numel); fp32_sum_grad = const_cast<float *>(fp32_grad_data);
fp16_sum_grad = const_cast<dtype::float16 *>(fp16_grad_data);
PADDLE_ENFORCE_EQ(numel % partial_numel, }
0,
platform::errors::InvalidArgument( float rescale_grad = 1.0f;
"The total parameter number %d should be divided " if (!is_grad_scaled_by_nranks) {
"exactly by the element number %d of Moment1.", rescale_grad /= nranks;
numel, }
partial_numel));
// The num_devices means the number of devices that shard a complete set
// of all parameters. It may be num_devices < nranks or num_devices ==
// nranks.
int64_t num_devices = numel / partial_numel;
VLOG(1) << "num_devices = " << num_devices
<< " , partial_numel = " << partial_numel;
PADDLE_ENFORCE_EQ(fp32_numel % num_devices,
0,
platform::errors::InvalidArgument(
"The fp32 parameter number %d should be divided "
"exactly by the device number %d.",
fp32_numel,
num_devices));
PADDLE_ENFORCE_EQ(fp16_numel % num_devices,
0,
platform::errors::InvalidArgument(
"The fp16 parameter number %d should be divided "
"exactly by the device number %d.",
fp16_numel,
num_devices));
auto *moment2 =
GetSameInOutTensorPtr<float>(ctx, place, "Moment2", "Moment2Out");
auto *beta1pow =
GetSameInOutTensorPtr<float>(ctx, place, "Beta1Pow", "Beta1PowOut");
auto *beta2pow =
GetSameInOutTensorPtr<float>(ctx, place, "Beta2Pow", "Beta2PowOut");
auto *found_inf = found_inf_t->mutable_data<bool>(place);
// Step 5: Get attributes weight_decay, beta1, beta2, epsilon,
// max_grad_norm, ring_id,
// use_master_param_norm, is_grad_scaled_by_nranks
auto weight_decay = ctx.Attr<float>("weight_decay");
auto beta1 = ctx.Attr<float>("beta1");
auto beta2 = ctx.Attr<float>("beta2");
auto epsilon = ctx.Attr<float>("epsilon");
auto max_global_grad_norm = ctx.Attr<float>("max_global_grad_norm");
auto clip_after_allreduce = ctx.Attr<bool>("clip_after_allreduce");
auto nranks = ctx.Attr<int64_t>("nranks");
PADDLE_ENFORCE_GE(nranks,
num_devices,
phi::errors::InvalidArgument(
"The nranks must be not less than num_devices."));
PADDLE_ENFORCE_EQ(
nranks % num_devices,
0,
phi::errors::InvalidArgument(
"The nranks must be exactly divided by num_devices."));
bool local_shard = (nranks > num_devices);
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_ids");
auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
auto use_hierarchical_allreduce =
ctx.Attr<bool>("use_hierarchical_allreduce");
VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
<< " , clip_after_allreduce = " << clip_after_allreduce
<< " , use_master_param_norm = " << use_master_param_norm
<< " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
<< " , local_shard = " << local_shard
<< " , use_hierarchical_allreduce = "
<< use_hierarchical_allreduce;
// Step 6: allreduce + global norm gradient clip
int64_t global_rank = 0, local_rank = 0;
ncclComm_t global_comm = nullptr, local_comm = nullptr,
external_comm = nullptr;
if (nranks > 1) {
auto *nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
global_comm = nccl_comm_handle->comm();
global_rank = nccl_comm_handle->rank();
if (max_global_grad_norm > 0) {
if (clip_after_allreduce) {
// (1) ReduceScater first
if (local_shard) { if (local_shard) {
auto *local_nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
local_comm = local_nccl_comm_handle->comm();
local_rank = local_nccl_comm_handle->rank();
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
external_comm = platform::NCCLCommContext::Instance() NCCLReduceScatterWithScale(
.Get(ring_ids[2], place) fp32_grad_data,
->comm(); fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad_data,
fp32_sum_grad,
fp32_numel,
nranks,
global_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad_data,
fp16_sum_grad,
fp16_numel,
nranks,
global_comm,
stream,
dev_ctx);
} }
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else { } else {
local_comm = global_comm; NCCLReduceScatterWithScale(fp32_grad_data,
local_rank = global_rank;
}
}
phi::memory_utils::Buffer grad_norm_square_buffer(place);
auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2);
phi::memory_utils::Buffer cub_tmp_buffer(place);
phi::memory_utils::Buffer sum_grad_buffer(place);
float *fp32_sum_grad;
platform::float16 *fp16_sum_grad;
auto fp32_numel_each_device = fp32_numel / num_devices;
auto fp16_numel_each_device = fp16_numel / num_devices;
if (local_shard) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel * sizeof(float) + fp16_numel * sizeof(platform::float16));
fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
fp16_sum_grad = has_fp16_param ? reinterpret_cast<platform::float16 *>(
ptr + fp32_numel * sizeof(float))
: nullptr;
} else if (nranks > 1 ||
(max_global_grad_norm > 0 && !clip_after_allreduce)) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel_each_device * sizeof(float) +
fp16_numel_each_device * sizeof(platform::float16));
fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
fp16_sum_grad = has_fp16_param
? reinterpret_cast<platform::float16 *>(
ptr + fp32_numel_each_device * sizeof(float))
: nullptr;
} else {
// NOTE: The const_cast here is not important. The fp32_sum_grad and
// fp16_sum_grad would not be changed when num_devices == 1
// But if I do not perform const_cast here, there would be more
// if-else codes (num_devices > 1) when I write the following code.
// So I prefer to use const_cast to unify the following code to reduce
// the if-else codes.
fp32_sum_grad = const_cast<float *>(fp32_grad);
fp16_sum_grad = const_cast<platform::float16 *>(fp16_grad);
}
float rescale_grad = 1.0f;
if (!is_grad_scaled_by_nranks) {
rescale_grad /= nranks;
}
if (max_global_grad_norm > 0) {
if (clip_after_allreduce) {
// (1) ReduceScater first
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLReduceScatterWithScale(
fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad, fp32_sum_grad,
fp32_numel, fp32_numel_each_device,
nranks, nranks,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLAllReduceWithScale(fp16_grad, NCCLReduceScatterWithScale(fp16_grad_data,
fp16_sum_grad, fp16_sum_grad,
fp16_numel, fp16_numel_each_device,
nranks, nranks,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx);
} }
fp32_sum_grad += (local_rank * fp32_numel_each_device); // (2) Calculate the global grad norm
fp16_sum_grad += (local_rank * fp16_numel_each_device); GetSquareGradNorm(fp32_sum_grad,
} else { fp32_numel_each_device,
NCCLReduceScatterWithScale(fp32_grad, fp16_sum_grad,
fp32_sum_grad, fp16_numel_each_device,
fp32_numel_each_device, fp32_square_grad_norm,
nranks, stream,
global_comm, &cub_tmp_buffer);
stream, VLOG(1) << "Grad square norm before all reduce: "
dev_ctx); << FlattenToString(fp32_square_grad_norm, 1, place);
NCCLReduceScatterWithScale(fp16_grad, if (num_devices > 1) {
fp16_sum_grad, PADDLE_ENFORCE_GPU_SUCCESS(
fp16_numel_each_device, phi::dynload::ncclAllReduce(fp32_square_grad_norm,
nranks, fp32_square_grad_norm,
global_comm, 1,
stream, ncclFloat32,
dev_ctx); ncclSum,
} local_comm,
// (2) Calculate the global grad norm stream));
GetSquareGradNorm(fp32_sum_grad, }
fp32_numel_each_device, VLOG(1) << "Grad square norm after all reduce: "
fp16_sum_grad, << FlattenToString(fp32_square_grad_norm, 1, place);
fp16_numel_each_device, } else {
fp32_square_grad_norm, // (1) Calculate the local grad norm
stream, GetSquareGradNorm(fp32_grad_data,
&cub_tmp_buffer); fp32_numel,
VLOG(1) << "Grad square norm before all reduce: " fp16_grad_data,
<< FlattenToString(fp32_square_grad_norm, 1, place); fp16_numel,
if (num_devices > 1) { fp32_square_grad_norm,
PADDLE_ENFORCE_GPU_SUCCESS( stream,
platform::dynload::ncclAllReduce(fp32_square_grad_norm, &cub_tmp_buffer);
fp32_square_grad_norm, VLOG(1) << "Grad square norm before all reduce: "
1, << FlattenToString(fp32_square_grad_norm, 1, place);
ncclFloat32, // (2) Calculate the gradient clip scale
ncclSum, float *fp32_scale = nullptr;
local_comm, dtype::float16 *fp16_scale = nullptr;
stream)); if (has_fp32_param && has_fp16_param) {
} auto *ptr = cub_tmp_buffer.Alloc<uint8_t>(sizeof(float) +
VLOG(1) << "Grad square norm after all reduce: " sizeof(dtype::float16));
<< FlattenToString(fp32_square_grad_norm, 1, place); fp32_scale = reinterpret_cast<float *>(ptr);
fp16_scale = reinterpret_cast<dtype::float16 *>(ptr + sizeof(float));
} else if (has_fp32_param) {
fp32_scale = cub_tmp_buffer.Alloc<float>(1);
} else { } else {
// (1) Calculate the local grad norm fp16_scale = cub_tmp_buffer.Alloc<dtype::float16>(1);
GetSquareGradNorm(fp32_grad, }
fp32_numel,
fp16_grad,
fp16_numel,
fp32_square_grad_norm,
stream,
&cub_tmp_buffer);
VLOG(1) << "Grad square norm before all reduce: "
<< FlattenToString(fp32_square_grad_norm, 1, place);
// (2) Calculate the gradient clip scale
float *fp32_scale = nullptr;
platform::float16 *fp16_scale = nullptr;
if (has_fp32_param && has_fp16_param) {
auto *ptr = cub_tmp_buffer.Alloc<uint8_t>(sizeof(float) +
sizeof(platform::float16));
fp32_scale = reinterpret_cast<float *>(ptr);
fp16_scale =
reinterpret_cast<platform::float16 *>(ptr + sizeof(float));
} else if (has_fp32_param) {
fp32_scale = cub_tmp_buffer.Alloc<float>(1);
} else {
fp16_scale = cub_tmp_buffer.Alloc<platform::float16>(1);
}
float clip_scale = 1.0f; float clip_scale = 1.0f;
if (is_grad_scaled_by_nranks) { if (is_grad_scaled_by_nranks) {
clip_scale *= nranks; clip_scale *= nranks;
}
CalcGradNormClipBeforeAllReduceScale<float, platform::float16>
<<<1, 1, 0, stream>>>(global_scale,
max_global_grad_norm,
fp32_square_grad_norm,
fp32_scale,
fp16_scale,
clip_scale);
if (fp32_scale) {
VLOG(1) << "Grad scale: " << FlattenToString(fp32_scale, 1, place);
} else {
VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place);
}
// (3) Do ReduceScatter with scale
VLOG(1) << "FP32 HasNanInf before all reduce: "
<< HasNanInf(dev_ctx, fp32_grad, fp32_numel);
VLOG(1) << "FP16 HasNanInf before all reduce: "
<< HasNanInf(dev_ctx, fp16_grad, fp16_numel);
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLReduceScatterWithScale(
fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx,
fp32_scale);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx,
fp16_scale);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks,
global_comm,
stream,
dev_ctx,
fp32_scale);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks,
global_comm,
stream,
dev_ctx,
fp16_scale);
}
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel_each_device,
nranks,
global_comm,
stream,
dev_ctx,
fp32_scale);
NCCLReduceScatterWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel_each_device,
nranks,
global_comm,
stream,
dev_ctx,
fp16_scale);
}
VLOG(1) << "FP32 HasNanInf after all reduce: "
<< HasNanInf(dev_ctx, fp32_sum_grad, fp32_numel_each_device);
VLOG(1) << "FP16 HasNanInf after all reduce: "
<< HasNanInf(dev_ctx, fp16_sum_grad, fp16_numel_each_device);
CheckHasNanInfGrad(fp32_sum_grad,
fp32_numel_each_device,
fp16_sum_grad,
fp16_numel_each_device,
fp32_square_grad_norm,
stream,
&cub_tmp_buffer);
if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(fp32_square_grad_norm,
fp32_square_grad_norm,
1,
ncclFloat32,
ncclSum,
local_comm,
stream));
VLOG(1) << "Grad square norm after all reduce: "
<< FlattenToString(fp32_square_grad_norm, 1, place);
}
// (4) mark max_global_grad_norm as 0, meaning that clip has been
// already performed
max_global_grad_norm = 0;
} }
} else { CalcGradNormClipBeforeAllReduceScale<float, dtype::float16>
<<<1, 1, 0, stream>>>(global_scale_data,
max_global_grad_norm,
fp32_square_grad_norm,
fp32_scale,
fp16_scale,
clip_scale);
if (fp32_scale) {
VLOG(1) << "Grad scale: " << FlattenToString(fp32_scale, 1, place);
} else {
VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place);
}
// (3) Do ReduceScatter with scale
VLOG(1) << "FP32 HasNanInf before all reduce: "
<< HasNanInf(dev_ctx, fp32_grad_data, fp32_numel);
VLOG(1) << "FP16 HasNanInf before all reduce: "
<< HasNanInf(dev_ctx, fp16_grad_data, fp16_numel);
if (local_shard) { if (local_shard) {
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp32_grad, fp32_grad_data,
fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device, fp32_numel_each_device,
num_devices, num_devices,
local_comm, local_comm,
stream, stream,
dev_ctx); dev_ctx,
fp32_scale);
NCCLAllReduceWithScale( NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device,
...@@ -1989,13 +1888,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1989,13 +1888,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
dev_ctx); dev_ctx);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp16_grad, fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device, fp16_numel_each_device,
num_devices, num_devices,
local_comm, local_comm,
stream, stream,
dev_ctx); dev_ctx,
fp16_scale);
NCCLAllReduceWithScale( NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
...@@ -2005,39 +1905,47 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2005,39 +1905,47 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream, stream,
dev_ctx); dev_ctx);
} else { } else {
NCCLAllReduceWithScale(fp32_grad, NCCLAllReduceWithScale(fp32_grad_data,
fp32_sum_grad, fp32_sum_grad,
fp32_numel, fp32_numel,
nranks, nranks,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx,
NCCLAllReduceWithScale(fp16_grad, fp32_scale);
NCCLAllReduceWithScale(fp16_grad_data,
fp16_sum_grad, fp16_sum_grad,
fp16_numel, fp16_numel,
nranks, nranks,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx,
fp16_scale);
} }
fp32_sum_grad += (local_rank * fp32_numel_each_device); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else { } else {
NCCLReduceScatterWithScale(fp32_grad, NCCLReduceScatterWithScale(fp32_grad_data,
fp32_sum_grad, fp32_sum_grad,
fp32_numel_each_device, fp32_numel_each_device,
num_devices, nranks,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx,
NCCLReduceScatterWithScale(fp16_grad, fp32_scale);
NCCLReduceScatterWithScale(fp16_grad_data,
fp16_sum_grad, fp16_sum_grad,
fp16_numel_each_device, fp16_numel_each_device,
num_devices, nranks,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx,
fp16_scale);
} }
VLOG(1) << "FP32 HasNanInf after all reduce: "
<< HasNanInf(dev_ctx, fp32_sum_grad, fp32_numel_each_device);
VLOG(1) << "FP16 HasNanInf after all reduce: "
<< HasNanInf(dev_ctx, fp16_sum_grad, fp16_numel_each_device);
CheckHasNanInfGrad(fp32_sum_grad, CheckHasNanInfGrad(fp32_sum_grad,
fp32_numel_each_device, fp32_numel_each_device,
fp16_sum_grad, fp16_sum_grad,
...@@ -2047,261 +1955,357 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2047,261 +1955,357 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
&cub_tmp_buffer); &cub_tmp_buffer);
if (num_devices > 1) { if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(fp32_square_grad_norm, phi::dynload::ncclAllReduce(fp32_square_grad_norm,
fp32_square_grad_norm, fp32_square_grad_norm,
1, 1,
ncclFloat32, ncclFloat32,
ncclSum, ncclSum,
local_comm, local_comm,
stream)); stream));
VLOG(1) << "Grad square norm after all reduce: "
<< FlattenToString(fp32_square_grad_norm, 1, place);
} }
// (4) mark max_global_grad_norm as 0, meaning that clip has been
// already performed
max_global_grad_norm = 0; max_global_grad_norm = 0;
} }
VLOG(10) << "ReduceScatter done"; } else {
if (local_shard) {
// Step 7: update the moment1, moment2. Calcuate the trust_ratio_div if (use_hierarchical_allreduce) {
auto *fused_offsets_t = ctx.Input<phi::DenseTensor>("FusedParamOffsets"); NCCLReduceScatterWithScale(
auto *fused_offsets = fused_offsets_t->data<int>(); fp32_grad_data,
auto *fp32_partial_fused_offsets_t = fp32_sum_grad + local_rank * fp32_numel_each_device,
ctx.Input<phi::DenseTensor>("FP32ShardFusedParamOffsets"); fp32_numel_each_device,
const auto *fp32_partial_fused_offsets = num_devices,
fp32_partial_fused_offsets_t->data<int>(); local_comm,
auto *fp16_partial_fused_offsets_t = stream,
ctx.Input<phi::DenseTensor>("FP16ShardFusedParamOffsets"); dev_ctx);
const auto *fp16_partial_fused_offsets = NCCLAllReduceWithScale(
fp16_partial_fused_offsets_t->data<int>(); fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
auto *step = ctx.Output<phi::DenseTensor>("Step")->data<int64_t>(); fp32_numel_each_device,
nranks / num_devices,
VLOG(1) << "FusedParamOffsets: " external_comm,
<< FlattenToString(fused_offsets, stream,
fused_offsets_t->numel(), dev_ctx);
fused_offsets_t->place());
VLOG(1) << "FP32ShardFusedParamOffsets: " NCCLReduceScatterWithScale(
<< FlattenToString(fp32_partial_fused_offsets, fp16_grad_data,
fp32_partial_fused_offsets_t->numel(), fp16_sum_grad + local_rank * fp16_numel_each_device,
fp32_partial_fused_offsets_t->place()); fp16_numel_each_device,
VLOG(1) << "FP16ShardFusedParamOffsets: " num_devices,
<< FlattenToString(fp16_partial_fused_offsets, local_comm,
fp16_partial_fused_offsets_t->numel(), stream,
fp16_partial_fused_offsets_t->place()); dev_ctx);
NCCLAllReduceWithScale(
phi::memory_utils::Buffer trust_ratio_div_buffer(place); fp16_sum_grad + local_rank * fp16_numel_each_device,
auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel); fp16_sum_grad + local_rank * fp16_numel_each_device,
auto fp32_offset = local_rank * fp32_numel_each_device; fp16_numel_each_device,
auto fp16_offset = local_rank * fp16_numel_each_device; nranks / num_devices,
if (has_fp32_param) { external_comm,
VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts"; stream,
MultiTensorUpdateLambMomentAndTrustRatioDiv(dev_ctx, dev_ctx);
fp32_partial_fused_offsets, } else {
fp32_local_param_num, NCCLAllReduceWithScale(fp32_grad_data,
fp32_param + fp32_offset, fp32_sum_grad,
fp32_sum_grad, fp32_numel,
fp32_square_grad_norm, nranks,
global_scale, global_comm,
beta1pow, stream,
beta2pow, dev_ctx);
moment1, NCCLAllReduceWithScale(fp16_grad_data,
moment2, fp16_sum_grad,
trust_ratio_div, fp16_numel,
found_inf, nranks,
step, global_comm,
weight_decay, stream,
fp32_weight_decay_end_idx, dev_ctx);
beta1, }
beta2, fp32_sum_grad += (local_rank * fp32_numel_each_device);
epsilon, fp16_sum_grad += (local_rank * fp16_numel_each_device);
max_global_grad_norm, } else {
rescale_grad); NCCLReduceScatterWithScale(fp32_grad_data,
VLOG(10) << "Update FP32 Moment and TrustRatioDiv done"; fp32_sum_grad,
fp32_numel_each_device,
num_devices,
global_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(fp16_grad_data,
fp16_sum_grad,
fp16_numel_each_device,
num_devices,
global_comm,
stream,
dev_ctx);
} }
float *master_param = nullptr; CheckHasNanInfGrad(fp32_sum_grad,
if (has_fp16_param) { fp32_numel_each_device,
master_param = fp32_param + fp32_numel; fp16_sum_grad,
VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts"; fp16_numel_each_device,
auto tmp_found_inf = has_fp32_param ? nullptr : found_inf; fp32_square_grad_norm,
auto tmp_step = has_fp32_param ? nullptr : step; stream,
MultiTensorUpdateLambMomentAndTrustRatioDiv( &cub_tmp_buffer);
dev_ctx, if (num_devices > 1) {
fp16_partial_fused_offsets, PADDLE_ENFORCE_GPU_SUCCESS(
fp16_local_param_num, phi::dynload::ncclAllReduce(fp32_square_grad_norm,
master_param + fp16_offset, fp32_square_grad_norm,
fp16_sum_grad, 1,
fp32_square_grad_norm, ncclFloat32,
global_scale, ncclSum,
beta1pow, local_comm,
beta2pow, stream));
moment1 + fp32_numel_each_device,
moment2 + fp32_numel_each_device,
trust_ratio_div + fp32_numel_each_device,
tmp_found_inf,
tmp_step,
weight_decay,
fp16_weight_decay_end_idx,
beta1,
beta2,
epsilon,
max_global_grad_norm,
rescale_grad);
VLOG(10) << "Update FP16 Moment and TrustRatioDiv done";
} }
max_global_grad_norm = 0;
}
VLOG(10) << "ReduceScatter done";
// Step 7: update the moment1, moment2. Calcuate the trust_ratio_div
auto *param_offsets_data = param_offsets.data<int>();
const auto *fp32_partial_offsets_data = fp32_partial_offsets.data<int>();
const auto *fp16_partial_offsets_data = fp16_partial_offsets.data<int>();
auto *step_data = step->data<int64_t>();
VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(param_offsets_data,
param_offsets.numel(),
param_offsets.place());
VLOG(1) << "FP32ShardFusedParamOffsets: "
<< FlattenToString(fp32_partial_offsets_data,
fp32_partial_offsets.numel(),
fp32_partial_offsets.place());
VLOG(1) << "FP16ShardFusedParamOffsets: "
<< FlattenToString(fp16_partial_offsets_data,
fp16_partial_offsets.numel(),
fp16_partial_offsets.place());
memory_utils::Buffer trust_ratio_div_buffer(place);
auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
auto fp32_offset = local_rank * fp32_numel_each_device;
auto fp16_offset = local_rank * fp16_numel_each_device;
if (has_fp32_param) {
VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts";
MultiTensorUpdateLambMomentAndTrustRatioDiv(dev_ctx,
fp32_partial_offsets_data,
fp32_local_param_num,
fp32_param_data + fp32_offset,
fp32_sum_grad,
fp32_square_grad_norm,
global_scale_data,
beta1_pow_data,
beta2_pow_data,
moment1_data,
moment2_data,
trust_ratio_div,
found_inf_data,
step_data,
weight_decay,
fp32_weight_decay_end_idx,
beta1,
beta2,
epsilon,
max_global_grad_norm,
rescale_grad);
VLOG(10) << "Update FP32 Moment and TrustRatioDiv done";
}
float *master_param = nullptr;
if (has_fp16_param) {
master_param = fp32_param_data + fp32_numel;
VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts";
auto tmp_found_inf = has_fp32_param ? nullptr : found_inf_data;
auto tmp_step = has_fp32_param ? nullptr : step_data;
MultiTensorUpdateLambMomentAndTrustRatioDiv(
dev_ctx,
fp16_partial_offsets_data,
fp16_local_param_num,
master_param + fp16_offset,
fp16_sum_grad,
fp32_square_grad_norm,
global_scale_data,
beta1_pow_data,
beta2_pow_data,
moment1_data + fp32_numel_each_device,
moment2_data + fp32_numel_each_device,
trust_ratio_div + fp32_numel_each_device,
tmp_found_inf,
tmp_step,
weight_decay,
fp16_weight_decay_end_idx,
beta1,
beta2,
epsilon,
max_global_grad_norm,
rescale_grad);
VLOG(10) << "Update FP16 Moment and TrustRatioDiv done";
}
VLOG(10) << "Update Moment and TrustRatioDiv done hehahaha"; VLOG(10) << "Update Moment and TrustRatioDiv done hehahaha";
// Step 8: calculate L2-Norm square of parameter and trust_ratio_div // Step 8: calculate L2-Norm square of parameter and trust_ratio_div
phi::memory_utils::Buffer square_norm_buffer(place); memory_utils::Buffer square_norm_buffer(place);
auto *param_square_norm = square_norm_buffer.Alloc<float>(2 * param_num); auto *param_square_norm = square_norm_buffer.Alloc<float>(2 * param_num);
auto *trust_ratio_div_square_norm = param_square_norm + param_num; auto *trust_ratio_div_square_norm = param_square_norm + param_num;
if (num_devices > 1) { if (num_devices > 1) {
if (use_master_param_norm) {
FillZeroWithPtr(param_square_norm + fp32_global_param_num,
2 * param_num - fp32_global_param_num,
stream);
} else {
FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream);
}
}
MultiTensorL2Norm(place,
stream,
fp32_param,
fused_offsets,
fp32_global_param_num,
param_square_norm);
if (use_master_param_norm) { if (use_master_param_norm) {
MultiTensorL2Norm(place, FillZeroWithPtr(param_square_norm + fp32_global_param_num,
stream, 2 * param_num - fp32_global_param_num,
master_param + fp16_offset, stream);
fp16_partial_fused_offsets,
fp16_local_param_num,
param_square_norm + fp16_local_start_idx);
} else { } else {
MultiTensorL2Norm(place, FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream);
stream,
fp16_param + fused_offsets[fp16_local_start_idx] -
fused_offsets[fp32_global_param_num],
fused_offsets + fp16_local_start_idx,
fp16_local_param_num,
param_square_norm + fp16_local_start_idx);
} }
}
MultiTensorL2Norm(place,
stream,
fp32_param_data,
param_offsets_data,
fp32_global_param_num,
param_square_norm);
if (use_master_param_norm) {
MultiTensorL2Norm(place, MultiTensorL2Norm(place,
stream, stream,
trust_ratio_div, master_param + fp16_offset,
fp32_partial_fused_offsets, fp16_partial_offsets_data,
fp32_local_param_num, fp16_local_param_num,
trust_ratio_div_square_norm + fp32_local_start_idx); param_square_norm + fp16_local_start_idx);
} else {
MultiTensorL2Norm(place, MultiTensorL2Norm(place,
stream, stream,
trust_ratio_div + fp32_numel_each_device, fp16_param_data +
fp16_partial_fused_offsets, param_offsets_data[fp16_local_start_idx] -
param_offsets_data[fp32_global_param_num],
param_offsets_data + fp16_local_start_idx,
fp16_local_param_num, fp16_local_param_num,
trust_ratio_div_square_norm + fp16_local_start_idx); param_square_norm + fp16_local_start_idx);
}
VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: " MultiTensorL2Norm(place,
<< FlattenToString(trust_ratio_div_square_norm, param_num, place); stream,
if (num_devices > 1) { trust_ratio_div,
if (use_master_param_norm) { fp32_partial_offsets_data,
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( fp32_local_param_num,
param_square_norm + fp32_global_param_num, trust_ratio_div_square_norm + fp32_local_start_idx);
param_square_norm + fp32_global_param_num, MultiTensorL2Norm(place,
2 * param_num - fp32_global_param_num, stream,
ncclFloat32, trust_ratio_div + fp32_numel_each_device,
ncclSum, fp16_partial_offsets_data,
local_comm, fp16_local_param_num,
stream)); trust_ratio_div_square_norm + fp16_local_start_idx);
} else {
PADDLE_ENFORCE_GPU_SUCCESS( VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: "
platform::dynload::ncclAllReduce(trust_ratio_div_square_norm, << FlattenToString(trust_ratio_div_square_norm, param_num, place);
trust_ratio_div_square_norm, if (num_devices > 1) {
param_num, if (use_master_param_norm) {
ncclFloat32, PADDLE_ENFORCE_GPU_SUCCESS(
ncclSum, phi::dynload::ncclAllReduce(param_square_norm + fp32_global_param_num,
local_comm, param_square_norm + fp32_global_param_num,
stream)); 2 * param_num - fp32_global_param_num,
} ncclFloat32,
VLOG(10) << "ncclAllReduce done"; ncclSum,
local_comm,
stream));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclAllReduce(trust_ratio_div_square_norm,
trust_ratio_div_square_norm,
param_num,
ncclFloat32,
ncclSum,
local_comm,
stream));
} }
VLOG(10) << "ncclAllReduce done";
}
LogParamAndTrustRatioDivSquareNorm<1>( LogParamAndTrustRatioDivSquareNorm<1>(
ctx, param_square_norm, trust_ratio_div_square_norm); param, param_order, param_square_norm, trust_ratio_div_square_norm);
VLOG(10) << "Calculate L2-Norm of Param and TrustRatioDiv done"; VLOG(10) << "Calculate L2-Norm of Param and TrustRatioDiv done";
// Step 9: update parameter, beta1pow, beta2pow. All gather parameters. // Step 9: update parameter, beta1pow, beta2pow. All gather parameters.
if (has_fp32_param) { if (has_fp32_param) {
MultiTensorUpdateLambParamAndBetaPows<float>( MultiTensorUpdateLambParamAndBetaPows<float>(
dev_ctx, dev_ctx,
fp32_partial_fused_offsets, fp32_partial_offsets_data,
fp32_local_param_num, fp32_local_param_num,
trust_ratio_div, trust_ratio_div,
lr, lr_data,
param_square_norm + fp32_local_start_idx, param_square_norm + fp32_local_start_idx,
trust_ratio_div_square_norm + fp32_local_start_idx, trust_ratio_div_square_norm + fp32_local_start_idx,
found_inf, found_inf_data,
fp32_param + fp32_offset, fp32_param_data + fp32_offset,
nullptr, nullptr,
beta1pow, beta1_pow_data,
beta2pow, beta2_pow_data,
beta1, beta1,
beta2); beta2);
if (num_devices > 1) { if (num_devices > 1) {
// ncclAllGather // ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(fp32_param + fp32_offset, phi::dynload::ncclAllGather(fp32_param_data + fp32_offset,
fp32_param, fp32_param_data,
fp32_numel_each_device, fp32_numel_each_device,
ncclFloat32, ncclFloat32,
local_comm, local_comm,
stream)); stream));
}
beta1pow = nullptr;
beta2pow = nullptr;
} }
if (has_fp16_param) {
MultiTensorUpdateLambParamAndBetaPows<platform::float16>( beta1_pow_data = nullptr;
dev_ctx, beta2_pow_data = nullptr;
fp16_partial_fused_offsets, }
fp16_local_param_num, if (has_fp16_param) {
trust_ratio_div + fp32_numel_each_device, MultiTensorUpdateLambParamAndBetaPows<dtype::float16>(
lr, dev_ctx,
param_square_norm + fp16_local_start_idx, fp16_partial_offsets_data,
trust_ratio_div_square_norm + fp16_local_start_idx, fp16_local_param_num,
found_inf, trust_ratio_div + fp32_numel_each_device,
fp16_param + fp16_offset, lr_data,
master_param + fp16_offset, param_square_norm + fp16_local_start_idx,
beta1pow, trust_ratio_div_square_norm + fp16_local_start_idx,
beta2pow, found_inf_data,
beta1, fp16_param_data + fp16_offset,
beta2); master_param + fp16_offset,
if (num_devices > 1) { beta1_pow_data,
// ncclAllGather beta2_pow_data,
PADDLE_ENFORCE_GPU_SUCCESS( beta1,
platform::dynload::ncclAllGather(fp16_param + fp16_offset, beta2);
fp16_param, if (num_devices > 1) {
fp16_numel_each_device, // ncclAllGather
ncclFloat16, PADDLE_ENFORCE_GPU_SUCCESS(
local_comm, phi::dynload::ncclAllGather(fp16_param_data + fp16_offset,
stream)); fp16_param_data,
} fp16_numel_each_device,
ncclFloat16,
local_comm,
stream));
} }
VLOG(10) << "Update Param done"; }
VLOG(10) << "Update Param done";
VLOG(1) << "IsFinite: " << IsFinite(dev_ctx, fp32_square_grad_norm); VLOG(1) << "IsFinite: " << IsFinite(dev_ctx, fp32_square_grad_norm);
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"distributed_fused_lamb op should be used with NCCL/RCCL.")); "distributed_fused_lamb op should be used with NCCL/RCCL."));
#endif #endif
} }
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb, } // namespace fusion
GPU, } // namespace phi
ALL_LAYOUT,
ops::DistributedFusedLambOpKernel, PD_REGISTER_KERNEL(distributed_fused_lamb,
float) {} GPU,
ALL_LAYOUT,
phi::fusion::DistributedFusedLambKernel,
float) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(9).SetDataType(phi::DataType::BOOL);
kernel->OutputAt(10).SetDataType(phi::DataType::INT64);
kernel->OutputAt(11).SetDataType(phi::DataType::BOOL);
kernel->OutputAt(12).SetDataType(phi::DataType::INT64);
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template <typename T, typename DevCtx>
class DistributedFusedLambOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"The distributed_fused_lamb operator does not support CPU yet."));
}
};
} // namespace operators
} // namespace paddle
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "math.h" // NOLINT #include "math.h" // NOLINT
#include "paddle/phi/core/cuda_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DistributedFusedLambOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("distributed_fused_lamb",
{"Param",
"Grad",
"FP32FusedParam",
"FP32FusedGrad",
"FP16FusedParam",
"FP16FusedGrad",
"Moment1",
"Moment2",
"Beta1Pow",
"Beta2Pow",
"FusedParamOffsets",
"FP32ShardFusedParamOffsets",
"FP16ShardFusedParamOffsets",
"ParamInfo",
"ParamOrder",
"LearningRate",
"GlobalScale"},
{"acc_steps",
"beta1",
"beta2",
"epsilon",
"max_global_grad_norm",
"weight_decay",
"clip_after_allreduce",
"use_master_param_norm",
"use_master_acc_grad",
"is_grad_scaled_by_nranks",
"use_hierarchical_allreduce",
"nranks",
"ring_ids"},
{"FP32FusedParamOut",
"FP16FusedParamOut",
"FP32AccFusedGrad",
"FP16AccFusedGrad",
"Moment1Out",
"Moment2Out",
"Beta1PowOut",
"Beta2PowOut",
"ParamOut",
"FoundInf",
"AccStep",
"StopUpdate",
"Step"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(distributed_fused_lamb,
phi::DistributedFusedLambOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册