未验证 提交 5f8e7d8f 编写于 作者: H huangjiyi 提交者: GitHub

Functionalize distributed_fused_lamb kernel (#53896)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update HostAlloc

* update param name

* update cpu kernel

* remove kernel header

* update

* update
上级 6e0cf610
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -170,8 +171,63 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb, ...@@ -170,8 +171,63 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb,
ops::DistributedFusedLambOp, ops::DistributedFusedLambOp,
ops::DistributedFusedLambOpMaker); ops::DistributedFusedLambOpMaker);
PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb, namespace phi {
namespace fusion {
template <typename T, typename Context>
void DistributedFusedLambKernel(const Context &dev_ctx,
const std::vector<const DenseTensor *> &param,
const std::vector<const DenseTensor *> &grad,
const paddle::optional<DenseTensor> &fp32_param,
const paddle::optional<DenseTensor> &fp32_grad,
const paddle::optional<DenseTensor> &fp16_param,
const paddle::optional<DenseTensor> &fp16_grad,
const DenseTensor &moment1,
const DenseTensor &moment2,
const DenseTensor &beta1_pow,
const DenseTensor &beta2_pow,
const DenseTensor &param_offsets,
const DenseTensor &fp32_partial_offsets,
const DenseTensor &fp16_partial_offsets,
const DenseTensor &param_info,
const DenseTensor &param_order,
const DenseTensor &learning_rate,
const DenseTensor &global_scale,
int acc_steps,
float beta1,
float beta2,
float epsilon,
float max_global_grad_norm,
float weight_decay,
bool clip_after_allreduce,
bool use_master_param_norm,
bool use_master_acc_grad,
bool is_grad_scaled_by_nranks,
bool use_hierarchical_allreduce,
int64_t nranks,
const std::vector<int> &ring_ids,
DenseTensor *fp32_param_out,
DenseTensor *fp16_param_out,
DenseTensor *fp32_acc_grad,
DenseTensor *fp16_acc_grad,
DenseTensor *moment1_out,
DenseTensor *moment2_out,
DenseTensor *beta1_pow_out,
DenseTensor *beta2_pow_out,
DenseTensor *param_out,
DenseTensor *found_inf,
DenseTensor *acc_step,
DenseTensor *stop_update,
DenseTensor *step) {
PADDLE_THROW(phi::errors::Unimplemented(
"The distributed_fused_lamb operator does not support CPU yet."));
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(distributed_fused_lamb,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
ops::DistributedFusedLambOpKernel, phi::fusion::DistributedFusedLambKernel,
float) {} float) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,19 +12,21 @@ ...@@ -12,19 +12,21 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cmath>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h" #include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/cuda_stream.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/tensor_to_string.h" #include "paddle/phi/kernels/funcs/tensor_to_string.h"
#include "paddle/utils/optional.h"
#ifdef __NVCC__ #ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
...@@ -38,11 +40,11 @@ ...@@ -38,11 +40,11 @@
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
namespace paddle { namespace phi {
namespace operators { namespace fusion {
template <typename T> template <typename T>
using MasterT = typename details::MPTypeTrait<T>::Type; using MasterT = typename phi::dtype::MPTypeTrait<T>::Type;
using phi::funcs::FlattenToString; using phi::funcs::FlattenToString;
using phi::funcs::ToVector; using phi::funcs::ToVector;
...@@ -157,7 +159,7 @@ template <typename InT, ...@@ -157,7 +159,7 @@ template <typename InT,
typename OutT, typename OutT,
int MaxTensorNumPerLaunch = 160, int MaxTensorNumPerLaunch = 160,
int MaxChunkNumPerLaunch = 780> int MaxChunkNumPerLaunch = 780>
static void MultiTensorL2Norm(const platform::CUDAPlace &place, static void MultiTensorL2Norm(const phi::GPUPlace &place,
gpuStream_t stream, gpuStream_t stream,
const InT *x, const InT *x,
const int *offsets, const int *offsets,
...@@ -191,7 +193,7 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, ...@@ -191,7 +193,7 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
<< " , tensor_num = " << n; << " , tensor_num = " << n;
using MT = MasterT<InT>; using MT = MasterT<InT>;
phi::memory_utils::Buffer tmp_out(place); memory_utils::Buffer tmp_out(place);
auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num); auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream); FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);
...@@ -200,7 +202,8 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, ...@@ -200,7 +202,8 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \ using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \
VLOG(10) << __func__ << " " << typeid(InT).name() \ VLOG(10) << __func__ << " " << typeid(InT).name() \
<< " VecSize = " << kVecSize; \ << " VecSize = " << kVecSize; \
MultiTensorApply<FunctorT, kNumTensor, kNumChunk>(FunctorT(), \ paddle::operators::MultiTensorApply<FunctorT, kNumTensor, kNumChunk>( \
FunctorT(), \
stream, \ stream, \
offsets, \ offsets, \
n, \ n, \
...@@ -220,27 +223,27 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, ...@@ -220,27 +223,27 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
template <int LogLevel> template <int LogLevel>
static void LogParamAndTrustRatioDivSquareNorm( static void LogParamAndTrustRatioDivSquareNorm(
const framework::ExecutionContext &ctx, const std::vector<const DenseTensor *> &param,
const DenseTensor &order,
const float *param_square_norm, const float *param_square_norm,
const float *trust_ratio_div_square_norm) { const float *trust_ratio_div_square_norm) {
if (!VLOG_IS_ON(LogLevel)) return; if (!VLOG_IS_ON(LogLevel)) return;
auto tensors = ctx.MultiInput<phi::DenseTensor>("Param"); if (param.empty()) return;
if (tensors.empty()) return;
const auto *order = ctx.Input<phi::DenseTensor>("ParamOrder")->data<int>(); const auto *order_data = order.data<int>();
size_t n = tensors.size(); size_t n = param.size();
auto place = tensors[0]->place(); auto place = param[0]->place();
auto pn_vec = ToVector(param_square_norm, n, place); auto pn_vec = ToVector(param_square_norm, n, place);
auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place); auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place);
const auto &names = ctx.GetOp().Inputs("Param");
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
auto idx = order[i]; auto idx = order_data[i];
VLOG(LogLevel) << "Param " << tensors[idx]->dtype() << " " << names[idx] VLOG(LogLevel) << "Param " << param[idx]->dtype() << " "
<< " pn = " << pn_vec[i] << " , tn = " << tn_vec[i]; << param[idx]->name() << " pn = " << pn_vec[i]
<< " , tn = " << tn_vec[i];
} }
} }
...@@ -261,13 +264,12 @@ static bool IsFinite(const phi::GPUContext &dev_ctx, const float *ptr) { ...@@ -261,13 +264,12 @@ static bool IsFinite(const phi::GPUContext &dev_ctx, const float *ptr) {
} }
template <typename T> template <typename T>
static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx, static const T *GetInputTensorPtr(const DenseTensor *in_tensor,
const char *in_name, const char *in_name,
int64_t *numel = nullptr) { int64_t *numel = nullptr) {
const auto *in_tensor = ctx.Input<phi::DenseTensor>(in_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
in_tensor, in_tensor,
platform::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name)); phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
if (in_tensor->IsInitialized()) { if (in_tensor->IsInitialized()) {
if (numel) *numel = in_tensor->numel(); if (numel) *numel = in_tensor->numel();
return in_tensor->data<T>(); return in_tensor->data<T>();
...@@ -277,34 +279,34 @@ static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx, ...@@ -277,34 +279,34 @@ static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx,
} }
} }
template <typename T, bool AllowNotExist = false> template <typename T, typename Context, bool AllowNotExist = false>
static T *GetSameInOutTensorPtr(const framework::ExecutionContext &ctx, static T *GetSameInOutTensorPtr(const Context &dev_ctx,
const platform::Place &place, const DenseTensor *in_tensor,
DenseTensor *out_tensor,
const char *in_name, const char *in_name,
const char *out_name, const char *out_name,
int64_t *numel = nullptr) { int64_t *numel = nullptr) {
const auto *in_tensor = ctx.Input<phi::DenseTensor>(in_name);
if (in_tensor == nullptr || !in_tensor->IsInitialized()) { if (in_tensor == nullptr || !in_tensor->IsInitialized()) {
PADDLE_ENFORCE_EQ(AllowNotExist, PADDLE_ENFORCE_EQ(
AllowNotExist,
true, true,
platform::errors::InvalidArgument( phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
"Input(%s) cannot be NULL.", in_name));
if (numel) *numel = 0; if (numel) *numel = 0;
return nullptr; return nullptr;
} }
auto *out_tensor = ctx.Output<phi::DenseTensor>(out_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
in_tensor, in_tensor,
platform::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name)); phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
PADDLE_ENFORCE_NOT_NULL(out_tensor, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( out_tensor,
"Output(%s) cannot be NULL.", out_name)); phi::errors::InvalidArgument("Output(%s) cannot be NULL.", out_name));
const T *in_data = in_tensor->data<T>(); const T *in_data = in_tensor->data<T>();
T *out_data = out_tensor->mutable_data<T>(place);
T *out_data = dev_ctx.template Alloc<T>(out_tensor);
PADDLE_ENFORCE_EQ(in_data, PADDLE_ENFORCE_EQ(in_data,
out_data, out_data,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(%s) and Output(%s) must be the same Tensor.", "Input(%s) and Output(%s) must be the same Tensor.",
in_name, in_name,
out_name)); out_name));
...@@ -535,11 +537,11 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv( ...@@ -535,11 +537,11 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
int numel = offsets[n] - offsets[0]; int numel = offsets[n] - offsets[0];
PADDLE_ENFORCE_GE(weight_decay_end_idx, PADDLE_ENFORCE_GE(weight_decay_end_idx,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The weight decay end index should be >= 0.")); "The weight decay end index should be >= 0."));
PADDLE_ENFORCE_LE(weight_decay_end_idx, PADDLE_ENFORCE_LE(weight_decay_end_idx,
n, n,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The weight decay end index should be < %d.", n)); "The weight decay end index should be < %d.", n));
auto weight_decay_end_numel = offsets[weight_decay_end_idx] - offsets[0]; auto weight_decay_end_numel = offsets[weight_decay_end_idx] - offsets[0];
...@@ -558,17 +560,17 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv( ...@@ -558,17 +560,17 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
VLOG(1) << __func__ << " VecSize = " << vec_size; VLOG(1) << __func__ << " VecSize = " << vec_size;
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
if (found_inf_p == nullptr) { if (found_inf_p == nullptr) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
step, step,
nullptr, nullptr,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Output(Step) cannot be updated twice in one mini-batch.")); "Output(Step) cannot be updated twice in one mini-batch."));
} else { } else {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
step, step, phi::errors::InvalidArgument("Output(Step) cannot be nullptr."));
platform::errors::InvalidArgument("Output(Step) cannot be nullptr."));
} }
#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \ #define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \
...@@ -603,12 +605,12 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv( ...@@ -603,12 +605,12 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
template <typename T, bool NeedUpdate /*=true*/> template <typename T, bool NeedUpdate /*=true*/>
struct LambBetaPowUpdateOnceHelper { struct LambBetaPowUpdateOnceHelper {
LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) { LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) {
PADDLE_ENFORCE_NOT_NULL(beta1pow, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( beta1pow,
"The beta1pow should not be nullptr.")); phi::errors::InvalidArgument("The beta1pow should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(beta2pow, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( beta2pow,
"The beta2pow should not be nullptr.")); phi::errors::InvalidArgument("The beta2pow should not be nullptr."));
beta1pow_ = beta1pow; beta1pow_ = beta1pow;
beta2pow_ = beta2pow; beta2pow_ = beta2pow;
beta1_ = beta1; beta1_ = beta1;
...@@ -633,11 +635,11 @@ struct LambBetaPowUpdateOnceHelper<T, false> { ...@@ -633,11 +635,11 @@ struct LambBetaPowUpdateOnceHelper<T, false> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta1pow, beta1pow,
nullptr, nullptr,
platform::errors::InvalidArgument("The beta1pow should be nullptr.")); phi::errors::InvalidArgument("The beta1pow should be nullptr."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta2pow, beta2pow,
nullptr, nullptr,
platform::errors::InvalidArgument("The beta2pow should be nullptr.")); phi::errors::InvalidArgument("The beta2pow should be nullptr."));
} }
HOSTDEVICE void UpdateBetaPows() const {} HOSTDEVICE void UpdateBetaPows() const {}
...@@ -649,11 +651,11 @@ struct LambParamHelper { ...@@ -649,11 +651,11 @@ struct LambParamHelper {
constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value; constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
PADDLE_ENFORCE_EQ(kIsSameType, PADDLE_ENFORCE_EQ(kIsSameType,
false, false,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"T must not be the same with MasterT<T>.")); "T must not be the same with MasterT<T>."));
PADDLE_ENFORCE_NOT_NULL(master_param, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( master_param,
"Master parameter must be provided.")); phi::errors::InvalidArgument("Master parameter must be provided."));
param_ = param; param_ = param;
master_param_ = master_param; master_param_ = master_param;
} }
...@@ -671,14 +673,14 @@ template <typename T> ...@@ -671,14 +673,14 @@ template <typename T>
struct LambParamHelper<T, false> { struct LambParamHelper<T, false> {
LambParamHelper(T *param, MasterT<T> *master_param) { LambParamHelper(T *param, MasterT<T> *master_param) {
constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value; constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
PADDLE_ENFORCE_EQ(kIsSameType, PADDLE_ENFORCE_EQ(
kIsSameType,
true, true,
platform::errors::InvalidArgument( phi::errors::InvalidArgument("T must be the same with MasterT<T>."));
"T must be the same with MasterT<T>."));
if (master_param != nullptr) { if (master_param != nullptr) {
PADDLE_ENFORCE_EQ(static_cast<void *>(param), PADDLE_ENFORCE_EQ(static_cast<void *>(param),
static_cast<void *>(master_param), static_cast<void *>(master_param),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Master parameter must be nullptr or the same as " "Master parameter must be nullptr or the same as "
"non-master parameter.")); "non-master parameter."));
} }
...@@ -802,12 +804,12 @@ static void MultiTensorUpdateLambParamAndBetaPows( ...@@ -802,12 +804,12 @@ static void MultiTensorUpdateLambParamAndBetaPows(
if (has_beta_pow) { if (has_beta_pow) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
beta2pow, beta2pow,
platform::errors::InvalidArgument("Beta2Pow should not be nullptr.")); phi::errors::InvalidArgument("Beta2Pow should not be nullptr."));
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta2pow, beta2pow,
nullptr, nullptr,
platform::errors::InvalidArgument("Beta2Pow should be nullptr.")); phi::errors::InvalidArgument("Beta2Pow should be nullptr."));
} }
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -858,7 +860,8 @@ static void MultiTensorUpdateLambParamAndBetaPows( ...@@ -858,7 +860,8 @@ static void MultiTensorUpdateLambParamAndBetaPows(
#define PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE \ #define PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE \
do { \ do { \
auto callback = \ auto callback = \
[&](const MultiTensorLauncher<kNumTensor, kNumChunk> &launcher, \ [&](const paddle::operators::MultiTensorLauncher<kNumTensor, \
kNumChunk> &launcher, \
int launch_n) { \ int launch_n) { \
if (has_beta_pow && launch_n == 0) { \ if (has_beta_pow && launch_n == 0) { \
PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(true); \ PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(true); \
...@@ -868,7 +871,7 @@ static void MultiTensorUpdateLambParamAndBetaPows( ...@@ -868,7 +871,7 @@ static void MultiTensorUpdateLambParamAndBetaPows(
PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(false); \ PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(false); \
} \ } \
}; \ }; \
MultiTensorApplyWithCallback<kNumTensor, kNumChunk>( \ paddle::operators::MultiTensorApplyWithCallback<kNumTensor, kNumChunk>( \
stream, offsets, n, chunk_size, block_dim, callback); \ stream, offsets, n, chunk_size, block_dim, callback); \
} while (0) } while (0)
...@@ -886,10 +889,10 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype, ...@@ -886,10 +889,10 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype,
ncclRedOp_t *op) { ncclRedOp_t *op) {
#if NCCL_VERSION_CODE >= 21100 #if NCCL_VERSION_CODE >= 21100
int ver; int ver;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetVersion(&ver)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetVersion(&ver));
if (ver >= 21100) { if (ver >= 21100) {
VLOG(10) << "ncclRedOpCreatePreMulSum is supported."; VLOG(10) << "ncclRedOpCreatePreMulSum is supported.";
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpCreatePreMulSum( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpCreatePreMulSum(
op, const_cast<void *>(scale), dtype, ncclScalarDevice, comm)); op, const_cast<void *>(scale), dtype, ncclScalarDevice, comm));
return true; return true;
} }
...@@ -906,7 +909,7 @@ static void LaunchScaleKernel(const phi::GPUContext &dev_ctx, ...@@ -906,7 +909,7 @@ static void LaunchScaleKernel(const phi::GPUContext &dev_ctx,
int n, int n,
gpuStream_t stream) { gpuStream_t stream) {
int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)); int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0));
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
#define PD_LAMB_VEC_SCALE_KERNEL_CASE \ #define PD_LAMB_VEC_SCALE_KERNEL_CASE \
do { \ do { \
...@@ -928,8 +931,8 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ...@@ -928,8 +931,8 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
gpuStream_t stream, gpuStream_t stream,
const phi::GPUContext &dev_ctx, const phi::GPUContext &dev_ctx,
const T *scale = nullptr) { const T *scale = nullptr) {
static_assert(std::is_same<T, float>::value || static_assert(
std::is_same<T, platform::float16>::value, std::is_same<T, float>::value || std::is_same<T, dtype::float16>::value,
"T must be either float32 or float16."); "T must be either float32 or float16.");
if (recvcount == 0) return; if (recvcount == 0) return;
...@@ -938,7 +941,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ...@@ -938,7 +941,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
if (scale != nullptr) { if (scale != nullptr) {
PADDLE_ENFORCE_EQ(nranks, PADDLE_ENFORCE_EQ(nranks,
1, 1,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"nranks must be 1 when scale != nullptr.")); "nranks must be 1 when scale != nullptr."));
LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream); LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream);
} }
...@@ -950,7 +953,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ...@@ -950,7 +953,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
std::is_same<T, float>::value ? ncclFloat32 : ncclFloat16; std::is_same<T, float>::value ? ncclFloat32 : ncclFloat16;
bool should_destroy_op = bool should_destroy_op =
scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op); scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op);
phi::memory_utils::Buffer buffer(dev_ctx.GetPlace()); memory_utils::Buffer buffer(dev_ctx.GetPlace());
if (scale && !should_destroy_op) { if (scale && !should_destroy_op) {
T *new_sendbuff = buffer.Alloc<T>(numel); T *new_sendbuff = buffer.Alloc<T>(numel);
LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream); LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
...@@ -958,17 +961,17 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ...@@ -958,17 +961,17 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
} }
if (UseReduceScatter) { if (UseReduceScatter) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclReduceScatter(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
} }
#if NCCL_VERSION_CODE >= 21100 #if NCCL_VERSION_CODE >= 21100
if (should_destroy_op) { if (should_destroy_op) {
VLOG(10) << "ncclRedOpDestroy starts"; VLOG(10) << "ncclRedOpDestroy starts";
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpDestroy(op, comm)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpDestroy(op, comm));
VLOG(10) << "ncclRedOpDestroy ends"; VLOG(10) << "ncclRedOpDestroy ends";
} }
#endif #endif
...@@ -1012,7 +1015,7 @@ static void CubDeviceReduce(InputIteratorT d_in, ...@@ -1012,7 +1015,7 @@ static void CubDeviceReduce(InputIteratorT d_in,
ReduceOpT reduction_op, ReduceOpT reduction_op,
T init, T init,
gpuStream_t stream, gpuStream_t stream,
phi::memory_utils::Buffer *buffer) { memory_utils::Buffer *buffer) {
void *d_temp_storage = nullptr; void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Reduce(d_temp_storage, PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Reduce(d_temp_storage,
...@@ -1041,7 +1044,7 @@ static void GetSquareGradNormImpl(const T *grad, ...@@ -1041,7 +1044,7 @@ static void GetSquareGradNormImpl(const T *grad,
int n, int n,
float *square_norm, float *square_norm,
gpuStream_t stream, gpuStream_t stream,
phi::memory_utils::Buffer *cub_tmp_buffer) { memory_utils::Buffer *cub_tmp_buffer) {
using Iterator = using Iterator =
cub::TransformInputIterator<float, SquareFunctor<T>, const T *>; cub::TransformInputIterator<float, SquareFunctor<T>, const T *>;
Iterator iter(grad, SquareFunctor<T>()); Iterator iter(grad, SquareFunctor<T>());
...@@ -1057,11 +1060,11 @@ static void GetSquareGradNormImpl(const T *grad, ...@@ -1057,11 +1060,11 @@ static void GetSquareGradNormImpl(const T *grad,
// square_norm is of length 2 at least // square_norm is of length 2 at least
static void GetSquareGradNorm(const float *fp32_grad, static void GetSquareGradNorm(const float *fp32_grad,
int fp32_numel, int fp32_numel,
const platform::float16 *fp16_grad, const dtype::float16 *fp16_grad,
int fp16_numel, int fp16_numel,
float *square_norm, float *square_norm,
gpuStream_t stream, gpuStream_t stream,
phi::memory_utils::Buffer *cub_tmp_buffer) { memory_utils::Buffer *cub_tmp_buffer) {
VLOG(10) << "GetSquareGradNorm starts, fp32_numel = " << fp32_numel VLOG(10) << "GetSquareGradNorm starts, fp32_numel = " << fp32_numel
<< " , fp16_numel = " << fp16_numel; << " , fp16_numel = " << fp16_numel;
if (fp32_numel > 0) { if (fp32_numel > 0) {
...@@ -1096,23 +1099,21 @@ std::string NumToString(T x) { ...@@ -1096,23 +1099,21 @@ std::string NumToString(T x) {
} }
template <typename T> template <typename T>
static std::string GetMinMaxStr(const T *x, static std::string GetMinMaxStr(const T *x, size_t n, const phi::Place &place) {
size_t n,
const platform::Place &place) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place), place.GetType() == phi::AllocationType::GPU,
true, true,
platform::errors::InvalidArgument("Only support CUDAPlace currently.")); phi::errors::InvalidArgument("Only support CUDAPlace currently."));
auto *dev_ctx = static_cast<phi::GPUContext *>( auto *dev_ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(place)); phi::DeviceContextPool::Instance().Get(place));
auto stream = dev_ctx->stream(); auto stream = dev_ctx->stream();
phi::memory_utils::Buffer ret_buffer(place); memory_utils::Buffer ret_buffer(place);
T *ret = ret_buffer.Alloc<T>(2); T *ret = ret_buffer.Alloc<T>(2);
if (n > 0) { if (n > 0) {
phi::memory_utils::Buffer cub_buffer(place); memory_utils::Buffer cub_buffer(place);
CubDeviceReduce(x, CubDeviceReduce(x,
ret, ret,
n, n,
...@@ -1160,45 +1161,20 @@ struct VisitDTypeFunctor { ...@@ -1160,45 +1161,20 @@ struct VisitDTypeFunctor {
static std::string GetMinMaxStr(const phi::DenseTensor *x) { static std::string GetMinMaxStr(const phi::DenseTensor *x) {
if (x == nullptr) return "null"; if (x == nullptr) return "null";
if (!x->IsInitialized()) return "not_inited"; if (!x->IsInitialized()) return "not_inited";
if (!platform::is_gpu_place(x->place())) return "CPUTensor"; if (x->place().GetType() != phi::AllocationType::GPU) return "CPUTensor";
std::string str; std::string str;
VisitDTypeFunctor functor(x, &str); VisitDTypeFunctor functor(x, &str);
phi::VisitDataType(x->dtype(), functor); phi::VisitDataType(x->dtype(), functor);
return str; return str;
} }
static void PrintAllMinMaxRange(const framework::ExecutionContext &ctx,
bool only_inputs) {
if (!VLOG_IS_ON(1)) return;
for (const auto &pair : ctx.GetOp().Inputs()) {
const auto &key = pair.first;
const auto tensors = ctx.MultiInput<phi::DenseTensor>(key);
size_t n = tensors.size();
for (size_t i = 0; i < n; ++i) {
VLOG(1) << "Input(" << key + ")[" << i << "] = " << pair.second[i]
<< " , " << GetMinMaxStr(tensors[i]);
}
}
if (only_inputs) return;
for (const auto &pair : ctx.GetOp().Outputs()) {
const auto &key = pair.first;
const auto tensors = ctx.MultiOutput<phi::DenseTensor>(key);
size_t n = tensors.size();
for (size_t i = 0; i < n; ++i) {
VLOG(1) << "Output(" << key + ")[" << i << "] = " << pair.second[i]
<< " , " << GetMinMaxStr(tensors[i]);
}
}
}
template <typename T> template <typename T>
static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) { static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) {
if (numel <= 0) return false; if (numel <= 0) return false;
cub::TransformInputIterator<bool, IsNanInfFunctor<T>, const T *> iter( cub::TransformInputIterator<bool, IsNanInfFunctor<T>, const T *> iter(
x, IsNanInfFunctor<T>()); x, IsNanInfFunctor<T>());
phi::memory_utils::Buffer buffer(dev_ctx.GetPlace()); memory_utils::Buffer buffer(dev_ctx.GetPlace());
phi::memory_utils::Buffer out(dev_ctx.GetPlace()); memory_utils::Buffer out(dev_ctx.GetPlace());
CubDeviceReduce(iter, CubDeviceReduce(iter,
out.Alloc<bool>(1), out.Alloc<bool>(1),
numel, numel,
...@@ -1226,11 +1202,11 @@ static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) { ...@@ -1226,11 +1202,11 @@ static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) {
static void CheckHasNanInfGrad(const float *fp32_grad, static void CheckHasNanInfGrad(const float *fp32_grad,
int fp32_numel, int fp32_numel,
const platform::float16 *fp16_grad, const dtype::float16 *fp16_grad,
int fp16_numel, int fp16_numel,
float *nan_inf_flag, float *nan_inf_flag,
gpuStream_t stream, gpuStream_t stream,
phi::memory_utils::Buffer *cub_tmp_buffer) { memory_utils::Buffer *cub_tmp_buffer) {
bool *fp32_has_nan_inf = nullptr; bool *fp32_has_nan_inf = nullptr;
bool *fp16_has_nan_inf = nullptr; bool *fp16_has_nan_inf = nullptr;
if (fp32_numel > 0) { if (fp32_numel > 0) {
...@@ -1249,9 +1225,9 @@ static void CheckHasNanInfGrad(const float *fp32_grad, ...@@ -1249,9 +1225,9 @@ static void CheckHasNanInfGrad(const float *fp32_grad,
if (fp16_numel > 0) { if (fp16_numel > 0) {
fp16_has_nan_inf = reinterpret_cast<bool *>(nan_inf_flag + 1) + 1; fp16_has_nan_inf = reinterpret_cast<bool *>(nan_inf_flag + 1) + 1;
cub::TransformInputIterator<bool, cub::TransformInputIterator<bool,
IsNanInfFunctor<platform::float16>, IsNanInfFunctor<dtype::float16>,
const platform::float16 *> const dtype::float16 *>
iter(fp16_grad, IsNanInfFunctor<platform::float16>()); iter(fp16_grad, IsNanInfFunctor<dtype::float16>());
CubDeviceReduce(iter, CubDeviceReduce(iter,
fp16_has_nan_inf, fp16_has_nan_inf,
fp16_numel, fp16_numel,
...@@ -1316,7 +1292,7 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx, ...@@ -1316,7 +1292,7 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx,
int vec_size = int vec_size =
std::min(std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)), std::min(std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)),
GetChunkedVecSize(z, 0)); GetChunkedVecSize(z, 0));
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
#define PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL \ #define PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL \
do { \ do { \
...@@ -1329,52 +1305,103 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx, ...@@ -1329,52 +1305,103 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx,
#undef PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL #undef PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL
} }
template <typename T> template <typename T, typename Context>
class DistributedFusedLambOpKernel<T, phi::GPUContext> void DistributedFusedLambKernel(
: public framework::OpKernel<T> { const Context &dev_ctx,
public: const std::vector<const DenseTensor *> &param,
void Compute(const framework::ExecutionContext &ctx) const override { const std::vector<const DenseTensor *> &grad, /*unused*/
const paddle::optional<DenseTensor> &fp32_param,
const paddle::optional<DenseTensor> &fp32_grad,
const paddle::optional<DenseTensor> &fp16_param,
const paddle::optional<DenseTensor> &fp16_grad,
const DenseTensor &moment1,
const DenseTensor &moment2,
const DenseTensor &beta1_pow,
const DenseTensor &beta2_pow,
const DenseTensor &param_offsets,
const DenseTensor &fp32_partial_offsets,
const DenseTensor &fp16_partial_offsets,
const DenseTensor &param_info,
const DenseTensor &param_order,
const DenseTensor &learning_rate,
const DenseTensor &global_scale,
int acc_steps,
float beta1,
float beta2,
float epsilon,
float max_global_grad_norm,
float weight_decay,
bool clip_after_allreduce,
bool use_master_param_norm,
bool use_master_acc_grad,
bool is_grad_scaled_by_nranks,
bool use_hierarchical_allreduce,
int64_t nranks,
const std::vector<int> &ring_ids,
DenseTensor *fp32_param_out,
DenseTensor *fp16_param_out,
DenseTensor *fp32_acc_grad,
DenseTensor *fp16_acc_grad,
DenseTensor *moment1_out,
DenseTensor *moment2_out,
DenseTensor *beta1_pow_out,
DenseTensor *beta2_pow_out,
DenseTensor *param_out, /*unused*/
DenseTensor *found_inf,
DenseTensor *acc_step,
DenseTensor *stop_update,
DenseTensor *step) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
auto place = dev_ctx.GetPlace(); auto place = dev_ctx.GetPlace();
auto *found_inf_t = ctx.Output<phi::DenseTensor>("FoundInf"); found_inf->Resize({1});
found_inf_t->Resize({1});
// Step 1: Get fp16 param and grad tensors // Step 1: Get fp16 param and grad tensors
int64_t fp16_numel; int64_t fp16_numel;
auto *fp16_param = GetSameInOutTensorPtr<platform::float16, true>( auto *fp16_param_data =
ctx, place, "FP16FusedParam", "FP16FusedParamOut", &fp16_numel); GetSameInOutTensorPtr<dtype::float16, Context, true>(dev_ctx,
fp16_param.get_ptr(),
fp16_param_out,
"FP16FusedParam",
"FP16FusedParamOut",
&fp16_numel);
bool has_fp16_param = (fp16_numel > 0); bool has_fp16_param = (fp16_numel > 0);
const platform::float16 *fp16_grad = nullptr; const dtype::float16 *fp16_grad_data = nullptr;
if (has_fp16_param) { if (has_fp16_param) {
fp16_grad = GetInputTensorPtr<platform::float16>(ctx, "FP16FusedGrad"); fp16_grad_data =
GetInputTensorPtr<dtype::float16>(fp16_grad.get_ptr(), "FP16FusedGrad");
} else { } else {
fp16_param = nullptr; fp16_param_data = nullptr;
} }
// Step 2: Get fp32 param and grad tensors // Step 2: Get fp32 param and grad tensors
int64_t fp32_numel = 0; int64_t fp32_numel = 0;
auto *fp32_param = GetSameInOutTensorPtr<float, true>( auto *fp32_param_data =
ctx, place, "FP32FusedParam", "FP32FusedParamOut", &fp32_numel); GetSameInOutTensorPtr<float, Context, true>(dev_ctx,
fp32_param.get_ptr(),
fp32_param_out,
"FP32FusedParam",
"FP32FusedParamOut",
&fp32_numel);
PADDLE_ENFORCE_GE(fp32_numel, PADDLE_ENFORCE_GE(fp32_numel,
fp16_numel, fp16_numel,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The element number in FP32FusedParam should be not " "The element number in FP32FusedParam should be not "
"less than FP16FusedParam.")); "less than FP16FusedParam."));
fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and
// fp16 master weight // fp16 master weight
bool has_fp32_param = (fp32_numel > 0); bool has_fp32_param = (fp32_numel > 0);
const float *fp32_grad = nullptr; const float *fp32_grad_data = nullptr;
if (has_fp32_param) { if (has_fp32_param) {
fp32_grad = GetInputTensorPtr<float>(ctx, "FP32FusedGrad"); fp32_grad_data =
GetInputTensorPtr<float>(fp32_grad.get_ptr(), "FP32FusedGrad");
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
has_fp16_param, has_fp16_param,
true, true,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Either FP32FusedGrad or FP16FusedGrad cannot be NULL.")); "Either FP32FusedGrad or FP16FusedGrad cannot be NULL."));
} }
...@@ -1385,92 +1412,84 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1385,92 +1412,84 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
// The NVIDIA cub library does not support number > INT32_MAX // The NVIDIA cub library does not support number > INT32_MAX
PADDLE_ENFORCE_LE(numel, PADDLE_ENFORCE_LE(numel,
std::numeric_limits<int>::max(), std::numeric_limits<int>::max(),
platform::errors::Unimplemented( phi::errors::Unimplemented(
"Too many parameter number. Only <= %d is supported.", "Too many parameter number. Only <= %d is supported.",
std::numeric_limits<int>::max())); std::numeric_limits<int>::max()));
auto acc_steps = ctx.Attr<int>("acc_steps");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
acc_steps, acc_steps,
1, 1,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The gradient accumulation steps should be not less than 1.")); "The gradient accumulation steps should be not less than 1."));
if (acc_steps > 1) { if (acc_steps > 1) {
auto *step_t = ctx.Output<phi::DenseTensor>("AccStep");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
step_t, acc_step,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1.")); "Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1."));
bool is_initialized = step_t->IsInitialized(); bool is_initialized = acc_step->IsInitialized();
int64_t *step_ptr; int64_t *acc_step_data;
if (is_initialized) { if (is_initialized) {
step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace()); acc_step_data = dev_ctx.template HostAlloc<int64_t>(acc_step);
++(*step_ptr); ++(*acc_step_data);
} else { } else {
step_t->Resize({1}); acc_step->Resize({1});
step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace()); acc_step_data = dev_ctx.template HostAlloc<int64_t>(acc_step);
*step_ptr = 1; *acc_step_data = 1;
} }
int64_t rounded_step = (*step_ptr) % acc_steps; int64_t rounded_step = (*acc_step_data) % acc_steps;
float *fp32_acc_grad = nullptr; float *fp32_acc_grad_data = nullptr;
if (has_fp32_param) { if (has_fp32_param) {
auto *fp32_acc_grad_t = PADDLE_ENFORCE_NOT_NULL(fp32_acc_grad,
ctx.Output<phi::DenseTensor>("FP32AccFusedGrad"); phi::errors::InvalidArgument(
PADDLE_ENFORCE_NOT_NULL(
fp32_acc_grad_t,
platform::errors::InvalidArgument(
"Output(FP32AccFusedGrad) cannot be nullptr " "Output(FP32AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1.")); "when Attr(acc_steps) > 1."));
if (!fp32_acc_grad_t->IsInitialized()) { if (!fp32_acc_grad->IsInitialized()) {
fp32_acc_grad_t->Resize({static_cast<int64_t>(fp32_numel)}); fp32_acc_grad->Resize({static_cast<int64_t>(fp32_numel)});
fp32_acc_grad = fp32_acc_grad_t->mutable_data<float>(place); fp32_acc_grad_data = dev_ctx.template Alloc<float>(fp32_acc_grad);
} else { } else {
fp32_acc_grad = fp32_acc_grad_t->data<float>(); fp32_acc_grad_data = fp32_acc_grad->data<float>();
} }
} }
platform::float16 *fp16_acc_grad = nullptr; dtype::float16 *fp16_acc_grad_data = nullptr;
float *master_acc_grad = nullptr; float *master_acc_grad = nullptr;
bool use_master_acc_grad = false;
if (has_fp16_param) { if (has_fp16_param) {
use_master_acc_grad = ctx.Attr<bool>("use_master_acc_grad"); PADDLE_ENFORCE_NOT_NULL(fp16_acc_grad,
auto *fp16_acc_grad_t = phi::errors::InvalidArgument(
ctx.Output<phi::DenseTensor>("FP16AccFusedGrad");
PADDLE_ENFORCE_NOT_NULL(
fp16_acc_grad_t,
platform::errors::InvalidArgument(
"Output(FP16AccFusedGrad) cannot be nullptr " "Output(FP16AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1.")); "when Attr(acc_steps) > 1."));
if (!fp16_acc_grad_t->IsInitialized()) { if (!fp16_acc_grad->IsInitialized()) {
auto acc_grad_size = auto acc_grad_size =
use_master_acc_grad ? (3 * fp16_numel) : fp16_numel; use_master_acc_grad ? (3 * fp16_numel) : fp16_numel;
fp16_acc_grad_t->Resize({static_cast<int64_t>(acc_grad_size)}); fp16_acc_grad->Resize({static_cast<int64_t>(acc_grad_size)});
fp16_acc_grad = fp16_acc_grad_data =
fp16_acc_grad_t->mutable_data<platform::float16>(place); dev_ctx.template Alloc<dtype::float16>(fp16_acc_grad);
} else { } else {
fp16_acc_grad = fp16_acc_grad_t->data<platform::float16>(); fp16_acc_grad_data = fp16_acc_grad->data<dtype::float16>();
} }
if (use_master_acc_grad) { if (use_master_acc_grad) {
master_acc_grad = master_acc_grad =
reinterpret_cast<float *>(fp16_acc_grad + fp16_numel); reinterpret_cast<float *>(fp16_acc_grad_data + fp16_numel);
} }
} else {
use_master_acc_grad = false;
} }
// Inplace addto // Inplace addto
if (has_fp32_param) { if (has_fp32_param) {
if (rounded_step == 1) { if (rounded_step == 1) {
memory::Copy(place, memory_utils::Copy(place,
fp32_acc_grad, fp32_acc_grad_data,
place, place,
fp32_grad, fp32_grad_data,
fp32_numel * sizeof(float), fp32_numel * sizeof(float),
stream); stream);
} else { } else {
LaunchElementwiseAddWithCastKernel(dev_ctx, LaunchElementwiseAddWithCastKernel(dev_ctx,
fp32_grad, fp32_grad_data,
fp32_acc_grad, fp32_acc_grad_data,
fp32_acc_grad, fp32_acc_grad_data,
fp32_numel, fp32_numel,
stream); stream);
} }
...@@ -1480,44 +1499,44 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1480,44 +1499,44 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
if (acc_steps == 2 || !use_master_acc_grad) { if (acc_steps == 2 || !use_master_acc_grad) {
if (rounded_step != 1) { if (rounded_step != 1) {
LaunchElementwiseAddWithCastKernel(dev_ctx, LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_acc_grad, fp16_acc_grad_data,
fp16_grad, fp16_grad_data,
fp16_acc_grad, fp16_acc_grad_data,
fp16_numel, fp16_numel,
stream); stream);
} else { } else {
memory::Copy(place, memory_utils::Copy(place,
fp16_acc_grad, fp16_acc_grad_data,
place, place,
fp16_grad, fp16_grad_data,
fp16_numel * sizeof(platform::float16), fp16_numel * sizeof(dtype::float16),
stream); stream);
} }
} else { // acc_steps >= 3 } else { // acc_steps >= 3
if (rounded_step == 0) { if (rounded_step == 0) {
LaunchElementwiseAddWithCastKernel(dev_ctx, LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad, fp16_grad_data,
master_acc_grad, master_acc_grad,
fp16_acc_grad, fp16_acc_grad_data,
fp16_numel, fp16_numel,
stream); stream);
} else if (rounded_step == 1) { } else if (rounded_step == 1) {
memory::Copy(place, memory_utils::Copy(place,
fp16_acc_grad, fp16_acc_grad_data,
place, place,
fp16_grad, fp16_grad_data,
fp16_numel * sizeof(platform::float16), fp16_numel * sizeof(dtype::float16),
stream); stream);
} else if (rounded_step == 2) { } else if (rounded_step == 2) {
LaunchElementwiseAddWithCastKernel(dev_ctx, LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad, fp16_grad_data,
fp16_acc_grad, fp16_acc_grad_data,
master_acc_grad, master_acc_grad,
fp16_numel, fp16_numel,
stream); stream);
} else { } else {
LaunchElementwiseAddWithCastKernel(dev_ctx, LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad, fp16_grad_data,
master_acc_grad, master_acc_grad,
master_acc_grad, master_acc_grad,
fp16_numel, fp16_numel,
...@@ -1526,45 +1545,40 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1526,45 +1545,40 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
} }
} }
auto *stop_update_t = ctx.Output<phi::DenseTensor>("StopUpdate"); stop_update->Resize({1});
stop_update_t->Resize({1}); auto *stop_update_data = dev_ctx.template HostAlloc<bool>(stop_update);
auto *stop_update = auto *found_inf_cpu = dev_ctx.template HostAlloc<bool>(found_inf);
stop_update_t->mutable_data<bool>(platform::CPUPlace());
auto *found_inf_cpu =
found_inf_t->mutable_data<bool>(platform::CPUPlace());
if (rounded_step != 0) { if (rounded_step != 0) {
*stop_update = true; *stop_update_data = true;
auto *found_inf_cpu =
found_inf_t->mutable_data<bool>(platform::CPUPlace());
*found_inf_cpu = false; *found_inf_cpu = false;
return; return;
} else { } else {
// swap pointer // swap pointer
fp32_grad = fp32_acc_grad; fp32_grad_data = fp32_acc_grad_data;
fp16_grad = fp16_acc_grad; fp16_grad_data = fp16_acc_grad_data;
*stop_update = false; *stop_update_data = false;
found_inf_t->clear(); found_inf->clear();
} }
} }
// Step 3: Get ParamInfo // Step 3: Get ParamInfo
const auto *param_info_tensor = GetInputTensorPtr<int>(ctx, "ParamInfo"); const auto *param_info_data =
auto fp32_local_start_idx = param_info_tensor[0]; GetInputTensorPtr<int>(&param_info, "ParamInfo");
auto fp32_local_param_num = param_info_tensor[1]; auto fp32_local_start_idx = param_info_data[0];
auto fp32_global_param_num = param_info_tensor[2]; auto fp32_local_param_num = param_info_data[1];
auto fp32_weight_decay_end_idx = param_info_tensor[3]; auto fp32_global_param_num = param_info_data[2];
auto fp16_local_start_idx = param_info_tensor[4]; auto fp32_weight_decay_end_idx = param_info_data[3];
auto fp16_local_param_num = param_info_tensor[5]; auto fp16_local_start_idx = param_info_data[4];
auto fp16_global_param_num = param_info_tensor[6]; auto fp16_local_param_num = param_info_data[5];
auto fp16_weight_decay_end_idx = param_info_tensor[7]; auto fp16_global_param_num = param_info_data[6];
auto fp16_weight_decay_end_idx = param_info_data[7];
auto local_param_num = fp32_local_param_num + fp16_local_param_num; auto local_param_num = fp32_local_param_num + fp16_local_param_num;
auto param_num = fp32_global_param_num + fp16_global_param_num; auto param_num = fp32_global_param_num + fp16_global_param_num;
PADDLE_ENFORCE_LE(local_param_num, PADDLE_ENFORCE_LE(local_param_num,
param_num, param_num,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The local parameter number should not exceed the " "The local parameter number should not exceed the "
"global parameter number.")); "global parameter number."));
VLOG(1) << "local_param_num = " << local_param_num VLOG(1) << "local_param_num = " << local_param_num
...@@ -1578,15 +1592,17 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1578,15 +1592,17 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
// Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow, // Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow,
// GlobalScale // GlobalScale
const auto *global_scale = GetInputTensorPtr<float>(ctx, "GlobalScale"); const auto *global_scale_data =
const auto *lr = GetInputTensorPtr<float>(ctx, "LearningRate"); GetInputTensorPtr<float>(&global_scale, "GlobalScale");
const auto *lr_data =
GetInputTensorPtr<float>(&learning_rate, "LearningRate");
int64_t partial_numel = 0; int64_t partial_numel = 0;
auto *moment1 = GetSameInOutTensorPtr<float>( auto *moment1_data = GetSameInOutTensorPtr<float, Context>(
ctx, place, "Moment1", "Moment1Out", &partial_numel); dev_ctx, &moment1, moment1_out, "Moment1", "Moment1Out", &partial_numel);
PADDLE_ENFORCE_EQ(numel % partial_numel, PADDLE_ENFORCE_EQ(numel % partial_numel,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The total parameter number %d should be divided " "The total parameter number %d should be divided "
"exactly by the element number %d of Moment1.", "exactly by the element number %d of Moment1.",
numel, numel,
...@@ -1601,61 +1617,47 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1601,61 +1617,47 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
PADDLE_ENFORCE_EQ(fp32_numel % num_devices, PADDLE_ENFORCE_EQ(fp32_numel % num_devices,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The fp32 parameter number %d should be divided " "The fp32 parameter number %d should be divided "
"exactly by the device number %d.", "exactly by the device number %d.",
fp32_numel, fp32_numel,
num_devices)); num_devices));
PADDLE_ENFORCE_EQ(fp16_numel % num_devices, PADDLE_ENFORCE_EQ(fp16_numel % num_devices,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The fp16 parameter number %d should be divided " "The fp16 parameter number %d should be divided "
"exactly by the device number %d.", "exactly by the device number %d.",
fp16_numel, fp16_numel,
num_devices)); num_devices));
auto *moment2 = auto *moment2_data = GetSameInOutTensorPtr<float, Context>(
GetSameInOutTensorPtr<float>(ctx, place, "Moment2", "Moment2Out"); dev_ctx, &moment2, moment2_out, "Moment2", "Moment2Out");
auto *beta1pow = auto *beta1_pow_data = GetSameInOutTensorPtr<float, Context>(
GetSameInOutTensorPtr<float>(ctx, place, "Beta1Pow", "Beta1PowOut"); dev_ctx, &beta1_pow, beta1_pow_out, "Beta1Pow", "Beta1PowOut");
auto *beta2pow = auto *beta2_pow_data = GetSameInOutTensorPtr<float, Context>(
GetSameInOutTensorPtr<float>(ctx, place, "Beta2Pow", "Beta2PowOut"); dev_ctx, &beta2_pow, beta2_pow_out, "Beta2Pow", "Beta2PowOut");
auto *found_inf = found_inf_t->mutable_data<bool>(place); auto *found_inf_data = dev_ctx.template Alloc<bool>(found_inf);
// Step 5: Get attributes weight_decay, beta1, beta2, epsilon, // Step 5: Get attributes weight_decay, beta1, beta2, epsilon,
// max_grad_norm, ring_id, // max_grad_norm, ring_id,
// use_master_param_norm, is_grad_scaled_by_nranks // use_master_param_norm, is_grad_scaled_by_nranks
auto weight_decay = ctx.Attr<float>("weight_decay");
auto beta1 = ctx.Attr<float>("beta1");
auto beta2 = ctx.Attr<float>("beta2");
auto epsilon = ctx.Attr<float>("epsilon");
auto max_global_grad_norm = ctx.Attr<float>("max_global_grad_norm");
auto clip_after_allreduce = ctx.Attr<bool>("clip_after_allreduce");
auto nranks = ctx.Attr<int64_t>("nranks");
PADDLE_ENFORCE_GE(nranks, PADDLE_ENFORCE_GE(nranks,
num_devices, num_devices,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The nranks must be not less than num_devices.")); "The nranks must be not less than num_devices."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(nranks % num_devices,
nranks % num_devices,
0, 0,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The nranks must be exactly divided by num_devices.")); "The nranks must be exactly divided by num_devices."));
bool local_shard = (nranks > num_devices); bool local_shard = (nranks > num_devices);
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_ids");
auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
auto use_hierarchical_allreduce =
ctx.Attr<bool>("use_hierarchical_allreduce");
VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
<< " , clip_after_allreduce = " << clip_after_allreduce << " , clip_after_allreduce = " << clip_after_allreduce
<< " , use_master_param_norm = " << use_master_param_norm << " , use_master_param_norm = " << use_master_param_norm
<< " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks << " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
<< " , local_shard = " << local_shard << " , local_shard = " << local_shard
<< " , use_hierarchical_allreduce = " << " , use_hierarchical_allreduce = " << use_hierarchical_allreduce;
<< use_hierarchical_allreduce;
// Step 6: allreduce + global norm gradient clip // Step 6: allreduce + global norm gradient clip
int64_t global_rank = 0, local_rank = 0; int64_t global_rank = 0, local_rank = 0;
...@@ -1663,17 +1665,17 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1663,17 +1665,17 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
external_comm = nullptr; external_comm = nullptr;
if (nranks > 1) { if (nranks > 1) {
auto *nccl_comm_handle = auto *nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[0], place); paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
global_comm = nccl_comm_handle->comm(); global_comm = nccl_comm_handle->comm();
global_rank = nccl_comm_handle->rank(); global_rank = nccl_comm_handle->rank();
if (local_shard) { if (local_shard) {
auto *local_nccl_comm_handle = auto *local_nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[1], place); paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
local_comm = local_nccl_comm_handle->comm(); local_comm = local_nccl_comm_handle->comm();
local_rank = local_nccl_comm_handle->rank(); local_rank = local_nccl_comm_handle->rank();
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
external_comm = platform::NCCLCommContext::Instance() external_comm = paddle::platform::NCCLCommContext::Instance()
.Get(ring_ids[2], place) .Get(ring_ids[2], place)
->comm(); ->comm();
} }
...@@ -1683,30 +1685,30 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1683,30 +1685,30 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
} }
} }
phi::memory_utils::Buffer grad_norm_square_buffer(place); memory_utils::Buffer grad_norm_square_buffer(place);
auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2); auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2);
phi::memory_utils::Buffer cub_tmp_buffer(place); memory_utils::Buffer cub_tmp_buffer(place);
phi::memory_utils::Buffer sum_grad_buffer(place); memory_utils::Buffer sum_grad_buffer(place);
float *fp32_sum_grad; float *fp32_sum_grad;
platform::float16 *fp16_sum_grad; dtype::float16 *fp16_sum_grad;
auto fp32_numel_each_device = fp32_numel / num_devices; auto fp32_numel_each_device = fp32_numel / num_devices;
auto fp16_numel_each_device = fp16_numel / num_devices; auto fp16_numel_each_device = fp16_numel / num_devices;
if (local_shard) { if (local_shard) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>( auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel * sizeof(float) + fp16_numel * sizeof(platform::float16)); fp32_numel * sizeof(float) + fp16_numel * sizeof(dtype::float16));
fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr; fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
fp16_sum_grad = has_fp16_param ? reinterpret_cast<platform::float16 *>( fp16_sum_grad = has_fp16_param ? reinterpret_cast<dtype::float16 *>(
ptr + fp32_numel * sizeof(float)) ptr + fp32_numel * sizeof(float))
: nullptr; : nullptr;
} else if (nranks > 1 || } else if (nranks > 1 ||
(max_global_grad_norm > 0 && !clip_after_allreduce)) { (max_global_grad_norm > 0 && !clip_after_allreduce)) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>( auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel_each_device * sizeof(float) + fp32_numel_each_device * sizeof(float) +
fp16_numel_each_device * sizeof(platform::float16)); fp16_numel_each_device * sizeof(dtype::float16));
fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr; fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
fp16_sum_grad = has_fp16_param fp16_sum_grad = has_fp16_param
? reinterpret_cast<platform::float16 *>( ? reinterpret_cast<dtype::float16 *>(
ptr + fp32_numel_each_device * sizeof(float)) ptr + fp32_numel_each_device * sizeof(float))
: nullptr; : nullptr;
} else { } else {
...@@ -1716,8 +1718,8 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1716,8 +1718,8 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
// if-else codes (num_devices > 1) when I write the following code. // if-else codes (num_devices > 1) when I write the following code.
// So I prefer to use const_cast to unify the following code to reduce // So I prefer to use const_cast to unify the following code to reduce
// the if-else codes. // the if-else codes.
fp32_sum_grad = const_cast<float *>(fp32_grad); fp32_sum_grad = const_cast<float *>(fp32_grad_data);
fp16_sum_grad = const_cast<platform::float16 *>(fp16_grad); fp16_sum_grad = const_cast<dtype::float16 *>(fp16_grad_data);
} }
float rescale_grad = 1.0f; float rescale_grad = 1.0f;
...@@ -1731,7 +1733,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1731,7 +1733,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
if (local_shard) { if (local_shard) {
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp32_grad, fp32_grad_data,
fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device, fp32_numel_each_device,
num_devices, num_devices,
...@@ -1748,7 +1750,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1748,7 +1750,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
dev_ctx); dev_ctx);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp16_grad, fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device, fp16_numel_each_device,
num_devices, num_devices,
...@@ -1764,14 +1766,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1764,14 +1766,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream, stream,
dev_ctx); dev_ctx);
} else { } else {
NCCLAllReduceWithScale(fp32_grad, NCCLAllReduceWithScale(fp32_grad_data,
fp32_sum_grad, fp32_sum_grad,
fp32_numel, fp32_numel,
nranks, nranks,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLAllReduceWithScale(fp16_grad, NCCLAllReduceWithScale(fp16_grad_data,
fp16_sum_grad, fp16_sum_grad,
fp16_numel, fp16_numel,
nranks, nranks,
...@@ -1782,14 +1784,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1782,14 +1784,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
fp32_sum_grad += (local_rank * fp32_numel_each_device); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else { } else {
NCCLReduceScatterWithScale(fp32_grad, NCCLReduceScatterWithScale(fp32_grad_data,
fp32_sum_grad, fp32_sum_grad,
fp32_numel_each_device, fp32_numel_each_device,
nranks, nranks,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, NCCLReduceScatterWithScale(fp16_grad_data,
fp16_sum_grad, fp16_sum_grad,
fp16_numel_each_device, fp16_numel_each_device,
nranks, nranks,
...@@ -1809,7 +1811,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1809,7 +1811,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
<< FlattenToString(fp32_square_grad_norm, 1, place); << FlattenToString(fp32_square_grad_norm, 1, place);
if (num_devices > 1) { if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(fp32_square_grad_norm, phi::dynload::ncclAllReduce(fp32_square_grad_norm,
fp32_square_grad_norm, fp32_square_grad_norm,
1, 1,
ncclFloat32, ncclFloat32,
...@@ -1821,9 +1823,9 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1821,9 +1823,9 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
<< FlattenToString(fp32_square_grad_norm, 1, place); << FlattenToString(fp32_square_grad_norm, 1, place);
} else { } else {
// (1) Calculate the local grad norm // (1) Calculate the local grad norm
GetSquareGradNorm(fp32_grad, GetSquareGradNorm(fp32_grad_data,
fp32_numel, fp32_numel,
fp16_grad, fp16_grad_data,
fp16_numel, fp16_numel,
fp32_square_grad_norm, fp32_square_grad_norm,
stream, stream,
...@@ -1832,25 +1834,24 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1832,25 +1834,24 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
<< FlattenToString(fp32_square_grad_norm, 1, place); << FlattenToString(fp32_square_grad_norm, 1, place);
// (2) Calculate the gradient clip scale // (2) Calculate the gradient clip scale
float *fp32_scale = nullptr; float *fp32_scale = nullptr;
platform::float16 *fp16_scale = nullptr; dtype::float16 *fp16_scale = nullptr;
if (has_fp32_param && has_fp16_param) { if (has_fp32_param && has_fp16_param) {
auto *ptr = cub_tmp_buffer.Alloc<uint8_t>(sizeof(float) + auto *ptr = cub_tmp_buffer.Alloc<uint8_t>(sizeof(float) +
sizeof(platform::float16)); sizeof(dtype::float16));
fp32_scale = reinterpret_cast<float *>(ptr); fp32_scale = reinterpret_cast<float *>(ptr);
fp16_scale = fp16_scale = reinterpret_cast<dtype::float16 *>(ptr + sizeof(float));
reinterpret_cast<platform::float16 *>(ptr + sizeof(float));
} else if (has_fp32_param) { } else if (has_fp32_param) {
fp32_scale = cub_tmp_buffer.Alloc<float>(1); fp32_scale = cub_tmp_buffer.Alloc<float>(1);
} else { } else {
fp16_scale = cub_tmp_buffer.Alloc<platform::float16>(1); fp16_scale = cub_tmp_buffer.Alloc<dtype::float16>(1);
} }
float clip_scale = 1.0f; float clip_scale = 1.0f;
if (is_grad_scaled_by_nranks) { if (is_grad_scaled_by_nranks) {
clip_scale *= nranks; clip_scale *= nranks;
} }
CalcGradNormClipBeforeAllReduceScale<float, platform::float16> CalcGradNormClipBeforeAllReduceScale<float, dtype::float16>
<<<1, 1, 0, stream>>>(global_scale, <<<1, 1, 0, stream>>>(global_scale_data,
max_global_grad_norm, max_global_grad_norm,
fp32_square_grad_norm, fp32_square_grad_norm,
fp32_scale, fp32_scale,
...@@ -1863,13 +1864,13 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1863,13 +1864,13 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
} }
// (3) Do ReduceScatter with scale // (3) Do ReduceScatter with scale
VLOG(1) << "FP32 HasNanInf before all reduce: " VLOG(1) << "FP32 HasNanInf before all reduce: "
<< HasNanInf(dev_ctx, fp32_grad, fp32_numel); << HasNanInf(dev_ctx, fp32_grad_data, fp32_numel);
VLOG(1) << "FP16 HasNanInf before all reduce: " VLOG(1) << "FP16 HasNanInf before all reduce: "
<< HasNanInf(dev_ctx, fp16_grad, fp16_numel); << HasNanInf(dev_ctx, fp16_grad_data, fp16_numel);
if (local_shard) { if (local_shard) {
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp32_grad, fp32_grad_data,
fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device, fp32_numel_each_device,
num_devices, num_devices,
...@@ -1887,7 +1888,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1887,7 +1888,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
dev_ctx); dev_ctx);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp16_grad, fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device, fp16_numel_each_device,
num_devices, num_devices,
...@@ -1904,7 +1905,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1904,7 +1905,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream, stream,
dev_ctx); dev_ctx);
} else { } else {
NCCLAllReduceWithScale(fp32_grad, NCCLAllReduceWithScale(fp32_grad_data,
fp32_sum_grad, fp32_sum_grad,
fp32_numel, fp32_numel,
nranks, nranks,
...@@ -1912,7 +1913,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1912,7 +1913,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream, stream,
dev_ctx, dev_ctx,
fp32_scale); fp32_scale);
NCCLAllReduceWithScale(fp16_grad, NCCLAllReduceWithScale(fp16_grad_data,
fp16_sum_grad, fp16_sum_grad,
fp16_numel, fp16_numel,
nranks, nranks,
...@@ -1924,7 +1925,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1924,7 +1925,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
fp32_sum_grad += (local_rank * fp32_numel_each_device); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else { } else {
NCCLReduceScatterWithScale(fp32_grad, NCCLReduceScatterWithScale(fp32_grad_data,
fp32_sum_grad, fp32_sum_grad,
fp32_numel_each_device, fp32_numel_each_device,
nranks, nranks,
...@@ -1932,7 +1933,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1932,7 +1933,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream, stream,
dev_ctx, dev_ctx,
fp32_scale); fp32_scale);
NCCLReduceScatterWithScale(fp16_grad, NCCLReduceScatterWithScale(fp16_grad_data,
fp16_sum_grad, fp16_sum_grad,
fp16_numel_each_device, fp16_numel_each_device,
nranks, nranks,
...@@ -1954,7 +1955,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1954,7 +1955,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
&cub_tmp_buffer); &cub_tmp_buffer);
if (num_devices > 1) { if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(fp32_square_grad_norm, phi::dynload::ncclAllReduce(fp32_square_grad_norm,
fp32_square_grad_norm, fp32_square_grad_norm,
1, 1,
ncclFloat32, ncclFloat32,
...@@ -1972,7 +1973,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1972,7 +1973,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
if (local_shard) { if (local_shard) {
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp32_grad, fp32_grad_data,
fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device, fp32_numel_each_device,
num_devices, num_devices,
...@@ -1989,7 +1990,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -1989,7 +1990,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
dev_ctx); dev_ctx);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp16_grad, fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device, fp16_numel_each_device,
num_devices, num_devices,
...@@ -2005,14 +2006,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2005,14 +2006,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream, stream,
dev_ctx); dev_ctx);
} else { } else {
NCCLAllReduceWithScale(fp32_grad, NCCLAllReduceWithScale(fp32_grad_data,
fp32_sum_grad, fp32_sum_grad,
fp32_numel, fp32_numel,
nranks, nranks,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLAllReduceWithScale(fp16_grad, NCCLAllReduceWithScale(fp16_grad_data,
fp16_sum_grad, fp16_sum_grad,
fp16_numel, fp16_numel,
nranks, nranks,
...@@ -2023,14 +2024,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2023,14 +2024,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
fp32_sum_grad += (local_rank * fp32_numel_each_device); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else { } else {
NCCLReduceScatterWithScale(fp32_grad, NCCLReduceScatterWithScale(fp32_grad_data,
fp32_sum_grad, fp32_sum_grad,
fp32_numel_each_device, fp32_numel_each_device,
num_devices, num_devices,
global_comm, global_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLReduceScatterWithScale(fp16_grad, NCCLReduceScatterWithScale(fp16_grad_data,
fp16_sum_grad, fp16_sum_grad,
fp16_numel_each_device, fp16_numel_each_device,
num_devices, num_devices,
...@@ -2047,7 +2048,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2047,7 +2048,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
&cub_tmp_buffer); &cub_tmp_buffer);
if (num_devices > 1) { if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(fp32_square_grad_norm, phi::dynload::ncclAllReduce(fp32_square_grad_norm,
fp32_square_grad_norm, fp32_square_grad_norm,
1, 1,
ncclFloat32, ncclFloat32,
...@@ -2060,52 +2061,45 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2060,52 +2061,45 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
VLOG(10) << "ReduceScatter done"; VLOG(10) << "ReduceScatter done";
// Step 7: update the moment1, moment2. Calcuate the trust_ratio_div // Step 7: update the moment1, moment2. Calcuate the trust_ratio_div
auto *fused_offsets_t = ctx.Input<phi::DenseTensor>("FusedParamOffsets"); auto *param_offsets_data = param_offsets.data<int>();
auto *fused_offsets = fused_offsets_t->data<int>(); const auto *fp32_partial_offsets_data = fp32_partial_offsets.data<int>();
auto *fp32_partial_fused_offsets_t = const auto *fp16_partial_offsets_data = fp16_partial_offsets.data<int>();
ctx.Input<phi::DenseTensor>("FP32ShardFusedParamOffsets");
const auto *fp32_partial_fused_offsets = auto *step_data = step->data<int64_t>();
fp32_partial_fused_offsets_t->data<int>();
auto *fp16_partial_fused_offsets_t =
ctx.Input<phi::DenseTensor>("FP16ShardFusedParamOffsets");
const auto *fp16_partial_fused_offsets =
fp16_partial_fused_offsets_t->data<int>();
auto *step = ctx.Output<phi::DenseTensor>("Step")->data<int64_t>();
VLOG(1) << "FusedParamOffsets: " VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(fused_offsets, << FlattenToString(param_offsets_data,
fused_offsets_t->numel(), param_offsets.numel(),
fused_offsets_t->place()); param_offsets.place());
VLOG(1) << "FP32ShardFusedParamOffsets: " VLOG(1) << "FP32ShardFusedParamOffsets: "
<< FlattenToString(fp32_partial_fused_offsets, << FlattenToString(fp32_partial_offsets_data,
fp32_partial_fused_offsets_t->numel(), fp32_partial_offsets.numel(),
fp32_partial_fused_offsets_t->place()); fp32_partial_offsets.place());
VLOG(1) << "FP16ShardFusedParamOffsets: " VLOG(1) << "FP16ShardFusedParamOffsets: "
<< FlattenToString(fp16_partial_fused_offsets, << FlattenToString(fp16_partial_offsets_data,
fp16_partial_fused_offsets_t->numel(), fp16_partial_offsets.numel(),
fp16_partial_fused_offsets_t->place()); fp16_partial_offsets.place());
phi::memory_utils::Buffer trust_ratio_div_buffer(place); memory_utils::Buffer trust_ratio_div_buffer(place);
auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel); auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
auto fp32_offset = local_rank * fp32_numel_each_device; auto fp32_offset = local_rank * fp32_numel_each_device;
auto fp16_offset = local_rank * fp16_numel_each_device; auto fp16_offset = local_rank * fp16_numel_each_device;
if (has_fp32_param) { if (has_fp32_param) {
VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts"; VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts";
MultiTensorUpdateLambMomentAndTrustRatioDiv(dev_ctx, MultiTensorUpdateLambMomentAndTrustRatioDiv(dev_ctx,
fp32_partial_fused_offsets, fp32_partial_offsets_data,
fp32_local_param_num, fp32_local_param_num,
fp32_param + fp32_offset, fp32_param_data + fp32_offset,
fp32_sum_grad, fp32_sum_grad,
fp32_square_grad_norm, fp32_square_grad_norm,
global_scale, global_scale_data,
beta1pow, beta1_pow_data,
beta2pow, beta2_pow_data,
moment1, moment1_data,
moment2, moment2_data,
trust_ratio_div, trust_ratio_div,
found_inf, found_inf_data,
step, step_data,
weight_decay, weight_decay,
fp32_weight_decay_end_idx, fp32_weight_decay_end_idx,
beta1, beta1,
...@@ -2117,22 +2111,22 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2117,22 +2111,22 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
} }
float *master_param = nullptr; float *master_param = nullptr;
if (has_fp16_param) { if (has_fp16_param) {
master_param = fp32_param + fp32_numel; master_param = fp32_param_data + fp32_numel;
VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts"; VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts";
auto tmp_found_inf = has_fp32_param ? nullptr : found_inf; auto tmp_found_inf = has_fp32_param ? nullptr : found_inf_data;
auto tmp_step = has_fp32_param ? nullptr : step; auto tmp_step = has_fp32_param ? nullptr : step_data;
MultiTensorUpdateLambMomentAndTrustRatioDiv( MultiTensorUpdateLambMomentAndTrustRatioDiv(
dev_ctx, dev_ctx,
fp16_partial_fused_offsets, fp16_partial_offsets_data,
fp16_local_param_num, fp16_local_param_num,
master_param + fp16_offset, master_param + fp16_offset,
fp16_sum_grad, fp16_sum_grad,
fp32_square_grad_norm, fp32_square_grad_norm,
global_scale, global_scale_data,
beta1pow, beta1_pow_data,
beta2pow, beta2_pow_data,
moment1 + fp32_numel_each_device, moment1_data + fp32_numel_each_device,
moment2 + fp32_numel_each_device, moment2_data + fp32_numel_each_device,
trust_ratio_div + fp32_numel_each_device, trust_ratio_div + fp32_numel_each_device,
tmp_found_inf, tmp_found_inf,
tmp_step, tmp_step,
...@@ -2149,7 +2143,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2149,7 +2143,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
VLOG(10) << "Update Moment and TrustRatioDiv done hehahaha"; VLOG(10) << "Update Moment and TrustRatioDiv done hehahaha";
// Step 8: calculate L2-Norm square of parameter and trust_ratio_div // Step 8: calculate L2-Norm square of parameter and trust_ratio_div
phi::memory_utils::Buffer square_norm_buffer(place); memory_utils::Buffer square_norm_buffer(place);
auto *param_square_norm = square_norm_buffer.Alloc<float>(2 * param_num); auto *param_square_norm = square_norm_buffer.Alloc<float>(2 * param_num);
auto *trust_ratio_div_square_norm = param_square_norm + param_num; auto *trust_ratio_div_square_norm = param_square_norm + param_num;
if (num_devices > 1) { if (num_devices > 1) {
...@@ -2163,23 +2157,24 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2163,23 +2157,24 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
} }
MultiTensorL2Norm(place, MultiTensorL2Norm(place,
stream, stream,
fp32_param, fp32_param_data,
fused_offsets, param_offsets_data,
fp32_global_param_num, fp32_global_param_num,
param_square_norm); param_square_norm);
if (use_master_param_norm) { if (use_master_param_norm) {
MultiTensorL2Norm(place, MultiTensorL2Norm(place,
stream, stream,
master_param + fp16_offset, master_param + fp16_offset,
fp16_partial_fused_offsets, fp16_partial_offsets_data,
fp16_local_param_num, fp16_local_param_num,
param_square_norm + fp16_local_start_idx); param_square_norm + fp16_local_start_idx);
} else { } else {
MultiTensorL2Norm(place, MultiTensorL2Norm(place,
stream, stream,
fp16_param + fused_offsets[fp16_local_start_idx] - fp16_param_data +
fused_offsets[fp32_global_param_num], param_offsets_data[fp16_local_start_idx] -
fused_offsets + fp16_local_start_idx, param_offsets_data[fp32_global_param_num],
param_offsets_data + fp16_local_start_idx,
fp16_local_param_num, fp16_local_param_num,
param_square_norm + fp16_local_start_idx); param_square_norm + fp16_local_start_idx);
} }
...@@ -2187,13 +2182,13 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2187,13 +2182,13 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
MultiTensorL2Norm(place, MultiTensorL2Norm(place,
stream, stream,
trust_ratio_div, trust_ratio_div,
fp32_partial_fused_offsets, fp32_partial_offsets_data,
fp32_local_param_num, fp32_local_param_num,
trust_ratio_div_square_norm + fp32_local_start_idx); trust_ratio_div_square_norm + fp32_local_start_idx);
MultiTensorL2Norm(place, MultiTensorL2Norm(place,
stream, stream,
trust_ratio_div + fp32_numel_each_device, trust_ratio_div + fp32_numel_each_device,
fp16_partial_fused_offsets, fp16_partial_offsets_data,
fp16_local_param_num, fp16_local_param_num,
trust_ratio_div_square_norm + fp16_local_start_idx); trust_ratio_div_square_norm + fp16_local_start_idx);
...@@ -2201,8 +2196,8 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2201,8 +2196,8 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
<< FlattenToString(trust_ratio_div_square_norm, param_num, place); << FlattenToString(trust_ratio_div_square_norm, param_num, place);
if (num_devices > 1) { if (num_devices > 1) {
if (use_master_param_norm) { if (use_master_param_norm) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(
param_square_norm + fp32_global_param_num, phi::dynload::ncclAllReduce(param_square_norm + fp32_global_param_num,
param_square_norm + fp32_global_param_num, param_square_norm + fp32_global_param_num,
2 * param_num - fp32_global_param_num, 2 * param_num - fp32_global_param_num,
ncclFloat32, ncclFloat32,
...@@ -2211,7 +2206,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2211,7 +2206,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream)); stream));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(trust_ratio_div_square_norm, phi::dynload::ncclAllReduce(trust_ratio_div_square_norm,
trust_ratio_div_square_norm, trust_ratio_div_square_norm,
param_num, param_num,
ncclFloat32, ncclFloat32,
...@@ -2223,61 +2218,61 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2223,61 +2218,61 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
} }
LogParamAndTrustRatioDivSquareNorm<1>( LogParamAndTrustRatioDivSquareNorm<1>(
ctx, param_square_norm, trust_ratio_div_square_norm); param, param_order, param_square_norm, trust_ratio_div_square_norm);
VLOG(10) << "Calculate L2-Norm of Param and TrustRatioDiv done"; VLOG(10) << "Calculate L2-Norm of Param and TrustRatioDiv done";
// Step 9: update parameter, beta1pow, beta2pow. All gather parameters. // Step 9: update parameter, beta1pow, beta2pow. All gather parameters.
if (has_fp32_param) { if (has_fp32_param) {
MultiTensorUpdateLambParamAndBetaPows<float>( MultiTensorUpdateLambParamAndBetaPows<float>(
dev_ctx, dev_ctx,
fp32_partial_fused_offsets, fp32_partial_offsets_data,
fp32_local_param_num, fp32_local_param_num,
trust_ratio_div, trust_ratio_div,
lr, lr_data,
param_square_norm + fp32_local_start_idx, param_square_norm + fp32_local_start_idx,
trust_ratio_div_square_norm + fp32_local_start_idx, trust_ratio_div_square_norm + fp32_local_start_idx,
found_inf, found_inf_data,
fp32_param + fp32_offset, fp32_param_data + fp32_offset,
nullptr, nullptr,
beta1pow, beta1_pow_data,
beta2pow, beta2_pow_data,
beta1, beta1,
beta2); beta2);
if (num_devices > 1) { if (num_devices > 1) {
// ncclAllGather // ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(fp32_param + fp32_offset, phi::dynload::ncclAllGather(fp32_param_data + fp32_offset,
fp32_param, fp32_param_data,
fp32_numel_each_device, fp32_numel_each_device,
ncclFloat32, ncclFloat32,
local_comm, local_comm,
stream)); stream));
} }
beta1pow = nullptr; beta1_pow_data = nullptr;
beta2pow = nullptr; beta2_pow_data = nullptr;
} }
if (has_fp16_param) { if (has_fp16_param) {
MultiTensorUpdateLambParamAndBetaPows<platform::float16>( MultiTensorUpdateLambParamAndBetaPows<dtype::float16>(
dev_ctx, dev_ctx,
fp16_partial_fused_offsets, fp16_partial_offsets_data,
fp16_local_param_num, fp16_local_param_num,
trust_ratio_div + fp32_numel_each_device, trust_ratio_div + fp32_numel_each_device,
lr, lr_data,
param_square_norm + fp16_local_start_idx, param_square_norm + fp16_local_start_idx,
trust_ratio_div_square_norm + fp16_local_start_idx, trust_ratio_div_square_norm + fp16_local_start_idx,
found_inf, found_inf_data,
fp16_param + fp16_offset, fp16_param_data + fp16_offset,
master_param + fp16_offset, master_param + fp16_offset,
beta1pow, beta1_pow_data,
beta2pow, beta2_pow_data,
beta1, beta1,
beta2); beta2);
if (num_devices > 1) { if (num_devices > 1) {
// ncclAllGather // ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(fp16_param + fp16_offset, phi::dynload::ncclAllGather(fp16_param_data + fp16_offset,
fp16_param, fp16_param_data,
fp16_numel_each_device, fp16_numel_each_device,
ncclFloat16, ncclFloat16,
local_comm, local_comm,
...@@ -2288,20 +2283,29 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext> ...@@ -2288,20 +2283,29 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
VLOG(1) << "IsFinite: " << IsFinite(dev_ctx, fp32_square_grad_norm); VLOG(1) << "IsFinite: " << IsFinite(dev_ctx, fp32_square_grad_norm);
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"distributed_fused_lamb op should be used with NCCL/RCCL.")); "distributed_fused_lamb op should be used with NCCL/RCCL."));
#endif #endif
} }
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform; } // namespace fusion
namespace ops = paddle::operators; } // namespace phi
PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb, PD_REGISTER_KERNEL(distributed_fused_lamb,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
ops::DistributedFusedLambOpKernel, phi::fusion::DistributedFusedLambKernel,
float) {} float) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(9).SetDataType(phi::DataType::BOOL);
kernel->OutputAt(10).SetDataType(phi::DataType::INT64);
kernel->OutputAt(11).SetDataType(phi::DataType::BOOL);
kernel->OutputAt(12).SetDataType(phi::DataType::INT64);
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template <typename T, typename DevCtx>
class DistributedFusedLambOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"The distributed_fused_lamb operator does not support CPU yet."));
}
};
} // namespace operators
} // namespace paddle
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "math.h" // NOLINT #include "math.h" // NOLINT
#include "paddle/phi/core/cuda_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DistributedFusedLambOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("distributed_fused_lamb",
{"Param",
"Grad",
"FP32FusedParam",
"FP32FusedGrad",
"FP16FusedParam",
"FP16FusedGrad",
"Moment1",
"Moment2",
"Beta1Pow",
"Beta2Pow",
"FusedParamOffsets",
"FP32ShardFusedParamOffsets",
"FP16ShardFusedParamOffsets",
"ParamInfo",
"ParamOrder",
"LearningRate",
"GlobalScale"},
{"acc_steps",
"beta1",
"beta2",
"epsilon",
"max_global_grad_norm",
"weight_decay",
"clip_after_allreduce",
"use_master_param_norm",
"use_master_acc_grad",
"is_grad_scaled_by_nranks",
"use_hierarchical_allreduce",
"nranks",
"ring_ids"},
{"FP32FusedParamOut",
"FP16FusedParamOut",
"FP32AccFusedGrad",
"FP16AccFusedGrad",
"Moment1Out",
"Moment2Out",
"Beta1PowOut",
"Beta2PowOut",
"ParamOut",
"FoundInf",
"AccStep",
"StopUpdate",
"Step"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(distributed_fused_lamb,
phi::DistributedFusedLambOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册