未验证 提交 d80fe268 编写于 作者: S sneaxiy 提交者: GitHub

Refine some AMP operators for BERT (#37923)

* support multi precision update for LAMB

* hide some api

* fix ci uts

* fix lamb output of dygraph

* remove some changes to some PR

* try to fix Py3 CI compile error

* fix test_imperative_optimizer, add lars ut, add layer_norm ut

* fix ut, fix format

* fix ut

* fix windows ci
上级 cff03734
......@@ -169,10 +169,16 @@ __inline__ __device__ half rsqrt_(const half val) {
}
#endif
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
T *y, U *mean, U *var, float epsilon,
int64_t feature_size) {
template <typename T, typename U, bool ScaleBiasWithSameTypeX>
using LayerNormScaleBiasT =
typename std::conditional<ScaleBiasWithSameTypeX, T, U>::type;
template <typename T, typename U, int BlockDim,
bool ScaleBiasWithSameTypeX = false>
__global__ void LayerNormForward(
const T *x, const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *bias, T *y,
U *mean, U *var, float epsilon, int64_t feature_size) {
__shared__ U mean_share;
__shared__ U var_share;
__shared__ U shared_mean[32]; // threadIdx.x / warpSize <= kMaxBlockDim /
......@@ -212,14 +218,15 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
if (bias != nullptr) {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(
scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
y[i] = static_cast<T>(static_cast<U>(scale[j]) *
(static_cast<U>(x[i]) - mean_val) * invvar +
static_cast<U>(bias[j]));
}
} else {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
invvar);
y[i] = static_cast<T>(static_cast<U>(scale[j]) *
(static_cast<U>(x[i]) - mean_val) * invvar);
}
}
} else { // scale == nullptr
......@@ -227,7 +234,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
bias[j]);
static_cast<U>(bias[j]));
}
} else {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
......@@ -336,12 +343,15 @@ __global__ void LayerNormBackwardPartGradGammaBeta(
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX>
__global__ void LayerNormBackwardSumGradGammaBeta(
const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
// const int n1, const int n2, T* grad_gamma, T* grad_beta) {
const int n1, const int n2, U *grad_gamma, U *grad_beta) {
const int n1, const int n2,
LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_gamma,
LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_beta) {
// sum partial gradients for gamma and beta
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX>;
__shared__ U buf[BDIMX * BDIMY];
int64_t i2 = blockIdx.x * BDIMX + threadIdx.x;
if (i2 < n2) {
......@@ -378,20 +388,18 @@ __global__ void LayerNormBackwardSumGradGammaBeta(
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
grad_gamma[i2] = static_cast<ScaleBiasT>(sum_gamma);
grad_beta[i2] = static_cast<ScaleBiasT>(sum_beta);
}
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX>
__global__ void LayerNormBackwardComputeGradInput(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2,
// const U* __restrict__ mean, const U* __restrict__ var, const float
// epsilon, const T* gamma,
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) {
const int n2, const U *__restrict__ mean, const U *__restrict__ var,
const float epsilon,
const LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *gamma, T *grad_input) {
#ifdef __HIPCC__
for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) {
#else
......@@ -411,15 +419,17 @@ __global__ void LayerNormBackwardComputeGradInput(
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);
sum_loss2 +=
c_loss * static_cast<U>(gamma[l + k]) * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
sum_loss1 += c_loss * static_cast<U>(gamma[l]);
sum_loss2 +=
c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
......@@ -491,7 +501,7 @@ __global__ void LayerNormBackwardComputeGradInput(
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
......@@ -513,11 +523,17 @@ __global__ void LayerNormBackwardComputeGradInput(
// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
template <typename T, typename U, int BlockDim, bool HasDx>
template <typename T, typename U, int BlockDim, bool HasDx,
bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientAll(
const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
const U *var, const U *scale, float epsilon, int64_t batch_size,
int64_t feature_size, int64_t col_offset) {
const T *x, const T *d_y,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, T *d_x,
const U *mean, const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t batch_size, int64_t feature_size,
int64_t col_offset) {
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
int64_t beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
int64_t end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
int64_t stride = BlockDim * feature_size;
......@@ -532,7 +548,8 @@ __global__ void LayerNormBackwardGradientAll(
d_bias_partial += static_cast<U>(d_y[i]);
if (HasDx) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val);
static_cast<U>(scale[blockIdx.x + col_offset]) /
var_val);
}
}
......@@ -543,19 +560,24 @@ __global__ void LayerNormBackwardGradientAll(
d_bias_partial = BlockReduceSum<U>(d_bias_partial, shared_bias);
if (threadIdx.x == 0) {
d_scale[blockIdx.x + col_offset] = d_scale_partial;
d_bias[blockIdx.x + col_offset] = d_bias_partial;
d_scale[blockIdx.x + col_offset] = static_cast<ScaleBiasT>(d_scale_partial);
d_bias[blockIdx.x + col_offset] = static_cast<ScaleBiasT>(d_bias_partial);
}
}
// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale,
bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientScaleOrBias(
const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
const U *var, const U *scale, float epsilon, int64_t batch_size,
int64_t feature_size, int col_offset) {
const T *x, const T *d_y,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, T *d_x,
const U *mean, const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t batch_size, int64_t feature_size, int col_offset) {
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
using BlockReduce = cub::BlockReduce<U, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
......@@ -578,7 +600,8 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if (HasDx) {
if (scale != nullptr) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val);
static_cast<U>(scale[blockIdx.x + col_offset]) /
var_val);
} else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
}
......@@ -590,9 +613,11 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if (threadIdx.x == 0) {
if (HasDScale) {
d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
d_scale[blockIdx.x + col_offset] =
static_cast<ScaleBiasT>(d_scale_or_d_bias_partial);
} else {
d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
d_bias[blockIdx.x + col_offset] =
static_cast<ScaleBiasT>(d_scale_or_d_bias_partial);
}
}
}
......@@ -640,12 +665,12 @@ __global__ void LayerNormBackwardPostProcessToCalculateDX(
}
// Here, we only calculate d_x
template <typename T, typename U, int BlockDim>
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
T *d_x, const U *mean,
const U *var, const U *scale,
float epsilon,
int64_t feature_size) {
template <typename T, typename U, int BlockDim, bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientOnlyDX(
const T *x, const T *d_y, T *d_x, const U *mean, const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t feature_size) {
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U d_x_reduce_tmp[2];
......@@ -660,8 +685,8 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
if (scale != nullptr) {
int col_idx = i % feature_size;
d_x[i] =
static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val);
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[col_idx]) / var_val);
} else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
}
......@@ -692,11 +717,16 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
}
}
template <typename T, typename U>
template <typename T, typename U, bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
const U *var, const U *scale, float epsilon, int64_t feature_size) {
const T *x, const T *d_y, T *d_x,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, const U *mean,
const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t feature_size) {
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
if (idx < feature_size) {
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon));
......@@ -704,26 +734,32 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
if (d_scale == nullptr) {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
} else {
d_x[idx] =
static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val);
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) *
static_cast<U>(scale[idx]) / var_val);
}
}
if (d_scale != nullptr) {
d_scale[idx] = static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[0]) / var_val;
d_scale[idx] =
static_cast<ScaleBiasT>(static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[0]) / var_val);
}
if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
if (d_bias != nullptr) {
d_bias[idx] = static_cast<ScaleBiasT>(d_y[idx]);
}
}
}
template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
const U *mean, const U *var, T *d_x, U *d_scale,
U *d_bias, float epsilon, int64_t batch_size,
int64_t feature_size,
const platform::CUDADeviceContext &dev_ctx) {
template <typename T, typename U, bool ScaleBiasWithSameTypeX = false>
static void LayerNormBackward(
const T *x, const T *d_y,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const U *mean, const U *var, T *d_x,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, float epsilon,
int64_t batch_size, int64_t feature_size,
const platform::CUDADeviceContext &dev_ctx) {
auto stream = dev_ctx.stream();
#ifdef __HIPCC__
const int kMaxBlockDim = 256;
......@@ -737,10 +773,10 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
if (gradient_flag == 0) return;
if (batch_size == 1) {
LayerNormBackwardWhenBatchSizeIsOne<
T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim,
0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale,
epsilon, feature_size);
LayerNormBackwardWhenBatchSizeIsOne<T, U, ScaleBiasWithSameTypeX><<<
(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0,
stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
feature_size);
if (d_x != nullptr) {
switch (GetDesiredBlockDim(feature_size)) {
......@@ -759,8 +795,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, false,
false><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, false, false,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -770,8 +806,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, false,
true><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, false, true,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -781,7 +817,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientAll<
T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, false,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -790,7 +827,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardGradientOnlyDX<
T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
T, U, kBlockDim,
ScaleBiasWithSameTypeX><<<batch_size, kBlockDim, 0, stream>>>(
x, d_y, d_x, mean, var, scale, epsilon, feature_size));
}
break;
......@@ -799,8 +837,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, true,
false><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, true, false,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -816,8 +854,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, true,
true><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, true, true,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -854,7 +892,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
dim3 threads3(BDIMX3, BDIMY3, 1);
const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
LayerNormBackwardSumGradGammaBeta<
T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>(
T, U, BDIMX3, BDIMY3,
ScaleBiasWithSameTypeX><<<blocks3, threads3, 0, stream>>>(
part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
d_scale, d_bias);
......@@ -862,7 +901,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1);
LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<batch_size, threads1, 0, stream>>>(
T, U, BDIMX1, BDIMY1,
ScaleBiasWithSameTypeX><<<batch_size, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break;
}
......
......@@ -102,24 +102,6 @@ class LayerNormOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto ln_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
ln_param_type = framework::proto::VarType::FP64;
}
if (ctx.HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
}
if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
......
......@@ -63,8 +63,32 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
auto *var_data = var->mutable_data<U>(ctx.GetPlace());
auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>());
auto *void_scale_data = (scale == nullptr ? nullptr : scale->data<void>());
auto *void_bias_data = (bias == nullptr ? nullptr : bias->data<void>());
framework::proto::VarType::Type x_dtype = x->type();
framework::proto::VarType::Type scale_bias_dtype;
if (void_scale_data != nullptr) {
scale_bias_dtype = scale->type();
if (void_bias_data != nullptr) {
PADDLE_ENFORCE_EQ(scale_bias_dtype, bias->type(),
platform::errors::InvalidArgument(
"Thie Scale and Bias of layer_norm op "
"should have the same data type."));
}
} else {
scale_bias_dtype = (void_bias_data != nullptr ? bias->type() : x_dtype);
}
bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
if (!is_scale_bias_same_dtype_with_x) {
PADDLE_ENFORCE_EQ(scale_bias_dtype,
framework::DataTypeTrait<U>::DataType(),
platform::errors::InvalidArgument(
"Unsupported data type of Scale and Bias: %s",
framework::DataTypeToString(scale_bias_dtype)));
}
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
......@@ -72,17 +96,28 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
auto stream = ctx.cuda_device_context().stream();
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, U,
kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x_data, scale_data, bias_data, y_data, mean_data, var_data,
epsilon, feature_size));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Product from begin_norm_axis to end must be larger than 1"));
break;
#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
switch (GetDesiredBlockDim(feature_size)) { \
FIXED_BLOCK_DIM_CASE( \
LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX><<< \
batch_size, kBlockDim, 0, stream>>>( \
x_data, static_cast<const ScaleBiasT *>(void_scale_data), \
static_cast<const ScaleBiasT *>(void_bias_data), y_data, \
mean_data, var_data, epsilon, feature_size)); \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Product from begin_norm_axis to end must be larger than 1")); \
break; \
} \
} while (0)
if (is_scale_bias_same_dtype_with_x) {
PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
}
#undef PADDLE_LAUNCH_LAYERNORM_FWD
}
};
......@@ -102,32 +137,64 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
auto *mean = ctx.Input<Tensor>("Mean");
auto *var = ctx.Input<Tensor>("Variance");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto *x_data = x->data<T>();
auto *d_y_data = d_y->data<T>();
auto *mean_data = mean->data<U>();
auto *var_data = var->data<U>();
auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *d_scale_data =
(d_scale == nullptr ? nullptr
: d_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_bias_data =
(d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
auto *d_x_data =
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
framework::proto::VarType::Type x_dtype = x->type();
framework::proto::VarType::Type scale_bias_dtype;
if (scale != nullptr) {
scale_bias_dtype = scale->type();
} else {
// FIXME(zengjinle): do not find a better way to get the right
// data type of the d_scale and d_bias if scale == nullptr.
auto *bias = ctx.Input<Tensor>("Bias");
if (bias != nullptr) {
scale_bias_dtype = bias->saved_type();
} else {
scale_bias_dtype = x_dtype;
}
}
#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
auto *scale_data = \
(scale == nullptr ? nullptr : scale->data<ScaleBiasT>()); \
auto *d_scale_data = \
(d_scale == nullptr ? nullptr : d_scale->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_bias_data = \
(d_bias == nullptr ? nullptr : d_bias->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_x_data = \
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace())); \
LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>( \
x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, \
d_scale_data, d_bias_data, epsilon, batch_size, feature_size, \
ctx.cuda_device_context()); \
} while (0)
if (scale_bias_dtype == x_dtype) {
PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
}
LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
d_x_data, d_scale_data, d_bias_data, epsilon,
batch_size, feature_size,
ctx.cuda_device_context());
#undef PADDLE_LAUNCH_LAYERNORM_BWD
}
};
......
......@@ -152,6 +152,14 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Moment2", "(Tensor) Input second moment.");
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator.");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator.");
AddInput("MasterParam",
"(LoDTensor, default LoDTensor<float>) "
"Input master parameter that has to be updated.")
.AsDispensable();
AddInput(
"SkipUpdate",
"(Tensor) Input tensor to determine whether to update the parameter.")
.AsDispensable();
AddOutput("ParamOut", "(Tensor) Output parameter.");
AddOutput("Moment1Out", "(Tensor) Output first moment.");
......@@ -160,6 +168,8 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable();
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator")
.AsDispensable();
AddOutput("MasterParamOut", "(Tensor) Output master parameter.")
.AsDispensable();
AddAttr<float>("weight_decay", "(float) Weight decay rate.");
AddAttr<float>("beta1",
"(float, default 0.9) The exponential decay rate for the "
......@@ -173,6 +183,10 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
"(float, default 1.0e-6) "
"Constant for numerical stability.")
.SetDefault(1.0e-6f);
AddAttr<bool>(
"multi_precision",
"(bool, default false) Whether to enable multi-precision mode.")
.SetDefault(false);
AddComment(R"DOC(
LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.
......
......@@ -16,5 +16,7 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lamb, ops::LambOpKernel<paddle::platform::CUDADeviceContext, float>,
lamb, ops::LambOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::LambOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::LambOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -17,8 +17,10 @@ limitations under the License. */
#include <Eigen/Dense>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
......@@ -26,30 +28,35 @@ namespace operators {
namespace scatter = paddle::operators::math::scatter;
template <typename T>
template <typename T, bool IsMultiPrecision>
struct LambMomentREGUpdateFunctor {
T weight_decay_;
T beta1_;
T beta2_;
T epsilon_;
T beta1_pow_;
T* beta1_pow_out_;
T beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
using MT = typename std::conditional<
IsMultiPrecision, typename details::MPTypeTrait<T>::Type, T>::type;
MT weight_decay_;
MT beta1_;
MT beta2_;
MT epsilon_;
MT beta1_pow_;
MT* beta1_pow_out_;
MT beta2_pow_;
MT* beta2_pow_out_;
const MT* moment1_;
MT* moment1_out_;
const MT* moment2_;
MT* moment2_out_;
const T* grad_;
const T* param_;
T* trust_ratio_div_;
LambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
T beta1_pow, T* beta1_pow_out, T beta2_pow,
T* beta2_pow_out, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* grad,
const T* param, T* trust_ratio_div)
const MT* param_;
MT* trust_ratio_div_;
const bool* skip_update_;
LambMomentREGUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
MT beta1_pow, MT* beta1_pow_out, MT beta2_pow,
MT* beta2_pow_out, const MT* mom1, MT* mom1_out,
const MT* mom2, MT* mom2_out, const T* grad,
const MT* param, MT* trust_ratio_div,
const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
......@@ -64,26 +71,30 @@ struct LambMomentREGUpdateFunctor {
moment2_out_(mom2_out),
grad_(grad),
param_(param),
trust_ratio_div_(trust_ratio_div) {}
trust_ratio_div_(trust_ratio_div),
skip_update_(skip_update) {}
inline HOSTDEVICE void operator()(size_t i) const {
T g = grad_[i];
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T beta1_pow = beta1_pow_;
T beta2_pow = beta2_pow_;
T p = param_[i];
if (skip_update_ && *skip_update_) return;
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
MT g = static_cast<MT>(grad_[i]);
MT mom1 = moment1_[i];
MT mom2 = moment2_[i];
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;
MT p = param_[i];
mom1 = beta1_ * mom1 + (static_cast<MT>(1) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<MT>(1) - beta2_) * g * g;
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
T mom1_unbiased = mom1 / (1 - beta1_pow);
T mom2_unbiased = mom2 / (1 - beta2_pow);
MT mom1_unbiased = mom1 / (static_cast<MT>(1) - beta1_pow);
MT mom2_unbiased = mom2 / (static_cast<MT>(1) - beta2_pow);
trust_ratio_div_[i] =
mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p;
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
......@@ -91,31 +102,35 @@ struct LambMomentREGUpdateFunctor {
}
};
template <typename T>
template <typename T, bool IsMultiPrecision>
struct LambMomentMENUpdateFunctor {
T weight_decay_;
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
T* beta1_pow_out_;
const T* beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
using MT = typename std::conditional<
IsMultiPrecision, typename details::MPTypeTrait<T>::Type, T>::type;
MT weight_decay_;
MT beta1_;
MT beta2_;
MT epsilon_;
const MT* beta1_pow_;
MT* beta1_pow_out_;
const MT* beta2_pow_;
MT* beta2_pow_out_;
const MT* moment1_;
MT* moment1_out_;
const MT* moment2_;
MT* moment2_out_;
const T* grad_;
const T* param_;
T* trust_ratio_div_;
LambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
const T* beta1_pow, T* beta1_pow_out,
const T* beta2_pow, T* beta2_pow_out,
const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* grad, const T* param,
T* trust_ratio_div)
const MT* param_;
MT* trust_ratio_div_;
const bool* skip_update_;
LambMomentMENUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
const MT* beta1_pow, MT* beta1_pow_out,
const MT* beta2_pow, MT* beta2_pow_out,
const MT* mom1, MT* mom1_out, const MT* mom2,
MT* mom2_out, const T* grad, const MT* param,
MT* trust_ratio_div, const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
......@@ -130,26 +145,29 @@ struct LambMomentMENUpdateFunctor {
moment2_out_(mom2_out),
grad_(grad),
param_(param),
trust_ratio_div_(trust_ratio_div) {}
trust_ratio_div_(trust_ratio_div),
skip_update_(skip_update) {}
inline HOSTDEVICE void operator()(size_t i) const {
T g = grad_[i];
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i];
if (skip_update_ && *skip_update_) return;
MT g = static_cast<MT>(grad_[i]);
MT mom1 = moment1_[i];
MT mom2 = moment2_[i];
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
MT p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
mom1 = beta1_ * mom1 + (static_cast<MT>(1) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<MT>(1) - beta2_) * g * g;
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
T mom1_unbiased = mom1 / (1 - beta1_pow);
T mom2_unbiased = mom2 / (1 - beta2_pow);
MT mom1_unbiased = mom1 / (static_cast<MT>(1) - beta1_pow);
MT mom2_unbiased = mom2 / (static_cast<MT>(1) - beta2_pow);
trust_ratio_div_[i] =
mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p;
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
......@@ -180,13 +198,15 @@ struct SparseLambMomentREGUpdateFunctor {
int64_t row_numel_;
int64_t row_count_;
const bool* skip_update_;
SparseLambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
T beta1_pow, T* beta1_pow_out, T beta2_pow,
T* beta2_pow_out, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* grad,
const T* param, T* trust_ratio_div,
const int64_t* rows, int64_t row_numel,
int64_t row_count)
int64_t row_count, const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
......@@ -204,7 +224,8 @@ struct SparseLambMomentREGUpdateFunctor {
trust_ratio_div_(trust_ratio_div),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
row_count_(row_count),
skip_update_(skip_update) {}
inline HOSTDEVICE void update(size_t i, T g) const {
// The following code is same as dense
......@@ -214,16 +235,17 @@ struct SparseLambMomentREGUpdateFunctor {
T beta2_pow = beta2_pow_;
T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
mom1 = beta1_ * mom1 + (static_cast<T>(1) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<T>(1) - beta2_) * g * g;
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
T mom1_unbiased = mom1 / (1 - beta1_pow);
T mom2_unbiased = mom2 / (1 - beta2_pow);
T mom1_unbiased = mom1 / (static_cast<T>(1) - beta1_pow);
T mom2_unbiased = mom2 / (static_cast<T>(1) - beta2_pow);
trust_ratio_div_[i] =
mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p;
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
......@@ -231,9 +253,11 @@ struct SparseLambMomentREGUpdateFunctor {
}
inline HOSTDEVICE void operator()(size_t i) const {
if (skip_update_ && *skip_update_) return;
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
update(i, g);
}
};
......@@ -261,13 +285,16 @@ struct SparseLambMomentMENUpdateFunctor {
int64_t row_numel_;
int64_t row_count_;
const bool* skip_update_;
SparseLambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
const T* beta1_pow, T* beta1_pow_out,
const T* beta2_pow, T* beta2_pow_out,
const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows,
int64_t row_numel, int64_t row_count)
int64_t row_numel, int64_t row_count,
const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
......@@ -285,7 +312,8 @@ struct SparseLambMomentMENUpdateFunctor {
trust_ratio_div_(trust_ratio_div),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
row_count_(row_count),
skip_update_(skip_update) {}
inline HOSTDEVICE void update(size_t i, T g) const {
// The following code is same as dense
......@@ -295,16 +323,17 @@ struct SparseLambMomentMENUpdateFunctor {
T beta2_pow = *beta2_pow_;
T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
mom1 = beta1_ * mom1 + (static_cast<T>(1) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<T>(1) - beta2_) * g * g;
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
T mom1_unbiased = mom1 / (1 - beta1_pow);
T mom2_unbiased = mom2 / (1 - beta2_pow);
T mom1_unbiased = mom1 / (static_cast<T>(1) - beta1_pow);
T mom2_unbiased = mom2 / (static_cast<T>(1) - beta2_pow);
trust_ratio_div_[i] =
mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p;
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
......@@ -312,40 +341,61 @@ struct SparseLambMomentMENUpdateFunctor {
}
inline HOSTDEVICE void operator()(size_t i) const {
if (skip_update_ && *skip_update_) return;
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
update(i, g);
}
};
template <typename T>
template <typename T, bool IsMultiPrecision>
struct LambParamUpateFunctor {
const T* lr_;
using MT = typename std::conditional<
IsMultiPrecision, typename details::MPTypeTrait<T>::Type, T>::type;
const MT* lr_;
const T* param_;
const T* param_norm_;
const T* trust_ratio_div_;
const T* trust_ratio_div_norm_;
const MT* master_param_;
const MT* param_norm_;
const MT* trust_ratio_div_;
const MT* trust_ratio_div_norm_;
T* param_out_;
MT* master_param_out_;
const bool* skip_update_;
LambParamUpateFunctor(const T* lr, const T* param, const T* param_norm,
const T* trust_ratio_div, const T* trust_ratio_div_norm,
T* param_out)
LambParamUpateFunctor(const MT* lr, const T* param, const MT* master_param,
const MT* param_norm, const MT* trust_ratio_div,
const MT* trust_ratio_div_norm, T* param_out,
MT* master_param_out, const bool* skip_update)
: lr_(lr),
param_(param),
master_param_(master_param),
param_norm_(param_norm),
trust_ratio_div_(trust_ratio_div),
trust_ratio_div_norm_(trust_ratio_div_norm),
param_out_(param_out) {}
param_out_(param_out),
master_param_out_(master_param_out),
skip_update_(skip_update) {}
inline HOSTDEVICE void operator()(size_t i) const {
T lr = *lr_;
T p = *param_norm_;
T t = *trust_ratio_div_norm_;
T r = (p > 0 && t > 0) ? p / t : 1.0;
if (skip_update_ && *skip_update_) return;
MT lr = *lr_;
MT pn = *param_norm_;
MT tn = *trust_ratio_div_norm_;
MT r = (pn > static_cast<MT>(0) && tn > static_cast<MT>(0))
? pn / tn
: static_cast<MT>(1);
lr *= r;
param_out_[i] = param_[i] - lr * trust_ratio_div_[i];
MT p = IsMultiPrecision ? master_param_[i] : static_cast<MT>(param_[i]);
MT param_out = p - lr * trust_ratio_div_[i];
param_out_[i] = static_cast<T>(param_out);
if (IsMultiPrecision) {
master_param_out_[i] = param_out;
}
}
};
......@@ -353,86 +403,146 @@ template <typename DeviceContext, typename T>
class LambOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
T weight_decay = static_cast<T>(ctx.Attr<float>("weight_decay"));
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input",
"Param", "Lamb");
auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment1"), "Input",
"Moment1", "Lamb");
auto& mom2 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment2"), "Input",
"Moment2", "Lamb");
auto& lr = GET_DATA_SAFELY(ctx.Input<LoDTensor>("LearningRate"), "Input",
"LearningRate", "Lamb");
auto& beta1_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta1Pow"), "Input",
"Beta1Pow", "Lamb");
auto& beta2_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta2Pow"), "Input",
"Beta2Pow", "Lamb");
auto& param_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("ParamOut"),
"Output", "ParamOut", "Lamb");
auto& mom1_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment1Out"),
"Output", "Moment1Out", "Lamb");
auto& mom2_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment2Out"),
"Output", "Moment2Out", "Lamb");
auto& beta1_pow_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta1PowOut"),
"Output", "Beta1PowOut", "Lamb");
auto& beta2_pow_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta2PowOut"),
"Output", "Beta2PowOut", "Lamb");
using MT = typename details::MPTypeTrait<T>::Type;
bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
ComputeImpl<MT, true>(ctx);
} else {
ComputeImpl<T, false>(ctx);
}
}
private:
template <typename MT, bool IsMultiPrecision>
void ComputeImpl(const framework::ExecutionContext& ctx) const {
if (!IsMultiPrecision) {
constexpr auto kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(
kIsSameType, true,
platform::errors::InvalidArgument(
"When multi_precision=False, T and MT must be the same type."));
}
const auto* skip_update = ctx.Input<framework::LoDTensor>("SkipUpdate");
const bool* skip_update_flag = skip_update && skip_update->IsInitialized()
? skip_update->data<bool>()
: nullptr;
if (skip_update_flag && platform::is_cpu_place(skip_update->place()) &&
(*skip_update_flag)) {
return;
}
auto weight_decay = static_cast<MT>(ctx.Attr<float>("weight_decay"));
auto beta1 = static_cast<MT>(ctx.Attr<float>("beta1"));
auto beta2 = static_cast<MT>(ctx.Attr<float>("beta2"));
auto epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
const auto& param = GET_DATA_SAFELY(
ctx.Input<framework::LoDTensor>("Param"), "Input", "Param", "Lamb");
const auto* grad_var = ctx.InputVar("Grad");
const auto& mom1 = GET_DATA_SAFELY(
ctx.Input<framework::LoDTensor>("Moment1"), "Input", "Moment1", "Lamb");
const auto& mom2 = GET_DATA_SAFELY(
ctx.Input<framework::LoDTensor>("Moment2"), "Input", "Moment2", "Lamb");
const auto& lr =
GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("LearningRate"),
"Input", "LearningRate", "Lamb");
const auto& beta1_pow =
GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("Beta1Pow"), "Input",
"Beta1Pow", "Lamb");
const auto& beta2_pow =
GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("Beta2Pow"), "Input",
"Beta2Pow", "Lamb");
auto& param_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("ParamOut"), "Output",
"ParamOut", "Lamb");
auto& mom1_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Moment1Out"),
"Output", "Moment1Out", "Lamb");
auto& mom2_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Moment2Out"),
"Output", "Moment2Out", "Lamb");
auto& beta1_pow_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Beta1PowOut"),
"Output", "Beta1PowOut", "Lamb");
auto& beta2_pow_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Beta2PowOut"),
"Output", "Beta2PowOut", "Lamb");
const auto* master_param =
IsMultiPrecision ? ctx.Input<framework::LoDTensor>("MasterParam")
: nullptr;
auto* master_param_out =
IsMultiPrecision ? ctx.Output<framework::LoDTensor>("MasterParamOut")
: nullptr;
if (IsMultiPrecision) {
PADDLE_ENFORCE_NOT_NULL(master_param,
platform::errors::InvalidArgument(
"Input(MasterParam) must be provided when "
"multi_precision=True."));
PADDLE_ENFORCE_NOT_NULL(master_param_out,
platform::errors::InvalidArgument(
"Output(MasterParamOut) must be provided "
"when multi_precision=True."));
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, param.numel());
framework::Tensor trust_ratio_div =
ctx.AllocateTmpTensor<T, DeviceContext>(param.dims(), dev_ctx);
auto trust_ratio_div =
ctx.AllocateTmpTensor<MT, DeviceContext>(param.dims(), dev_ctx);
const void* param_ptr = param.template data<void>();
const void* master_param_ptr =
master_param ? master_param->template data<void>() : nullptr;
void* param_out_ptr = param_out.template mutable_data<T>(ctx.GetPlace());
void* master_param_out_ptr =
master_param_out
? master_param_out->template mutable_data<MT>(ctx.GetPlace())
: nullptr;
// Update moments
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = *ctx.Input<LoDTensor>("Grad");
auto& grad = grad_var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(ctx.GetPlace()) &&
beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
LambMomentREGUpdateFunctor<T> moment_update_functor(
weight_decay, beta1, beta2, epsilon, *beta1_pow.template data<T>(),
nullptr, *beta2_pow.template data<T>(), nullptr,
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
grad.template data<T>(), param.template data<T>(),
trust_ratio_div.template data<T>());
LambMomentREGUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, *beta1_pow.template data<MT>(),
nullptr, *beta2_pow.template data<MT>(), nullptr,
mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
grad.template data<T>(),
static_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
trust_ratio_div.template data<MT>(), skip_update_flag);
for_range(moment_update_functor);
beta1_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
beta1 * beta1_pow.template data<T>()[0];
beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<T>()[0];
beta1_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
beta1 * beta1_pow.template data<MT>()[0];
beta2_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<MT>()[0];
} else {
LambMomentMENUpdateFunctor<T> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow.template data<T>(),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
grad.template data<T>(), param.template data<T>(),
trust_ratio_div.template data<T>());
LambMomentMENUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<MT>(),
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace()),
beta2_pow.template data<MT>(),
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace()),
mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
grad.template data<T>(),
static_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
trust_ratio_div.template data<MT>(), skip_update_flag);
for_range(moment_update_functor);
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_EQ(IsMultiPrecision, false,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True"));
auto& grad = GET_DATA_SAFELY(ctx.Input<framework::SelectedRows>("Grad"),
"Input", "Grad", "Lamb");
if (grad.rows().size() == 0) {
......@@ -470,22 +580,25 @@ class LambOpKernel : public framework::OpKernel<T> {
beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
SparseLambMomentREGUpdateFunctor<T> moment_update_functor(
weight_decay, beta1, beta2, epsilon, *beta1_pow.template data<T>(),
nullptr, *beta2_pow.template data<T>(), nullptr,
mom1.template data<T>(),
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
*beta1_pow.template data<T>(), nullptr,
*beta2_pow.template data<T>(), nullptr, mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
param.template data<T>(), trust_ratio_div.template data<T>(), rows,
row_numel, grad_merge.rows().size());
row_numel, grad_merge.rows().size(), skip_update_flag);
for_range(moment_update_functor);
beta1_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
beta1 * beta1_pow.template data<T>()[0];
static_cast<T>(beta1) * beta1_pow.template data<T>()[0];
beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<T>()[0];
static_cast<T>(beta2) * beta2_pow.template data<T>()[0];
} else {
SparseLambMomentMENUpdateFunctor<T> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<T>(),
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
beta1_pow.template data<T>(),
beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow.template data<T>(),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
......@@ -494,36 +607,45 @@ class LambOpKernel : public framework::OpKernel<T> {
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
param.template data<T>(), trust_ratio_div.template data<T>(), rows,
row_numel, grad_merge.rows().size());
row_numel, grad_merge.rows().size(), skip_update_flag);
for_range(moment_update_functor);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type not supported by lamb_op. Expect LoDTensor or "
"SelectedRows, but got %s",
framework::ToTypeName(param_var->Type())));
framework::ToTypeName(grad_var->Type())));
}
// Update parameter
framework::Tensor p_norm_t =
ctx.AllocateTmpTensor<T, DeviceContext>({1}, dev_ctx);
framework::Tensor trust_ratio_div_norm_t =
ctx.AllocateTmpTensor<T, DeviceContext>({1}, dev_ctx);
auto p_norm = framework::EigenScalar<T>::From(p_norm_t);
auto trust_ratio_div_norm =
framework::EigenScalar<T>::From(trust_ratio_div_norm_t);
auto p_norm_t = ctx.AllocateTmpTensor<MT, DeviceContext>({1}, dev_ctx);
auto trust_ratio_div_norm_t =
ctx.AllocateTmpTensor<MT, DeviceContext>({1}, dev_ctx);
auto p = framework::EigenVector<T>::Flatten(param);
auto t = framework::EigenVector<T>::Flatten(trust_ratio_div);
auto p_norm = framework::EigenScalar<MT>::From(p_norm_t);
auto trust_ratio_div_norm =
framework::EigenScalar<MT>::From(trust_ratio_div_norm_t);
auto t = framework::EigenVector<MT>::Flatten(trust_ratio_div);
// TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true.
auto* place = dev_ctx.eigen_device();
p_norm.device(*place) = p.square().sum().sqrt();
if (IsMultiPrecision) {
auto mp = framework::EigenVector<MT>::Flatten(*master_param);
p_norm.device(*place) = mp.square().sum().sqrt();
} else {
auto p = framework::EigenVector<MT>::Flatten(param);
p_norm.device(*place) = p.square().sum().sqrt();
}
trust_ratio_div_norm.device(*place) = t.square().sum().sqrt();
LambParamUpateFunctor<T> param_update_functor(
lr.template data<T>(), param.template data<T>(),
p_norm_t.template data<T>(), trust_ratio_div.template data<T>(),
trust_ratio_div_norm_t.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
LambParamUpateFunctor<T, IsMultiPrecision> param_update_functor(
lr.template data<MT>(), static_cast<const T*>(param_ptr),
static_cast<const MT*>(master_param_ptr), p_norm_t.template data<MT>(),
trust_ratio_div.template data<MT>(),
trust_ratio_div_norm_t.template data<MT>(),
static_cast<T*>(param_out_ptr), static_cast<MT*>(master_param_out_ptr),
skip_update_flag);
for_range(param_update_functor);
}
};
......
......@@ -72,6 +72,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"adamw",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
{"lamb",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
{"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
};
......@@ -112,8 +115,6 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"run_program", {"DOut"}},
{"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
......@@ -121,6 +122,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......@@ -142,6 +146,9 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"average_accumulates",
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}},
......@@ -173,8 +180,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
{"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}},
{"run_program", {"Out", "DOut", "OutScope"}},
{"clear_float_status", {"FloatStatusOut"}},
......
......@@ -371,6 +371,20 @@ class ClipGradByNorm(ClipGradBase):
return param, new_grad
_allow_pure_fp16_global_norm_clip_flag = False
def _allow_pure_fp16_global_norm_clip(*args):
global _allow_pure_fp16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_fp16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_fp16_global_norm_clip_flag
_allow_pure_fp16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase):
r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
......@@ -537,8 +551,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
global_norm_var.append(
global_norm_var_fp16.astype(sum_dtype))
if sum_square_list_fp32 or sum_square_list or not _allow_pure_fp16_global_norm_clip(
):
global_norm_var.append(
global_norm_var_fp16.astype(sum_dtype))
else:
global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
if sum_dtype == 'float32':
......@@ -573,8 +591,9 @@ class ClipGradByGlobalNorm(ClipGradBase):
with p.block.program._optimized_guard([p, g]):
# inplace
scale_input = (scale_var.astype('float16')
if g.dtype == core.VarDesc.VarType.FP16 else
scale_var)
if g.dtype == core.VarDesc.VarType.FP16 and
scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter
......
......@@ -411,6 +411,8 @@ class OptimizerWithMixedPrecision(object):
found_inf = paddle.tensor.creation._memcpy(found_inf,
paddle.CPUPlace())
real_optimizer._set_auxiliary_var('found_inf', found_inf)
elif hasattr(real_optimizer, "_set_auxiliary_var"):
real_optimizer._set_auxiliary_var('found_inf', found_inf)
optimize_ops = self._optimizer.apply_gradients(params_grads)
return optimize_ops
......
......@@ -80,11 +80,27 @@ def _dtype_to_str(dtype):
return 'fp32'
_keep_layer_norm_scale_bias_to_fp32_flag = True
def _keep_layer_norm_scale_bias_to_fp32(*args):
global _keep_layer_norm_scale_bias_to_fp32_flag
if len(args) == 0:
return _keep_layer_norm_scale_bias_to_fp32_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _keep_layer_norm_scale_bias_to_fp32_flag
_keep_layer_norm_scale_bias_to_fp32_flag = args[0]
return old_value
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type in ['batch_norm', 'layer_norm']:
if op_type == 'batch_norm':
# Scale, Bias, Mean, Variance should be float32.
return in_name != 'X'
if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
return in_name != 'X'
if op_type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit':
......@@ -98,7 +114,9 @@ def _keep_fp32_input(op, in_name):
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']:
if op_type in ['batch_norm', 'fused_bn_add_activation']:
return out_name != 'Y'
if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
return out_name != 'Y'
if op_type == 'resnet_unit':
return out_name not in {'Y', 'ConvX', 'ConvZ'}
......
......@@ -3554,14 +3554,14 @@ class LambOptimizer(AdamOptimizer):
else:
weight_decay = self._weight_decay
lr = self._create_param_lr(param_and_grad)
master_weight = None
if framework.in_dygraph_mode():
_, _, _, _, _ = _C_ops.lamb(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay',
weight_decay)
_C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1,
moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
param_and_grad[0], moment1, moment2, beta1_pow_acc,
beta2_pow_acc, master_weight, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon,
'weight_decay', weight_decay)
return None
# create the lamb optimize op
......
......@@ -21,6 +21,7 @@ import paddle.fluid.core as core
import paddle.fluid as fluid
import six
from fake_reader import fake_imdb_reader
from paddle.fluid.clip import _allow_pure_fp16_global_norm_clip
paddle.enable_static()
......@@ -566,5 +567,35 @@ class TestDygraphGradientClipFP64(unittest.TestCase):
% (a, b))
class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase):
def check_main(self, expected_has_cast_op):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
names = ["p0", "p1"]
shapes = [[2, 3], [4, 5]]
param_and_grads = []
main_block = main_prog.global_block()
for name, shape in zip(names, shapes):
p = main_block.create_parameter(
name=name, shape=shape, dtype='float16')
g = main_block.create_parameter(
name=p.name + '@GRAD', shape=p.shape, dtype=p.dtype)
param_and_grads.append((p, g))
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
clip(param_and_grads)
actual_has_cast = any(op.type == 'cast' for op in main_block.ops)
self.assertEqual(actual_has_cast, expected_has_cast_op)
def test_main(self):
self.check_main(True)
_allow_pure_fp16_global_norm_clip(True)
self.check_main(False)
_allow_pure_fp16_global_norm_clip(False)
self.check_main(True)
if __name__ == '__main__':
unittest.main()
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
from paddle.fluid import core
from paddle.fluid.op import Operator
from paddle.fluid.dygraph.base import switch_to_static_graph
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
......@@ -181,5 +182,58 @@ class TestLambOpV2Group(TestLambOpV2):
adam.clear_gradients()
class TestLambOpMultiPrecision(unittest.TestCase):
def check_main(self, x_np, place, multi_precision=False, seed=10, n=10):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
paddle.seed(seed)
with paddle.static.amp.fp16_guard():
x = paddle.static.data(
name='x', shape=[None, 10], dtype='float32')
linear = paddle.nn.Linear(10, 2)
hidden = linear(x)
loss = paddle.mean(hidden)
optimizer = paddle.optimizer.Lamb(learning_rate=1e-3)
optimizer._multi_precision = multi_precision
if multi_precision:
optimizer = paddle.static.amp.decorate(
optimizer, use_pure_fp16=True, use_fp16_guard=True)
optimizer.minimize(loss)
weight, bias = linear.weight, linear.bias
scope = paddle.static.Scope()
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
x = main_prog.global_block().var(x.name)
if x.dtype == core.VarDesc.VarType.FP16:
x_np = x_np.astype(np.float16)
with paddle.static.scope_guard(scope):
exe.run(startup_prog)
if multi_precision:
optimizer.amp_init(place)
weight_np, bias_np = None, None
for i in range(n):
feed_dict = {x.name: x_np}
weight_np, bias_np = exe.run(main_prog,
feed=feed_dict,
fetch_list=[weight, bias])
return weight_np.astype('float32'), bias_np.astype('float32')
@switch_to_static_graph
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
place = paddle.CUDAPlace(0)
x_np = np.random.random(size=[5, 10]).astype('float32')
weight_1, bias_1 = self.check_main(x_np, place, multi_precision=False)
weight_2, bias_2 = self.check_main(x_np, place, multi_precision=True)
self.assertTrue(np.all(np.abs(weight_1 - weight_2) < 1e-3))
self.assertTrue(np.all(np.abs(bias_1 - bias_2) < 1e-7))
if __name__ == "__main__":
unittest.main()
......@@ -20,9 +20,13 @@ import paddle
from operator import mul
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.nn.functional as F
from functools import reduce
from op_test import _set_use_system_allocator
from paddle.fluid import Program, program_guard
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32
paddle.enable_static()
np.random.random(123)
......@@ -325,5 +329,58 @@ class TestDygraphLayerNormAPIError(unittest.TestCase):
self.assertRaises(TypeError, layer_norm, x2)
class TestFP16ScaleBiasLayerNorm(unittest.TestCase):
def check_main(self, x_np, weight_np, bias_np, dtype):
paddle.disable_static()
weight_np = weight_np.astype(dtype)
bias_np = bias_np.astype(dtype)
x = paddle.to_tensor(x_np)
weight = paddle.to_tensor(weight_np)
bias = paddle.to_tensor(bias_np)
x.stop_gradient = False
weight.stop_gradient = False
bias.stop_gradient = False
y = F.layer_norm(x, x.shape[1:], weight, bias)
x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])
y_np = y.numpy().astype('float32')
x_g_np = x_g.numpy().astype('float32')
w_g_np = w_g.numpy().astype('float16')
b_g_np = b_g.numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np, w_g_np, b_g_np
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
x_np = np.random.random([10, 20]).astype('float16')
weight_np = np.random.random([20]).astype('float16')
bias_np = np.random.random([20]).astype('float16')
y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
x_np, weight_np, bias_np, 'float16')
y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
x_np, weight_np, bias_np, 'float32')
def assert_equal(x, y):
self.assertTrue(np.array_equal(x, y))
assert_equal(y_np_1, y_np_2)
assert_equal(x_g_np_1, x_g_np_2)
assert_equal(w_g_np_1, w_g_np_2)
assert_equal(b_g_np_1, b_g_np_2)
class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase):
def test_main(self):
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
_keep_layer_norm_scale_bias_to_fp32(False)
self.assertFalse(_keep_layer_norm_scale_bias_to_fp32())
_keep_layer_norm_scale_bias_to_fp32(True)
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,9 @@ from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
from ..fluid.framework import Variable
from ..fluid import layers
from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper
from paddle import _C_ops
__all__ = []
......@@ -127,6 +130,36 @@ class Lamb(Optimizer):
'lamb_weight_decay': lamb_weight_decay,
'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
}
self._master_weights = {}
# TODO(zengjinle): expose API as soon as possible
self._multi_precision = False
def _create_master_weight(self, param):
assert self._multi_precision
if param.name in self._master_weights:
var = self._master_weights[param.name]
else:
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
self._master_weights[param.name] = var
return var
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
......@@ -135,18 +168,51 @@ class Lamb(Optimizer):
# Create accumulator tensors for first and second moments
for p in parameters:
self._add_accumulator(self._moment1_acc_str, p)
self._add_accumulator(self._moment2_acc_str, p)
self._add_accumulator(
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
else:
self._add_moments_pows(p)
def _get_accumulator(self, name, param):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if self._name is not None:
name = self._name + "_" + name
find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
target_param = self._master_weights[
param.name] if find_master else param
target_name = target_param.name
if (name not in self._accumulators or
target_name not in self._accumulators[name]):
raise Exception("Accumulator {} does not exist for parameter {}".
format(name, target_name))
return self._accumulators[name][target_name]
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16:
acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
self._add_accumulator(
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
......@@ -175,13 +241,20 @@ class Lamb(Optimizer):
weight_decay = self._lamb_weight_decay
lr = self._create_param_lr(param_and_grad)
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = self._master_weights[param_and_grad[0]
.name] if find_master else None
found_inf = self._get_auxiliary_var('found_inf')
if framework.in_dygraph_mode():
_, _, _, _, _ = _C_ops.lamb(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay',
weight_decay)
_C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1,
moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
param_and_grad[0], moment1, moment2, beta1_pow_acc,
beta2_pow_acc, master_weight, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon,
'weight_decay', weight_decay, 'multi_precision',
find_master)
return None
# create the lamb optimize op
......@@ -205,9 +278,17 @@ class Lamb(Optimizer):
"beta1": self._beta1,
"beta2": self._beta2,
"epsilon": self._epsilon,
"weight_decay": weight_decay
"weight_decay": weight_decay,
"multi_precision": find_master,
}
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
if found_inf:
inputs["SkipUpdate"] = found_inf
lamb_op = block.append_op(
type=self.type,
inputs=inputs,
......
......@@ -217,6 +217,14 @@ class Optimizer(object):
else:
self._param_groups = self._parameter_list
self._auxiliary_vars = {}
def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val
def _get_auxiliary_var(self, key):
return self._auxiliary_vars.get(key, None)
@framework.dygraph_only
def state_dict(self):
'''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册