未验证 提交 5f8e7d8f 编写于 作者: H huangjiyi 提交者: GitHub

Functionalize distributed_fused_lamb kernel (#53896)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update HostAlloc

* update param name

* update cpu kernel

* remove kernel header

* update

* update
上级 6e0cf610
......@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle {
namespace operators {
......@@ -170,8 +171,63 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb,
ops::DistributedFusedLambOp,
ops::DistributedFusedLambOpMaker);
PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb,
namespace phi {
namespace fusion {
template <typename T, typename Context>
void DistributedFusedLambKernel(const Context &dev_ctx,
const std::vector<const DenseTensor *> &param,
const std::vector<const DenseTensor *> &grad,
const paddle::optional<DenseTensor> &fp32_param,
const paddle::optional<DenseTensor> &fp32_grad,
const paddle::optional<DenseTensor> &fp16_param,
const paddle::optional<DenseTensor> &fp16_grad,
const DenseTensor &moment1,
const DenseTensor &moment2,
const DenseTensor &beta1_pow,
const DenseTensor &beta2_pow,
const DenseTensor &param_offsets,
const DenseTensor &fp32_partial_offsets,
const DenseTensor &fp16_partial_offsets,
const DenseTensor &param_info,
const DenseTensor &param_order,
const DenseTensor &learning_rate,
const DenseTensor &global_scale,
int acc_steps,
float beta1,
float beta2,
float epsilon,
float max_global_grad_norm,
float weight_decay,
bool clip_after_allreduce,
bool use_master_param_norm,
bool use_master_acc_grad,
bool is_grad_scaled_by_nranks,
bool use_hierarchical_allreduce,
int64_t nranks,
const std::vector<int> &ring_ids,
DenseTensor *fp32_param_out,
DenseTensor *fp16_param_out,
DenseTensor *fp32_acc_grad,
DenseTensor *fp16_acc_grad,
DenseTensor *moment1_out,
DenseTensor *moment2_out,
DenseTensor *beta1_pow_out,
DenseTensor *beta2_pow_out,
DenseTensor *param_out,
DenseTensor *found_inf,
DenseTensor *acc_step,
DenseTensor *stop_update,
DenseTensor *step) {
PADDLE_THROW(phi::errors::Unimplemented(
"The distributed_fused_lamb operator does not support CPU yet."));
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(distributed_fused_lamb,
CPU,
ALL_LAYOUT,
ops::DistributedFusedLambOpKernel,
phi::fusion::DistributedFusedLambKernel,
float) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,19 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/cuda_stream.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/tensor_to_string.h"
#include "paddle/utils/optional.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
......@@ -38,11 +40,11 @@
namespace cub = hipcub;
#endif
namespace paddle {
namespace operators {
namespace phi {
namespace fusion {
template <typename T>
using MasterT = typename details::MPTypeTrait<T>::Type;
using MasterT = typename phi::dtype::MPTypeTrait<T>::Type;
using phi::funcs::FlattenToString;
using phi::funcs::ToVector;
......@@ -157,7 +159,7 @@ template <typename InT,
typename OutT,
int MaxTensorNumPerLaunch = 160,
int MaxChunkNumPerLaunch = 780>
static void MultiTensorL2Norm(const platform::CUDAPlace &place,
static void MultiTensorL2Norm(const phi::GPUPlace &place,
gpuStream_t stream,
const InT *x,
const int *offsets,
......@@ -191,7 +193,7 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
<< " , tensor_num = " << n;
using MT = MasterT<InT>;
phi::memory_utils::Buffer tmp_out(place);
memory_utils::Buffer tmp_out(place);
auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);
......@@ -200,7 +202,8 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \
VLOG(10) << __func__ << " " << typeid(InT).name() \
<< " VecSize = " << kVecSize; \
MultiTensorApply<FunctorT, kNumTensor, kNumChunk>(FunctorT(), \
paddle::operators::MultiTensorApply<FunctorT, kNumTensor, kNumChunk>( \
FunctorT(), \
stream, \
offsets, \
n, \
......@@ -220,27 +223,27 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
template <int LogLevel>
static void LogParamAndTrustRatioDivSquareNorm(
const framework::ExecutionContext &ctx,
const std::vector<const DenseTensor *> &param,
const DenseTensor &order,
const float *param_square_norm,
const float *trust_ratio_div_square_norm) {
if (!VLOG_IS_ON(LogLevel)) return;
auto tensors = ctx.MultiInput<phi::DenseTensor>("Param");
if (tensors.empty()) return;
if (param.empty()) return;
const auto *order = ctx.Input<phi::DenseTensor>("ParamOrder")->data<int>();
const auto *order_data = order.data<int>();
size_t n = tensors.size();
auto place = tensors[0]->place();
size_t n = param.size();
auto place = param[0]->place();
auto pn_vec = ToVector(param_square_norm, n, place);
auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place);
const auto &names = ctx.GetOp().Inputs("Param");
for (size_t i = 0; i < n; ++i) {
auto idx = order[i];
VLOG(LogLevel) << "Param " << tensors[idx]->dtype() << " " << names[idx]
<< " pn = " << pn_vec[i] << " , tn = " << tn_vec[i];
auto idx = order_data[i];
VLOG(LogLevel) << "Param " << param[idx]->dtype() << " "
<< param[idx]->name() << " pn = " << pn_vec[i]
<< " , tn = " << tn_vec[i];
}
}
......@@ -261,13 +264,12 @@ static bool IsFinite(const phi::GPUContext &dev_ctx, const float *ptr) {
}
template <typename T>
static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx,
static const T *GetInputTensorPtr(const DenseTensor *in_tensor,
const char *in_name,
int64_t *numel = nullptr) {
const auto *in_tensor = ctx.Input<phi::DenseTensor>(in_name);
PADDLE_ENFORCE_NOT_NULL(
in_tensor,
platform::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
if (in_tensor->IsInitialized()) {
if (numel) *numel = in_tensor->numel();
return in_tensor->data<T>();
......@@ -277,34 +279,34 @@ static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx,
}
}
template <typename T, bool AllowNotExist = false>
static T *GetSameInOutTensorPtr(const framework::ExecutionContext &ctx,
const platform::Place &place,
template <typename T, typename Context, bool AllowNotExist = false>
static T *GetSameInOutTensorPtr(const Context &dev_ctx,
const DenseTensor *in_tensor,
DenseTensor *out_tensor,
const char *in_name,
const char *out_name,
int64_t *numel = nullptr) {
const auto *in_tensor = ctx.Input<phi::DenseTensor>(in_name);
if (in_tensor == nullptr || !in_tensor->IsInitialized()) {
PADDLE_ENFORCE_EQ(AllowNotExist,
PADDLE_ENFORCE_EQ(
AllowNotExist,
true,
platform::errors::InvalidArgument(
"Input(%s) cannot be NULL.", in_name));
phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
if (numel) *numel = 0;
return nullptr;
}
auto *out_tensor = ctx.Output<phi::DenseTensor>(out_name);
PADDLE_ENFORCE_NOT_NULL(
in_tensor,
platform::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
PADDLE_ENFORCE_NOT_NULL(out_tensor,
platform::errors::InvalidArgument(
"Output(%s) cannot be NULL.", out_name));
phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
PADDLE_ENFORCE_NOT_NULL(
out_tensor,
phi::errors::InvalidArgument("Output(%s) cannot be NULL.", out_name));
const T *in_data = in_tensor->data<T>();
T *out_data = out_tensor->mutable_data<T>(place);
T *out_data = dev_ctx.template Alloc<T>(out_tensor);
PADDLE_ENFORCE_EQ(in_data,
out_data,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Input(%s) and Output(%s) must be the same Tensor.",
in_name,
out_name));
......@@ -535,11 +537,11 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
int numel = offsets[n] - offsets[0];
PADDLE_ENFORCE_GE(weight_decay_end_idx,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The weight decay end index should be >= 0."));
PADDLE_ENFORCE_LE(weight_decay_end_idx,
n,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The weight decay end index should be < %d.", n));
auto weight_decay_end_numel = offsets[weight_decay_end_idx] - offsets[0];
......@@ -558,17 +560,17 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
VLOG(1) << __func__ << " VecSize = " << vec_size;
auto stream = dev_ctx.stream();
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
if (found_inf_p == nullptr) {
PADDLE_ENFORCE_EQ(
step,
nullptr,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Output(Step) cannot be updated twice in one mini-batch."));
} else {
PADDLE_ENFORCE_NOT_NULL(
step,
platform::errors::InvalidArgument("Output(Step) cannot be nullptr."));
step, phi::errors::InvalidArgument("Output(Step) cannot be nullptr."));
}
#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \
......@@ -603,12 +605,12 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
template <typename T, bool NeedUpdate /*=true*/>
struct LambBetaPowUpdateOnceHelper {
LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) {
PADDLE_ENFORCE_NOT_NULL(beta1pow,
platform::errors::InvalidArgument(
"The beta1pow should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(beta2pow,
platform::errors::InvalidArgument(
"The beta2pow should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
beta1pow,
phi::errors::InvalidArgument("The beta1pow should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
beta2pow,
phi::errors::InvalidArgument("The beta2pow should not be nullptr."));
beta1pow_ = beta1pow;
beta2pow_ = beta2pow;
beta1_ = beta1;
......@@ -633,11 +635,11 @@ struct LambBetaPowUpdateOnceHelper<T, false> {
PADDLE_ENFORCE_EQ(
beta1pow,
nullptr,
platform::errors::InvalidArgument("The beta1pow should be nullptr."));
phi::errors::InvalidArgument("The beta1pow should be nullptr."));
PADDLE_ENFORCE_EQ(
beta2pow,
nullptr,
platform::errors::InvalidArgument("The beta2pow should be nullptr."));
phi::errors::InvalidArgument("The beta2pow should be nullptr."));
}
HOSTDEVICE void UpdateBetaPows() const {}
......@@ -649,11 +651,11 @@ struct LambParamHelper {
constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
PADDLE_ENFORCE_EQ(kIsSameType,
false,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"T must not be the same with MasterT<T>."));
PADDLE_ENFORCE_NOT_NULL(master_param,
platform::errors::InvalidArgument(
"Master parameter must be provided."));
PADDLE_ENFORCE_NOT_NULL(
master_param,
phi::errors::InvalidArgument("Master parameter must be provided."));
param_ = param;
master_param_ = master_param;
}
......@@ -671,14 +673,14 @@ template <typename T>
struct LambParamHelper<T, false> {
LambParamHelper(T *param, MasterT<T> *master_param) {
constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
PADDLE_ENFORCE_EQ(kIsSameType,
PADDLE_ENFORCE_EQ(
kIsSameType,
true,
platform::errors::InvalidArgument(
"T must be the same with MasterT<T>."));
phi::errors::InvalidArgument("T must be the same with MasterT<T>."));
if (master_param != nullptr) {
PADDLE_ENFORCE_EQ(static_cast<void *>(param),
static_cast<void *>(master_param),
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Master parameter must be nullptr or the same as "
"non-master parameter."));
}
......@@ -802,12 +804,12 @@ static void MultiTensorUpdateLambParamAndBetaPows(
if (has_beta_pow) {
PADDLE_ENFORCE_NOT_NULL(
beta2pow,
platform::errors::InvalidArgument("Beta2Pow should not be nullptr."));
phi::errors::InvalidArgument("Beta2Pow should not be nullptr."));
} else {
PADDLE_ENFORCE_EQ(
beta2pow,
nullptr,
platform::errors::InvalidArgument("Beta2Pow should be nullptr."));
phi::errors::InvalidArgument("Beta2Pow should be nullptr."));
}
#ifdef PADDLE_WITH_HIP
......@@ -858,7 +860,8 @@ static void MultiTensorUpdateLambParamAndBetaPows(
#define PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE \
do { \
auto callback = \
[&](const MultiTensorLauncher<kNumTensor, kNumChunk> &launcher, \
[&](const paddle::operators::MultiTensorLauncher<kNumTensor, \
kNumChunk> &launcher, \
int launch_n) { \
if (has_beta_pow && launch_n == 0) { \
PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(true); \
......@@ -868,7 +871,7 @@ static void MultiTensorUpdateLambParamAndBetaPows(
PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(false); \
} \
}; \
MultiTensorApplyWithCallback<kNumTensor, kNumChunk>( \
paddle::operators::MultiTensorApplyWithCallback<kNumTensor, kNumChunk>( \
stream, offsets, n, chunk_size, block_dim, callback); \
} while (0)
......@@ -886,10 +889,10 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype,
ncclRedOp_t *op) {
#if NCCL_VERSION_CODE >= 21100
int ver;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetVersion(&ver));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetVersion(&ver));
if (ver >= 21100) {
VLOG(10) << "ncclRedOpCreatePreMulSum is supported.";
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpCreatePreMulSum(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpCreatePreMulSum(
op, const_cast<void *>(scale), dtype, ncclScalarDevice, comm));
return true;
}
......@@ -906,7 +909,7 @@ static void LaunchScaleKernel(const phi::GPUContext &dev_ctx,
int n,
gpuStream_t stream) {
int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0));
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
#define PD_LAMB_VEC_SCALE_KERNEL_CASE \
do { \
......@@ -928,8 +931,8 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
gpuStream_t stream,
const phi::GPUContext &dev_ctx,
const T *scale = nullptr) {
static_assert(std::is_same<T, float>::value ||
std::is_same<T, platform::float16>::value,
static_assert(
std::is_same<T, float>::value || std::is_same<T, dtype::float16>::value,
"T must be either float32 or float16.");
if (recvcount == 0) return;
......@@ -938,7 +941,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
if (scale != nullptr) {
PADDLE_ENFORCE_EQ(nranks,
1,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"nranks must be 1 when scale != nullptr."));
LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream);
}
......@@ -950,7 +953,7 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
std::is_same<T, float>::value ? ncclFloat32 : ncclFloat16;
bool should_destroy_op =
scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op);
phi::memory_utils::Buffer buffer(dev_ctx.GetPlace());
memory_utils::Buffer buffer(dev_ctx.GetPlace());
if (scale && !should_destroy_op) {
T *new_sendbuff = buffer.Alloc<T>(numel);
LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
......@@ -958,17 +961,17 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
}
if (UseReduceScatter) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclReduceScatter(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(
sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
}
#if NCCL_VERSION_CODE >= 21100
if (should_destroy_op) {
VLOG(10) << "ncclRedOpDestroy starts";
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpDestroy(op, comm));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpDestroy(op, comm));
VLOG(10) << "ncclRedOpDestroy ends";
}
#endif
......@@ -1012,7 +1015,7 @@ static void CubDeviceReduce(InputIteratorT d_in,
ReduceOpT reduction_op,
T init,
gpuStream_t stream,
phi::memory_utils::Buffer *buffer) {
memory_utils::Buffer *buffer) {
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Reduce(d_temp_storage,
......@@ -1041,7 +1044,7 @@ static void GetSquareGradNormImpl(const T *grad,
int n,
float *square_norm,
gpuStream_t stream,
phi::memory_utils::Buffer *cub_tmp_buffer) {
memory_utils::Buffer *cub_tmp_buffer) {
using Iterator =
cub::TransformInputIterator<float, SquareFunctor<T>, const T *>;
Iterator iter(grad, SquareFunctor<T>());
......@@ -1057,11 +1060,11 @@ static void GetSquareGradNormImpl(const T *grad,
// square_norm is of length 2 at least
static void GetSquareGradNorm(const float *fp32_grad,
int fp32_numel,
const platform::float16 *fp16_grad,
const dtype::float16 *fp16_grad,
int fp16_numel,
float *square_norm,
gpuStream_t stream,
phi::memory_utils::Buffer *cub_tmp_buffer) {
memory_utils::Buffer *cub_tmp_buffer) {
VLOG(10) << "GetSquareGradNorm starts, fp32_numel = " << fp32_numel
<< " , fp16_numel = " << fp16_numel;
if (fp32_numel > 0) {
......@@ -1096,23 +1099,21 @@ std::string NumToString(T x) {
}
template <typename T>
static std::string GetMinMaxStr(const T *x,
size_t n,
const platform::Place &place) {
static std::string GetMinMaxStr(const T *x, size_t n, const phi::Place &place) {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place),
place.GetType() == phi::AllocationType::GPU,
true,
platform::errors::InvalidArgument("Only support CUDAPlace currently."));
phi::errors::InvalidArgument("Only support CUDAPlace currently."));
auto *dev_ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(place));
phi::DeviceContextPool::Instance().Get(place));
auto stream = dev_ctx->stream();
phi::memory_utils::Buffer ret_buffer(place);
memory_utils::Buffer ret_buffer(place);
T *ret = ret_buffer.Alloc<T>(2);
if (n > 0) {
phi::memory_utils::Buffer cub_buffer(place);
memory_utils::Buffer cub_buffer(place);
CubDeviceReduce(x,
ret,
n,
......@@ -1160,45 +1161,20 @@ struct VisitDTypeFunctor {
static std::string GetMinMaxStr(const phi::DenseTensor *x) {
if (x == nullptr) return "null";
if (!x->IsInitialized()) return "not_inited";
if (!platform::is_gpu_place(x->place())) return "CPUTensor";
if (x->place().GetType() != phi::AllocationType::GPU) return "CPUTensor";
std::string str;
VisitDTypeFunctor functor(x, &str);
phi::VisitDataType(x->dtype(), functor);
return str;
}
static void PrintAllMinMaxRange(const framework::ExecutionContext &ctx,
bool only_inputs) {
if (!VLOG_IS_ON(1)) return;
for (const auto &pair : ctx.GetOp().Inputs()) {
const auto &key = pair.first;
const auto tensors = ctx.MultiInput<phi::DenseTensor>(key);
size_t n = tensors.size();
for (size_t i = 0; i < n; ++i) {
VLOG(1) << "Input(" << key + ")[" << i << "] = " << pair.second[i]
<< " , " << GetMinMaxStr(tensors[i]);
}
}
if (only_inputs) return;
for (const auto &pair : ctx.GetOp().Outputs()) {
const auto &key = pair.first;
const auto tensors = ctx.MultiOutput<phi::DenseTensor>(key);
size_t n = tensors.size();
for (size_t i = 0; i < n; ++i) {
VLOG(1) << "Output(" << key + ")[" << i << "] = " << pair.second[i]
<< " , " << GetMinMaxStr(tensors[i]);
}
}
}
template <typename T>
static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) {
if (numel <= 0) return false;
cub::TransformInputIterator<bool, IsNanInfFunctor<T>, const T *> iter(
x, IsNanInfFunctor<T>());
phi::memory_utils::Buffer buffer(dev_ctx.GetPlace());
phi::memory_utils::Buffer out(dev_ctx.GetPlace());
memory_utils::Buffer buffer(dev_ctx.GetPlace());
memory_utils::Buffer out(dev_ctx.GetPlace());
CubDeviceReduce(iter,
out.Alloc<bool>(1),
numel,
......@@ -1226,11 +1202,11 @@ static bool HasNanInf(const phi::GPUContext &dev_ctx, const T *x, int numel) {
static void CheckHasNanInfGrad(const float *fp32_grad,
int fp32_numel,
const platform::float16 *fp16_grad,
const dtype::float16 *fp16_grad,
int fp16_numel,
float *nan_inf_flag,
gpuStream_t stream,
phi::memory_utils::Buffer *cub_tmp_buffer) {
memory_utils::Buffer *cub_tmp_buffer) {
bool *fp32_has_nan_inf = nullptr;
bool *fp16_has_nan_inf = nullptr;
if (fp32_numel > 0) {
......@@ -1249,9 +1225,9 @@ static void CheckHasNanInfGrad(const float *fp32_grad,
if (fp16_numel > 0) {
fp16_has_nan_inf = reinterpret_cast<bool *>(nan_inf_flag + 1) + 1;
cub::TransformInputIterator<bool,
IsNanInfFunctor<platform::float16>,
const platform::float16 *>
iter(fp16_grad, IsNanInfFunctor<platform::float16>());
IsNanInfFunctor<dtype::float16>,
const dtype::float16 *>
iter(fp16_grad, IsNanInfFunctor<dtype::float16>());
CubDeviceReduce(iter,
fp16_has_nan_inf,
fp16_numel,
......@@ -1316,7 +1292,7 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx,
int vec_size =
std::min(std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)),
GetChunkedVecSize(z, 0));
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
#define PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL \
do { \
......@@ -1329,52 +1305,103 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx,
#undef PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL
}
template <typename T>
class DistributedFusedLambOpKernel<T, phi::GPUContext>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
template <typename T, typename Context>
void DistributedFusedLambKernel(
const Context &dev_ctx,
const std::vector<const DenseTensor *> &param,
const std::vector<const DenseTensor *> &grad, /*unused*/
const paddle::optional<DenseTensor> &fp32_param,
const paddle::optional<DenseTensor> &fp32_grad,
const paddle::optional<DenseTensor> &fp16_param,
const paddle::optional<DenseTensor> &fp16_grad,
const DenseTensor &moment1,
const DenseTensor &moment2,
const DenseTensor &beta1_pow,
const DenseTensor &beta2_pow,
const DenseTensor &param_offsets,
const DenseTensor &fp32_partial_offsets,
const DenseTensor &fp16_partial_offsets,
const DenseTensor &param_info,
const DenseTensor &param_order,
const DenseTensor &learning_rate,
const DenseTensor &global_scale,
int acc_steps,
float beta1,
float beta2,
float epsilon,
float max_global_grad_norm,
float weight_decay,
bool clip_after_allreduce,
bool use_master_param_norm,
bool use_master_acc_grad,
bool is_grad_scaled_by_nranks,
bool use_hierarchical_allreduce,
int64_t nranks,
const std::vector<int> &ring_ids,
DenseTensor *fp32_param_out,
DenseTensor *fp16_param_out,
DenseTensor *fp32_acc_grad,
DenseTensor *fp16_acc_grad,
DenseTensor *moment1_out,
DenseTensor *moment2_out,
DenseTensor *beta1_pow_out,
DenseTensor *beta2_pow_out,
DenseTensor *param_out, /*unused*/
DenseTensor *found_inf,
DenseTensor *acc_step,
DenseTensor *stop_update,
DenseTensor *step) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
auto stream = dev_ctx.stream();
auto place = dev_ctx.GetPlace();
auto *found_inf_t = ctx.Output<phi::DenseTensor>("FoundInf");
found_inf_t->Resize({1});
found_inf->Resize({1});
// Step 1: Get fp16 param and grad tensors
int64_t fp16_numel;
auto *fp16_param = GetSameInOutTensorPtr<platform::float16, true>(
ctx, place, "FP16FusedParam", "FP16FusedParamOut", &fp16_numel);
auto *fp16_param_data =
GetSameInOutTensorPtr<dtype::float16, Context, true>(dev_ctx,
fp16_param.get_ptr(),
fp16_param_out,
"FP16FusedParam",
"FP16FusedParamOut",
&fp16_numel);
bool has_fp16_param = (fp16_numel > 0);
const platform::float16 *fp16_grad = nullptr;
const dtype::float16 *fp16_grad_data = nullptr;
if (has_fp16_param) {
fp16_grad = GetInputTensorPtr<platform::float16>(ctx, "FP16FusedGrad");
fp16_grad_data =
GetInputTensorPtr<dtype::float16>(fp16_grad.get_ptr(), "FP16FusedGrad");
} else {
fp16_param = nullptr;
fp16_param_data = nullptr;
}
// Step 2: Get fp32 param and grad tensors
int64_t fp32_numel = 0;
auto *fp32_param = GetSameInOutTensorPtr<float, true>(
ctx, place, "FP32FusedParam", "FP32FusedParamOut", &fp32_numel);
auto *fp32_param_data =
GetSameInOutTensorPtr<float, Context, true>(dev_ctx,
fp32_param.get_ptr(),
fp32_param_out,
"FP32FusedParam",
"FP32FusedParamOut",
&fp32_numel);
PADDLE_ENFORCE_GE(fp32_numel,
fp16_numel,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The element number in FP32FusedParam should be not "
"less than FP16FusedParam."));
fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and
// fp16 master weight
bool has_fp32_param = (fp32_numel > 0);
const float *fp32_grad = nullptr;
const float *fp32_grad_data = nullptr;
if (has_fp32_param) {
fp32_grad = GetInputTensorPtr<float>(ctx, "FP32FusedGrad");
fp32_grad_data =
GetInputTensorPtr<float>(fp32_grad.get_ptr(), "FP32FusedGrad");
} else {
PADDLE_ENFORCE_EQ(
has_fp16_param,
true,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Either FP32FusedGrad or FP16FusedGrad cannot be NULL."));
}
......@@ -1385,92 +1412,84 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
// The NVIDIA cub library does not support number > INT32_MAX
PADDLE_ENFORCE_LE(numel,
std::numeric_limits<int>::max(),
platform::errors::Unimplemented(
phi::errors::Unimplemented(
"Too many parameter number. Only <= %d is supported.",
std::numeric_limits<int>::max()));
auto acc_steps = ctx.Attr<int>("acc_steps");
PADDLE_ENFORCE_GE(
acc_steps,
1,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The gradient accumulation steps should be not less than 1."));
if (acc_steps > 1) {
auto *step_t = ctx.Output<phi::DenseTensor>("AccStep");
PADDLE_ENFORCE_NOT_NULL(
step_t,
platform::errors::InvalidArgument(
acc_step,
phi::errors::InvalidArgument(
"Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1."));
bool is_initialized = step_t->IsInitialized();
int64_t *step_ptr;
bool is_initialized = acc_step->IsInitialized();
int64_t *acc_step_data;
if (is_initialized) {
step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace());
++(*step_ptr);
acc_step_data = dev_ctx.template HostAlloc<int64_t>(acc_step);
++(*acc_step_data);
} else {
step_t->Resize({1});
step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace());
*step_ptr = 1;
acc_step->Resize({1});
acc_step_data = dev_ctx.template HostAlloc<int64_t>(acc_step);
*acc_step_data = 1;
}
int64_t rounded_step = (*step_ptr) % acc_steps;
int64_t rounded_step = (*acc_step_data) % acc_steps;
float *fp32_acc_grad = nullptr;
float *fp32_acc_grad_data = nullptr;
if (has_fp32_param) {
auto *fp32_acc_grad_t =
ctx.Output<phi::DenseTensor>("FP32AccFusedGrad");
PADDLE_ENFORCE_NOT_NULL(
fp32_acc_grad_t,
platform::errors::InvalidArgument(
PADDLE_ENFORCE_NOT_NULL(fp32_acc_grad,
phi::errors::InvalidArgument(
"Output(FP32AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1."));
if (!fp32_acc_grad_t->IsInitialized()) {
fp32_acc_grad_t->Resize({static_cast<int64_t>(fp32_numel)});
fp32_acc_grad = fp32_acc_grad_t->mutable_data<float>(place);
if (!fp32_acc_grad->IsInitialized()) {
fp32_acc_grad->Resize({static_cast<int64_t>(fp32_numel)});
fp32_acc_grad_data = dev_ctx.template Alloc<float>(fp32_acc_grad);
} else {
fp32_acc_grad = fp32_acc_grad_t->data<float>();
fp32_acc_grad_data = fp32_acc_grad->data<float>();
}
}
platform::float16 *fp16_acc_grad = nullptr;
dtype::float16 *fp16_acc_grad_data = nullptr;
float *master_acc_grad = nullptr;
bool use_master_acc_grad = false;
if (has_fp16_param) {
use_master_acc_grad = ctx.Attr<bool>("use_master_acc_grad");
auto *fp16_acc_grad_t =
ctx.Output<phi::DenseTensor>("FP16AccFusedGrad");
PADDLE_ENFORCE_NOT_NULL(
fp16_acc_grad_t,
platform::errors::InvalidArgument(
PADDLE_ENFORCE_NOT_NULL(fp16_acc_grad,
phi::errors::InvalidArgument(
"Output(FP16AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1."));
if (!fp16_acc_grad_t->IsInitialized()) {
if (!fp16_acc_grad->IsInitialized()) {
auto acc_grad_size =
use_master_acc_grad ? (3 * fp16_numel) : fp16_numel;
fp16_acc_grad_t->Resize({static_cast<int64_t>(acc_grad_size)});
fp16_acc_grad =
fp16_acc_grad_t->mutable_data<platform::float16>(place);
fp16_acc_grad->Resize({static_cast<int64_t>(acc_grad_size)});
fp16_acc_grad_data =
dev_ctx.template Alloc<dtype::float16>(fp16_acc_grad);
} else {
fp16_acc_grad = fp16_acc_grad_t->data<platform::float16>();
fp16_acc_grad_data = fp16_acc_grad->data<dtype::float16>();
}
if (use_master_acc_grad) {
master_acc_grad =
reinterpret_cast<float *>(fp16_acc_grad + fp16_numel);
reinterpret_cast<float *>(fp16_acc_grad_data + fp16_numel);
}
} else {
use_master_acc_grad = false;
}
// Inplace addto
if (has_fp32_param) {
if (rounded_step == 1) {
memory::Copy(place,
fp32_acc_grad,
memory_utils::Copy(place,
fp32_acc_grad_data,
place,
fp32_grad,
fp32_grad_data,
fp32_numel * sizeof(float),
stream);
} else {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp32_grad,
fp32_acc_grad,
fp32_acc_grad,
fp32_grad_data,
fp32_acc_grad_data,
fp32_acc_grad_data,
fp32_numel,
stream);
}
......@@ -1480,44 +1499,44 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
if (acc_steps == 2 || !use_master_acc_grad) {
if (rounded_step != 1) {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_acc_grad,
fp16_grad,
fp16_acc_grad,
fp16_acc_grad_data,
fp16_grad_data,
fp16_acc_grad_data,
fp16_numel,
stream);
} else {
memory::Copy(place,
fp16_acc_grad,
memory_utils::Copy(place,
fp16_acc_grad_data,
place,
fp16_grad,
fp16_numel * sizeof(platform::float16),
fp16_grad_data,
fp16_numel * sizeof(dtype::float16),
stream);
}
} else { // acc_steps >= 3
if (rounded_step == 0) {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad,
fp16_grad_data,
master_acc_grad,
fp16_acc_grad,
fp16_acc_grad_data,
fp16_numel,
stream);
} else if (rounded_step == 1) {
memory::Copy(place,
fp16_acc_grad,
memory_utils::Copy(place,
fp16_acc_grad_data,
place,
fp16_grad,
fp16_numel * sizeof(platform::float16),
fp16_grad_data,
fp16_numel * sizeof(dtype::float16),
stream);
} else if (rounded_step == 2) {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad,
fp16_acc_grad,
fp16_grad_data,
fp16_acc_grad_data,
master_acc_grad,
fp16_numel,
stream);
} else {
LaunchElementwiseAddWithCastKernel(dev_ctx,
fp16_grad,
fp16_grad_data,
master_acc_grad,
master_acc_grad,
fp16_numel,
......@@ -1526,45 +1545,40 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
}
}
auto *stop_update_t = ctx.Output<phi::DenseTensor>("StopUpdate");
stop_update_t->Resize({1});
auto *stop_update =
stop_update_t->mutable_data<bool>(platform::CPUPlace());
auto *found_inf_cpu =
found_inf_t->mutable_data<bool>(platform::CPUPlace());
stop_update->Resize({1});
auto *stop_update_data = dev_ctx.template HostAlloc<bool>(stop_update);
auto *found_inf_cpu = dev_ctx.template HostAlloc<bool>(found_inf);
if (rounded_step != 0) {
*stop_update = true;
auto *found_inf_cpu =
found_inf_t->mutable_data<bool>(platform::CPUPlace());
*stop_update_data = true;
*found_inf_cpu = false;
return;
} else {
// swap pointer
fp32_grad = fp32_acc_grad;
fp16_grad = fp16_acc_grad;
*stop_update = false;
found_inf_t->clear();
fp32_grad_data = fp32_acc_grad_data;
fp16_grad_data = fp16_acc_grad_data;
*stop_update_data = false;
found_inf->clear();
}
}
// Step 3: Get ParamInfo
const auto *param_info_tensor = GetInputTensorPtr<int>(ctx, "ParamInfo");
auto fp32_local_start_idx = param_info_tensor[0];
auto fp32_local_param_num = param_info_tensor[1];
auto fp32_global_param_num = param_info_tensor[2];
auto fp32_weight_decay_end_idx = param_info_tensor[3];
auto fp16_local_start_idx = param_info_tensor[4];
auto fp16_local_param_num = param_info_tensor[5];
auto fp16_global_param_num = param_info_tensor[6];
auto fp16_weight_decay_end_idx = param_info_tensor[7];
const auto *param_info_data =
GetInputTensorPtr<int>(&param_info, "ParamInfo");
auto fp32_local_start_idx = param_info_data[0];
auto fp32_local_param_num = param_info_data[1];
auto fp32_global_param_num = param_info_data[2];
auto fp32_weight_decay_end_idx = param_info_data[3];
auto fp16_local_start_idx = param_info_data[4];
auto fp16_local_param_num = param_info_data[5];
auto fp16_global_param_num = param_info_data[6];
auto fp16_weight_decay_end_idx = param_info_data[7];
auto local_param_num = fp32_local_param_num + fp16_local_param_num;
auto param_num = fp32_global_param_num + fp16_global_param_num;
PADDLE_ENFORCE_LE(local_param_num,
param_num,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The local parameter number should not exceed the "
"global parameter number."));
VLOG(1) << "local_param_num = " << local_param_num
......@@ -1578,15 +1592,17 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
// Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow,
// GlobalScale
const auto *global_scale = GetInputTensorPtr<float>(ctx, "GlobalScale");
const auto *lr = GetInputTensorPtr<float>(ctx, "LearningRate");
const auto *global_scale_data =
GetInputTensorPtr<float>(&global_scale, "GlobalScale");
const auto *lr_data =
GetInputTensorPtr<float>(&learning_rate, "LearningRate");
int64_t partial_numel = 0;
auto *moment1 = GetSameInOutTensorPtr<float>(
ctx, place, "Moment1", "Moment1Out", &partial_numel);
auto *moment1_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &moment1, moment1_out, "Moment1", "Moment1Out", &partial_numel);
PADDLE_ENFORCE_EQ(numel % partial_numel,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The total parameter number %d should be divided "
"exactly by the element number %d of Moment1.",
numel,
......@@ -1601,61 +1617,47 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
PADDLE_ENFORCE_EQ(fp32_numel % num_devices,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The fp32 parameter number %d should be divided "
"exactly by the device number %d.",
fp32_numel,
num_devices));
PADDLE_ENFORCE_EQ(fp16_numel % num_devices,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The fp16 parameter number %d should be divided "
"exactly by the device number %d.",
fp16_numel,
num_devices));
auto *moment2 =
GetSameInOutTensorPtr<float>(ctx, place, "Moment2", "Moment2Out");
auto *beta1pow =
GetSameInOutTensorPtr<float>(ctx, place, "Beta1Pow", "Beta1PowOut");
auto *beta2pow =
GetSameInOutTensorPtr<float>(ctx, place, "Beta2Pow", "Beta2PowOut");
auto *moment2_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &moment2, moment2_out, "Moment2", "Moment2Out");
auto *beta1_pow_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &beta1_pow, beta1_pow_out, "Beta1Pow", "Beta1PowOut");
auto *beta2_pow_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &beta2_pow, beta2_pow_out, "Beta2Pow", "Beta2PowOut");
auto *found_inf = found_inf_t->mutable_data<bool>(place);
auto *found_inf_data = dev_ctx.template Alloc<bool>(found_inf);
// Step 5: Get attributes weight_decay, beta1, beta2, epsilon,
// max_grad_norm, ring_id,
// use_master_param_norm, is_grad_scaled_by_nranks
auto weight_decay = ctx.Attr<float>("weight_decay");
auto beta1 = ctx.Attr<float>("beta1");
auto beta2 = ctx.Attr<float>("beta2");
auto epsilon = ctx.Attr<float>("epsilon");
auto max_global_grad_norm = ctx.Attr<float>("max_global_grad_norm");
auto clip_after_allreduce = ctx.Attr<bool>("clip_after_allreduce");
auto nranks = ctx.Attr<int64_t>("nranks");
PADDLE_ENFORCE_GE(nranks,
num_devices,
phi::errors::InvalidArgument(
"The nranks must be not less than num_devices."));
PADDLE_ENFORCE_EQ(
nranks % num_devices,
PADDLE_ENFORCE_EQ(nranks % num_devices,
0,
phi::errors::InvalidArgument(
"The nranks must be exactly divided by num_devices."));
bool local_shard = (nranks > num_devices);
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_ids");
auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
auto use_hierarchical_allreduce =
ctx.Attr<bool>("use_hierarchical_allreduce");
VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
<< " , clip_after_allreduce = " << clip_after_allreduce
<< " , use_master_param_norm = " << use_master_param_norm
<< " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
<< " , local_shard = " << local_shard
<< " , use_hierarchical_allreduce = "
<< use_hierarchical_allreduce;
<< " , use_hierarchical_allreduce = " << use_hierarchical_allreduce;
// Step 6: allreduce + global norm gradient clip
int64_t global_rank = 0, local_rank = 0;
......@@ -1663,17 +1665,17 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
external_comm = nullptr;
if (nranks > 1) {
auto *nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
global_comm = nccl_comm_handle->comm();
global_rank = nccl_comm_handle->rank();
if (local_shard) {
auto *local_nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
local_comm = local_nccl_comm_handle->comm();
local_rank = local_nccl_comm_handle->rank();
if (use_hierarchical_allreduce) {
external_comm = platform::NCCLCommContext::Instance()
external_comm = paddle::platform::NCCLCommContext::Instance()
.Get(ring_ids[2], place)
->comm();
}
......@@ -1683,30 +1685,30 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
}
}
phi::memory_utils::Buffer grad_norm_square_buffer(place);
memory_utils::Buffer grad_norm_square_buffer(place);
auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2);
phi::memory_utils::Buffer cub_tmp_buffer(place);
memory_utils::Buffer cub_tmp_buffer(place);
phi::memory_utils::Buffer sum_grad_buffer(place);
memory_utils::Buffer sum_grad_buffer(place);
float *fp32_sum_grad;
platform::float16 *fp16_sum_grad;
dtype::float16 *fp16_sum_grad;
auto fp32_numel_each_device = fp32_numel / num_devices;
auto fp16_numel_each_device = fp16_numel / num_devices;
if (local_shard) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel * sizeof(float) + fp16_numel * sizeof(platform::float16));
fp32_numel * sizeof(float) + fp16_numel * sizeof(dtype::float16));
fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
fp16_sum_grad = has_fp16_param ? reinterpret_cast<platform::float16 *>(
fp16_sum_grad = has_fp16_param ? reinterpret_cast<dtype::float16 *>(
ptr + fp32_numel * sizeof(float))
: nullptr;
} else if (nranks > 1 ||
(max_global_grad_norm > 0 && !clip_after_allreduce)) {
auto ptr = sum_grad_buffer.Alloc<uint8_t>(
fp32_numel_each_device * sizeof(float) +
fp16_numel_each_device * sizeof(platform::float16));
fp16_numel_each_device * sizeof(dtype::float16));
fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
fp16_sum_grad = has_fp16_param
? reinterpret_cast<platform::float16 *>(
? reinterpret_cast<dtype::float16 *>(
ptr + fp32_numel_each_device * sizeof(float))
: nullptr;
} else {
......@@ -1716,8 +1718,8 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
// if-else codes (num_devices > 1) when I write the following code.
// So I prefer to use const_cast to unify the following code to reduce
// the if-else codes.
fp32_sum_grad = const_cast<float *>(fp32_grad);
fp16_sum_grad = const_cast<platform::float16 *>(fp16_grad);
fp32_sum_grad = const_cast<float *>(fp32_grad_data);
fp16_sum_grad = const_cast<dtype::float16 *>(fp16_grad_data);
}
float rescale_grad = 1.0f;
......@@ -1731,7 +1733,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLReduceScatterWithScale(
fp32_grad,
fp32_grad_data,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
......@@ -1748,7 +1750,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
dev_ctx);
NCCLReduceScatterWithScale(
fp16_grad,
fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
......@@ -1764,14 +1766,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
NCCLAllReduceWithScale(fp32_grad_data,
fp32_sum_grad,
fp32_numel,
nranks,
global_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
NCCLAllReduceWithScale(fp16_grad_data,
fp16_sum_grad,
fp16_numel,
nranks,
......@@ -1782,14 +1784,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad,
NCCLReduceScatterWithScale(fp32_grad_data,
fp32_sum_grad,
fp32_numel_each_device,
nranks,
global_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(fp16_grad,
NCCLReduceScatterWithScale(fp16_grad_data,
fp16_sum_grad,
fp16_numel_each_device,
nranks,
......@@ -1809,7 +1811,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
<< FlattenToString(fp32_square_grad_norm, 1, place);
if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(fp32_square_grad_norm,
phi::dynload::ncclAllReduce(fp32_square_grad_norm,
fp32_square_grad_norm,
1,
ncclFloat32,
......@@ -1821,9 +1823,9 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
<< FlattenToString(fp32_square_grad_norm, 1, place);
} else {
// (1) Calculate the local grad norm
GetSquareGradNorm(fp32_grad,
GetSquareGradNorm(fp32_grad_data,
fp32_numel,
fp16_grad,
fp16_grad_data,
fp16_numel,
fp32_square_grad_norm,
stream,
......@@ -1832,25 +1834,24 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
<< FlattenToString(fp32_square_grad_norm, 1, place);
// (2) Calculate the gradient clip scale
float *fp32_scale = nullptr;
platform::float16 *fp16_scale = nullptr;
dtype::float16 *fp16_scale = nullptr;
if (has_fp32_param && has_fp16_param) {
auto *ptr = cub_tmp_buffer.Alloc<uint8_t>(sizeof(float) +
sizeof(platform::float16));
sizeof(dtype::float16));
fp32_scale = reinterpret_cast<float *>(ptr);
fp16_scale =
reinterpret_cast<platform::float16 *>(ptr + sizeof(float));
fp16_scale = reinterpret_cast<dtype::float16 *>(ptr + sizeof(float));
} else if (has_fp32_param) {
fp32_scale = cub_tmp_buffer.Alloc<float>(1);
} else {
fp16_scale = cub_tmp_buffer.Alloc<platform::float16>(1);
fp16_scale = cub_tmp_buffer.Alloc<dtype::float16>(1);
}
float clip_scale = 1.0f;
if (is_grad_scaled_by_nranks) {
clip_scale *= nranks;
}
CalcGradNormClipBeforeAllReduceScale<float, platform::float16>
<<<1, 1, 0, stream>>>(global_scale,
CalcGradNormClipBeforeAllReduceScale<float, dtype::float16>
<<<1, 1, 0, stream>>>(global_scale_data,
max_global_grad_norm,
fp32_square_grad_norm,
fp32_scale,
......@@ -1863,13 +1864,13 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
}
// (3) Do ReduceScatter with scale
VLOG(1) << "FP32 HasNanInf before all reduce: "
<< HasNanInf(dev_ctx, fp32_grad, fp32_numel);
<< HasNanInf(dev_ctx, fp32_grad_data, fp32_numel);
VLOG(1) << "FP16 HasNanInf before all reduce: "
<< HasNanInf(dev_ctx, fp16_grad, fp16_numel);
<< HasNanInf(dev_ctx, fp16_grad_data, fp16_numel);
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLReduceScatterWithScale(
fp32_grad,
fp32_grad_data,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
......@@ -1887,7 +1888,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
dev_ctx);
NCCLReduceScatterWithScale(
fp16_grad,
fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
......@@ -1904,7 +1905,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
NCCLAllReduceWithScale(fp32_grad_data,
fp32_sum_grad,
fp32_numel,
nranks,
......@@ -1912,7 +1913,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream,
dev_ctx,
fp32_scale);
NCCLAllReduceWithScale(fp16_grad,
NCCLAllReduceWithScale(fp16_grad_data,
fp16_sum_grad,
fp16_numel,
nranks,
......@@ -1924,7 +1925,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad,
NCCLReduceScatterWithScale(fp32_grad_data,
fp32_sum_grad,
fp32_numel_each_device,
nranks,
......@@ -1932,7 +1933,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream,
dev_ctx,
fp32_scale);
NCCLReduceScatterWithScale(fp16_grad,
NCCLReduceScatterWithScale(fp16_grad_data,
fp16_sum_grad,
fp16_numel_each_device,
nranks,
......@@ -1954,7 +1955,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
&cub_tmp_buffer);
if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(fp32_square_grad_norm,
phi::dynload::ncclAllReduce(fp32_square_grad_norm,
fp32_square_grad_norm,
1,
ncclFloat32,
......@@ -1972,7 +1973,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLReduceScatterWithScale(
fp32_grad,
fp32_grad_data,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
......@@ -1989,7 +1990,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
dev_ctx);
NCCLReduceScatterWithScale(
fp16_grad,
fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
......@@ -2005,14 +2006,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
NCCLAllReduceWithScale(fp32_grad_data,
fp32_sum_grad,
fp32_numel,
nranks,
global_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
NCCLAllReduceWithScale(fp16_grad_data,
fp16_sum_grad,
fp16_numel,
nranks,
......@@ -2023,14 +2024,14 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
NCCLReduceScatterWithScale(fp32_grad,
NCCLReduceScatterWithScale(fp32_grad_data,
fp32_sum_grad,
fp32_numel_each_device,
num_devices,
global_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(fp16_grad,
NCCLReduceScatterWithScale(fp16_grad_data,
fp16_sum_grad,
fp16_numel_each_device,
num_devices,
......@@ -2047,7 +2048,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
&cub_tmp_buffer);
if (num_devices > 1) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(fp32_square_grad_norm,
phi::dynload::ncclAllReduce(fp32_square_grad_norm,
fp32_square_grad_norm,
1,
ncclFloat32,
......@@ -2060,52 +2061,45 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
VLOG(10) << "ReduceScatter done";
// Step 7: update the moment1, moment2. Calcuate the trust_ratio_div
auto *fused_offsets_t = ctx.Input<phi::DenseTensor>("FusedParamOffsets");
auto *fused_offsets = fused_offsets_t->data<int>();
auto *fp32_partial_fused_offsets_t =
ctx.Input<phi::DenseTensor>("FP32ShardFusedParamOffsets");
const auto *fp32_partial_fused_offsets =
fp32_partial_fused_offsets_t->data<int>();
auto *fp16_partial_fused_offsets_t =
ctx.Input<phi::DenseTensor>("FP16ShardFusedParamOffsets");
const auto *fp16_partial_fused_offsets =
fp16_partial_fused_offsets_t->data<int>();
auto *step = ctx.Output<phi::DenseTensor>("Step")->data<int64_t>();
auto *param_offsets_data = param_offsets.data<int>();
const auto *fp32_partial_offsets_data = fp32_partial_offsets.data<int>();
const auto *fp16_partial_offsets_data = fp16_partial_offsets.data<int>();
auto *step_data = step->data<int64_t>();
VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(fused_offsets,
fused_offsets_t->numel(),
fused_offsets_t->place());
<< FlattenToString(param_offsets_data,
param_offsets.numel(),
param_offsets.place());
VLOG(1) << "FP32ShardFusedParamOffsets: "
<< FlattenToString(fp32_partial_fused_offsets,
fp32_partial_fused_offsets_t->numel(),
fp32_partial_fused_offsets_t->place());
<< FlattenToString(fp32_partial_offsets_data,
fp32_partial_offsets.numel(),
fp32_partial_offsets.place());
VLOG(1) << "FP16ShardFusedParamOffsets: "
<< FlattenToString(fp16_partial_fused_offsets,
fp16_partial_fused_offsets_t->numel(),
fp16_partial_fused_offsets_t->place());
<< FlattenToString(fp16_partial_offsets_data,
fp16_partial_offsets.numel(),
fp16_partial_offsets.place());
phi::memory_utils::Buffer trust_ratio_div_buffer(place);
memory_utils::Buffer trust_ratio_div_buffer(place);
auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
auto fp32_offset = local_rank * fp32_numel_each_device;
auto fp16_offset = local_rank * fp16_numel_each_device;
if (has_fp32_param) {
VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts";
MultiTensorUpdateLambMomentAndTrustRatioDiv(dev_ctx,
fp32_partial_fused_offsets,
fp32_partial_offsets_data,
fp32_local_param_num,
fp32_param + fp32_offset,
fp32_param_data + fp32_offset,
fp32_sum_grad,
fp32_square_grad_norm,
global_scale,
beta1pow,
beta2pow,
moment1,
moment2,
global_scale_data,
beta1_pow_data,
beta2_pow_data,
moment1_data,
moment2_data,
trust_ratio_div,
found_inf,
step,
found_inf_data,
step_data,
weight_decay,
fp32_weight_decay_end_idx,
beta1,
......@@ -2117,22 +2111,22 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
}
float *master_param = nullptr;
if (has_fp16_param) {
master_param = fp32_param + fp32_numel;
master_param = fp32_param_data + fp32_numel;
VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts";
auto tmp_found_inf = has_fp32_param ? nullptr : found_inf;
auto tmp_step = has_fp32_param ? nullptr : step;
auto tmp_found_inf = has_fp32_param ? nullptr : found_inf_data;
auto tmp_step = has_fp32_param ? nullptr : step_data;
MultiTensorUpdateLambMomentAndTrustRatioDiv(
dev_ctx,
fp16_partial_fused_offsets,
fp16_partial_offsets_data,
fp16_local_param_num,
master_param + fp16_offset,
fp16_sum_grad,
fp32_square_grad_norm,
global_scale,
beta1pow,
beta2pow,
moment1 + fp32_numel_each_device,
moment2 + fp32_numel_each_device,
global_scale_data,
beta1_pow_data,
beta2_pow_data,
moment1_data + fp32_numel_each_device,
moment2_data + fp32_numel_each_device,
trust_ratio_div + fp32_numel_each_device,
tmp_found_inf,
tmp_step,
......@@ -2149,7 +2143,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
VLOG(10) << "Update Moment and TrustRatioDiv done hehahaha";
// Step 8: calculate L2-Norm square of parameter and trust_ratio_div
phi::memory_utils::Buffer square_norm_buffer(place);
memory_utils::Buffer square_norm_buffer(place);
auto *param_square_norm = square_norm_buffer.Alloc<float>(2 * param_num);
auto *trust_ratio_div_square_norm = param_square_norm + param_num;
if (num_devices > 1) {
......@@ -2163,23 +2157,24 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
}
MultiTensorL2Norm(place,
stream,
fp32_param,
fused_offsets,
fp32_param_data,
param_offsets_data,
fp32_global_param_num,
param_square_norm);
if (use_master_param_norm) {
MultiTensorL2Norm(place,
stream,
master_param + fp16_offset,
fp16_partial_fused_offsets,
fp16_partial_offsets_data,
fp16_local_param_num,
param_square_norm + fp16_local_start_idx);
} else {
MultiTensorL2Norm(place,
stream,
fp16_param + fused_offsets[fp16_local_start_idx] -
fused_offsets[fp32_global_param_num],
fused_offsets + fp16_local_start_idx,
fp16_param_data +
param_offsets_data[fp16_local_start_idx] -
param_offsets_data[fp32_global_param_num],
param_offsets_data + fp16_local_start_idx,
fp16_local_param_num,
param_square_norm + fp16_local_start_idx);
}
......@@ -2187,13 +2182,13 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
MultiTensorL2Norm(place,
stream,
trust_ratio_div,
fp32_partial_fused_offsets,
fp32_partial_offsets_data,
fp32_local_param_num,
trust_ratio_div_square_norm + fp32_local_start_idx);
MultiTensorL2Norm(place,
stream,
trust_ratio_div + fp32_numel_each_device,
fp16_partial_fused_offsets,
fp16_partial_offsets_data,
fp16_local_param_num,
trust_ratio_div_square_norm + fp16_local_start_idx);
......@@ -2201,8 +2196,8 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
<< FlattenToString(trust_ratio_div_square_norm, param_num, place);
if (num_devices > 1) {
if (use_master_param_norm) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
param_square_norm + fp32_global_param_num,
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclAllReduce(param_square_norm + fp32_global_param_num,
param_square_norm + fp32_global_param_num,
2 * param_num - fp32_global_param_num,
ncclFloat32,
......@@ -2211,7 +2206,7 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
stream));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(trust_ratio_div_square_norm,
phi::dynload::ncclAllReduce(trust_ratio_div_square_norm,
trust_ratio_div_square_norm,
param_num,
ncclFloat32,
......@@ -2223,61 +2218,61 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
}
LogParamAndTrustRatioDivSquareNorm<1>(
ctx, param_square_norm, trust_ratio_div_square_norm);
param, param_order, param_square_norm, trust_ratio_div_square_norm);
VLOG(10) << "Calculate L2-Norm of Param and TrustRatioDiv done";
// Step 9: update parameter, beta1pow, beta2pow. All gather parameters.
if (has_fp32_param) {
MultiTensorUpdateLambParamAndBetaPows<float>(
dev_ctx,
fp32_partial_fused_offsets,
fp32_partial_offsets_data,
fp32_local_param_num,
trust_ratio_div,
lr,
lr_data,
param_square_norm + fp32_local_start_idx,
trust_ratio_div_square_norm + fp32_local_start_idx,
found_inf,
fp32_param + fp32_offset,
found_inf_data,
fp32_param_data + fp32_offset,
nullptr,
beta1pow,
beta2pow,
beta1_pow_data,
beta2_pow_data,
beta1,
beta2);
if (num_devices > 1) {
// ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(fp32_param + fp32_offset,
fp32_param,
phi::dynload::ncclAllGather(fp32_param_data + fp32_offset,
fp32_param_data,
fp32_numel_each_device,
ncclFloat32,
local_comm,
stream));
}
beta1pow = nullptr;
beta2pow = nullptr;
beta1_pow_data = nullptr;
beta2_pow_data = nullptr;
}
if (has_fp16_param) {
MultiTensorUpdateLambParamAndBetaPows<platform::float16>(
MultiTensorUpdateLambParamAndBetaPows<dtype::float16>(
dev_ctx,
fp16_partial_fused_offsets,
fp16_partial_offsets_data,
fp16_local_param_num,
trust_ratio_div + fp32_numel_each_device,
lr,
lr_data,
param_square_norm + fp16_local_start_idx,
trust_ratio_div_square_norm + fp16_local_start_idx,
found_inf,
fp16_param + fp16_offset,
found_inf_data,
fp16_param_data + fp16_offset,
master_param + fp16_offset,
beta1pow,
beta2pow,
beta1_pow_data,
beta2_pow_data,
beta1,
beta2);
if (num_devices > 1) {
// ncclAllGather
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(fp16_param + fp16_offset,
fp16_param,
phi::dynload::ncclAllGather(fp16_param_data + fp16_offset,
fp16_param_data,
fp16_numel_each_device,
ncclFloat16,
local_comm,
......@@ -2288,20 +2283,29 @@ class DistributedFusedLambOpKernel<T, phi::GPUContext>
VLOG(1) << "IsFinite: " << IsFinite(dev_ctx, fp32_square_grad_norm);
#else
PADDLE_THROW(platform::errors::Unimplemented(
PADDLE_THROW(phi::errors::Unimplemented(
"distributed_fused_lamb op should be used with NCCL/RCCL."));
#endif
}
};
} // namespace operators
} // namespace paddle
}
namespace plat = paddle::platform;
namespace ops = paddle::operators;
} // namespace fusion
} // namespace phi
PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb,
PD_REGISTER_KERNEL(distributed_fused_lamb,
GPU,
ALL_LAYOUT,
ops::DistributedFusedLambOpKernel,
float) {}
phi::fusion::DistributedFusedLambKernel,
float) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(9).SetDataType(phi::DataType::BOOL);
kernel->OutputAt(10).SetDataType(phi::DataType::INT64);
kernel->OutputAt(11).SetDataType(phi::DataType::BOOL);
kernel->OutputAt(12).SetDataType(phi::DataType::INT64);
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template <typename T, typename DevCtx>
class DistributedFusedLambOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"The distributed_fused_lamb operator does not support CPU yet."));
}
};
} // namespace operators
} // namespace paddle
......@@ -18,6 +18,8 @@
#include "math.h" // NOLINT
#include "paddle/phi/core/cuda_stream.h"
namespace paddle {
namespace operators {
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DistributedFusedLambOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("distributed_fused_lamb",
{"Param",
"Grad",
"FP32FusedParam",
"FP32FusedGrad",
"FP16FusedParam",
"FP16FusedGrad",
"Moment1",
"Moment2",
"Beta1Pow",
"Beta2Pow",
"FusedParamOffsets",
"FP32ShardFusedParamOffsets",
"FP16ShardFusedParamOffsets",
"ParamInfo",
"ParamOrder",
"LearningRate",
"GlobalScale"},
{"acc_steps",
"beta1",
"beta2",
"epsilon",
"max_global_grad_norm",
"weight_decay",
"clip_after_allreduce",
"use_master_param_norm",
"use_master_acc_grad",
"is_grad_scaled_by_nranks",
"use_hierarchical_allreduce",
"nranks",
"ring_ids"},
{"FP32FusedParamOut",
"FP16FusedParamOut",
"FP32AccFusedGrad",
"FP16AccFusedGrad",
"Moment1Out",
"Moment2Out",
"Beta1PowOut",
"Beta2PowOut",
"ParamOut",
"FoundInf",
"AccStep",
"StopUpdate",
"Step"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(distributed_fused_lamb,
phi::DistributedFusedLambOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册