未验证 提交 d80fe268 编写于 作者: S sneaxiy 提交者: GitHub

Refine some AMP operators for BERT (#37923)

* support multi precision update for LAMB

* hide some api

* fix ci uts

* fix lamb output of dygraph

* remove some changes to some PR

* try to fix Py3 CI compile error

* fix test_imperative_optimizer, add lars ut, add layer_norm ut

* fix ut, fix format

* fix ut

* fix windows ci
上级 cff03734
......@@ -169,10 +169,16 @@ __inline__ __device__ half rsqrt_(const half val) {
}
#endif
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
T *y, U *mean, U *var, float epsilon,
int64_t feature_size) {
template <typename T, typename U, bool ScaleBiasWithSameTypeX>
using LayerNormScaleBiasT =
typename std::conditional<ScaleBiasWithSameTypeX, T, U>::type;
template <typename T, typename U, int BlockDim,
bool ScaleBiasWithSameTypeX = false>
__global__ void LayerNormForward(
const T *x, const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *bias, T *y,
U *mean, U *var, float epsilon, int64_t feature_size) {
__shared__ U mean_share;
__shared__ U var_share;
__shared__ U shared_mean[32]; // threadIdx.x / warpSize <= kMaxBlockDim /
......@@ -212,14 +218,15 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
if (bias != nullptr) {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(
scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
y[i] = static_cast<T>(static_cast<U>(scale[j]) *
(static_cast<U>(x[i]) - mean_val) * invvar +
static_cast<U>(bias[j]));
}
} else {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
invvar);
y[i] = static_cast<T>(static_cast<U>(scale[j]) *
(static_cast<U>(x[i]) - mean_val) * invvar);
}
}
} else { // scale == nullptr
......@@ -227,7 +234,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
bias[j]);
static_cast<U>(bias[j]));
}
} else {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
......@@ -336,12 +343,15 @@ __global__ void LayerNormBackwardPartGradGammaBeta(
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX>
__global__ void LayerNormBackwardSumGradGammaBeta(
const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
// const int n1, const int n2, T* grad_gamma, T* grad_beta) {
const int n1, const int n2, U *grad_gamma, U *grad_beta) {
const int n1, const int n2,
LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_gamma,
LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_beta) {
// sum partial gradients for gamma and beta
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX>;
__shared__ U buf[BDIMX * BDIMY];
int64_t i2 = blockIdx.x * BDIMX + threadIdx.x;
if (i2 < n2) {
......@@ -378,20 +388,18 @@ __global__ void LayerNormBackwardSumGradGammaBeta(
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
grad_gamma[i2] = static_cast<ScaleBiasT>(sum_gamma);
grad_beta[i2] = static_cast<ScaleBiasT>(sum_beta);
}
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX>
__global__ void LayerNormBackwardComputeGradInput(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2,
// const U* __restrict__ mean, const U* __restrict__ var, const float
// epsilon, const T* gamma,
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) {
const int n2, const U *__restrict__ mean, const U *__restrict__ var,
const float epsilon,
const LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *gamma, T *grad_input) {
#ifdef __HIPCC__
for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) {
#else
......@@ -411,15 +419,17 @@ __global__ void LayerNormBackwardComputeGradInput(
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);
sum_loss2 +=
c_loss * static_cast<U>(gamma[l + k]) * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
sum_loss1 += c_loss * static_cast<U>(gamma[l]);
sum_loss2 +=
c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
......@@ -491,7 +501,7 @@ __global__ void LayerNormBackwardComputeGradInput(
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
......@@ -513,11 +523,17 @@ __global__ void LayerNormBackwardComputeGradInput(
// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
template <typename T, typename U, int BlockDim, bool HasDx>
template <typename T, typename U, int BlockDim, bool HasDx,
bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientAll(
const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
const U *var, const U *scale, float epsilon, int64_t batch_size,
int64_t feature_size, int64_t col_offset) {
const T *x, const T *d_y,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, T *d_x,
const U *mean, const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t batch_size, int64_t feature_size,
int64_t col_offset) {
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
int64_t beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
int64_t end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
int64_t stride = BlockDim * feature_size;
......@@ -532,7 +548,8 @@ __global__ void LayerNormBackwardGradientAll(
d_bias_partial += static_cast<U>(d_y[i]);
if (HasDx) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val);
static_cast<U>(scale[blockIdx.x + col_offset]) /
var_val);
}
}
......@@ -543,19 +560,24 @@ __global__ void LayerNormBackwardGradientAll(
d_bias_partial = BlockReduceSum<U>(d_bias_partial, shared_bias);
if (threadIdx.x == 0) {
d_scale[blockIdx.x + col_offset] = d_scale_partial;
d_bias[blockIdx.x + col_offset] = d_bias_partial;
d_scale[blockIdx.x + col_offset] = static_cast<ScaleBiasT>(d_scale_partial);
d_bias[blockIdx.x + col_offset] = static_cast<ScaleBiasT>(d_bias_partial);
}
}
// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale,
bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientScaleOrBias(
const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
const U *var, const U *scale, float epsilon, int64_t batch_size,
int64_t feature_size, int col_offset) {
const T *x, const T *d_y,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, T *d_x,
const U *mean, const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t batch_size, int64_t feature_size, int col_offset) {
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
using BlockReduce = cub::BlockReduce<U, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
......@@ -578,7 +600,8 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if (HasDx) {
if (scale != nullptr) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val);
static_cast<U>(scale[blockIdx.x + col_offset]) /
var_val);
} else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
}
......@@ -590,9 +613,11 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if (threadIdx.x == 0) {
if (HasDScale) {
d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
d_scale[blockIdx.x + col_offset] =
static_cast<ScaleBiasT>(d_scale_or_d_bias_partial);
} else {
d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
d_bias[blockIdx.x + col_offset] =
static_cast<ScaleBiasT>(d_scale_or_d_bias_partial);
}
}
}
......@@ -640,12 +665,12 @@ __global__ void LayerNormBackwardPostProcessToCalculateDX(
}
// Here, we only calculate d_x
template <typename T, typename U, int BlockDim>
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
T *d_x, const U *mean,
const U *var, const U *scale,
float epsilon,
int64_t feature_size) {
template <typename T, typename U, int BlockDim, bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientOnlyDX(
const T *x, const T *d_y, T *d_x, const U *mean, const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t feature_size) {
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U d_x_reduce_tmp[2];
......@@ -660,8 +685,8 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
if (scale != nullptr) {
int col_idx = i % feature_size;
d_x[i] =
static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val);
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[col_idx]) / var_val);
} else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
}
......@@ -692,11 +717,16 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
}
}
template <typename T, typename U>
template <typename T, typename U, bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
const U *var, const U *scale, float epsilon, int64_t feature_size) {
const T *x, const T *d_y, T *d_x,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, const U *mean,
const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t feature_size) {
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
if (idx < feature_size) {
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon));
......@@ -704,25 +734,31 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
if (d_scale == nullptr) {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
} else {
d_x[idx] =
static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val);
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) *
static_cast<U>(scale[idx]) / var_val);
}
}
if (d_scale != nullptr) {
d_scale[idx] = static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[0]) / var_val;
d_scale[idx] =
static_cast<ScaleBiasT>(static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[0]) / var_val);
}
if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
if (d_bias != nullptr) {
d_bias[idx] = static_cast<ScaleBiasT>(d_y[idx]);
}
}
}
template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
const U *mean, const U *var, T *d_x, U *d_scale,
U *d_bias, float epsilon, int64_t batch_size,
int64_t feature_size,
template <typename T, typename U, bool ScaleBiasWithSameTypeX = false>
static void LayerNormBackward(
const T *x, const T *d_y,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const U *mean, const U *var, T *d_x,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, float epsilon,
int64_t batch_size, int64_t feature_size,
const platform::CUDADeviceContext &dev_ctx) {
auto stream = dev_ctx.stream();
#ifdef __HIPCC__
......@@ -737,10 +773,10 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
if (gradient_flag == 0) return;
if (batch_size == 1) {
LayerNormBackwardWhenBatchSizeIsOne<
T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim,
0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale,
epsilon, feature_size);
LayerNormBackwardWhenBatchSizeIsOne<T, U, ScaleBiasWithSameTypeX><<<
(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0,
stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
feature_size);
if (d_x != nullptr) {
switch (GetDesiredBlockDim(feature_size)) {
......@@ -759,8 +795,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, false,
false><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, false, false,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -770,8 +806,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, false,
true><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, false, true,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -781,7 +817,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientAll<
T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, false,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -790,7 +827,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardGradientOnlyDX<
T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
T, U, kBlockDim,
ScaleBiasWithSameTypeX><<<batch_size, kBlockDim, 0, stream>>>(
x, d_y, d_x, mean, var, scale, epsilon, feature_size));
}
break;
......@@ -799,8 +837,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, true,
false><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, true, false,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -816,8 +854,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, true,
true><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, true, true,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -854,7 +892,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
dim3 threads3(BDIMX3, BDIMY3, 1);
const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
LayerNormBackwardSumGradGammaBeta<
T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>(
T, U, BDIMX3, BDIMY3,
ScaleBiasWithSameTypeX><<<blocks3, threads3, 0, stream>>>(
part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
d_scale, d_bias);
......@@ -862,7 +901,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1);
LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<batch_size, threads1, 0, stream>>>(
T, U, BDIMX1, BDIMY1,
ScaleBiasWithSameTypeX><<<batch_size, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break;
}
......
......@@ -102,24 +102,6 @@ class LayerNormOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto ln_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
ln_param_type = framework::proto::VarType::FP64;
}
if (ctx.HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
}
if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
......
......@@ -63,8 +63,32 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
auto *var_data = var->mutable_data<U>(ctx.GetPlace());
auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>());
auto *void_scale_data = (scale == nullptr ? nullptr : scale->data<void>());
auto *void_bias_data = (bias == nullptr ? nullptr : bias->data<void>());
framework::proto::VarType::Type x_dtype = x->type();
framework::proto::VarType::Type scale_bias_dtype;
if (void_scale_data != nullptr) {
scale_bias_dtype = scale->type();
if (void_bias_data != nullptr) {
PADDLE_ENFORCE_EQ(scale_bias_dtype, bias->type(),
platform::errors::InvalidArgument(
"Thie Scale and Bias of layer_norm op "
"should have the same data type."));
}
} else {
scale_bias_dtype = (void_bias_data != nullptr ? bias->type() : x_dtype);
}
bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
if (!is_scale_bias_same_dtype_with_x) {
PADDLE_ENFORCE_EQ(scale_bias_dtype,
framework::DataTypeTrait<U>::DataType(),
platform::errors::InvalidArgument(
"Unsupported data type of Scale and Bias: %s",
framework::DataTypeToString(scale_bias_dtype)));
}
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
......@@ -72,17 +96,28 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
auto stream = ctx.cuda_device_context().stream();
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, U,
kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x_data, scale_data, bias_data, y_data, mean_data, var_data,
epsilon, feature_size));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Product from begin_norm_axis to end must be larger than 1"));
break;
#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
switch (GetDesiredBlockDim(feature_size)) { \
FIXED_BLOCK_DIM_CASE( \
LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX><<< \
batch_size, kBlockDim, 0, stream>>>( \
x_data, static_cast<const ScaleBiasT *>(void_scale_data), \
static_cast<const ScaleBiasT *>(void_bias_data), y_data, \
mean_data, var_data, epsilon, feature_size)); \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Product from begin_norm_axis to end must be larger than 1")); \
break; \
} \
} while (0)
if (is_scale_bias_same_dtype_with_x) {
PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
}
#undef PADDLE_LAUNCH_LAYERNORM_FWD
}
};
......@@ -102,32 +137,64 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
auto *mean = ctx.Input<Tensor>("Mean");
auto *var = ctx.Input<Tensor>("Variance");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto *x_data = x->data<T>();
auto *d_y_data = d_y->data<T>();
auto *mean_data = mean->data<U>();
auto *var_data = var->data<U>();
auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *d_scale_data =
(d_scale == nullptr ? nullptr
: d_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_bias_data =
(d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
auto *d_x_data =
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
framework::proto::VarType::Type x_dtype = x->type();
framework::proto::VarType::Type scale_bias_dtype;
if (scale != nullptr) {
scale_bias_dtype = scale->type();
} else {
// FIXME(zengjinle): do not find a better way to get the right
// data type of the d_scale and d_bias if scale == nullptr.
auto *bias = ctx.Input<Tensor>("Bias");
if (bias != nullptr) {
scale_bias_dtype = bias->saved_type();
} else {
scale_bias_dtype = x_dtype;
}
}
#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
auto *scale_data = \
(scale == nullptr ? nullptr : scale->data<ScaleBiasT>()); \
auto *d_scale_data = \
(d_scale == nullptr ? nullptr : d_scale->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_bias_data = \
(d_bias == nullptr ? nullptr : d_bias->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_x_data = \
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace())); \
LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>( \
x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, \
d_scale_data, d_bias_data, epsilon, batch_size, feature_size, \
ctx.cuda_device_context()); \
} while (0)
if (scale_bias_dtype == x_dtype) {
PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
}
LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
d_x_data, d_scale_data, d_bias_data, epsilon,
batch_size, feature_size,
ctx.cuda_device_context());
#undef PADDLE_LAUNCH_LAYERNORM_BWD
}
};
......
......@@ -152,6 +152,14 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Moment2", "(Tensor) Input second moment.");
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator.");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator.");
AddInput("MasterParam",
"(LoDTensor, default LoDTensor<float>) "
"Input master parameter that has to be updated.")
.AsDispensable();
AddInput(
"SkipUpdate",
"(Tensor) Input tensor to determine whether to update the parameter.")
.AsDispensable();
AddOutput("ParamOut", "(Tensor) Output parameter.");
AddOutput("Moment1Out", "(Tensor) Output first moment.");
......@@ -160,6 +168,8 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable();
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator")
.AsDispensable();
AddOutput("MasterParamOut", "(Tensor) Output master parameter.")
.AsDispensable();
AddAttr<float>("weight_decay", "(float) Weight decay rate.");
AddAttr<float>("beta1",
"(float, default 0.9) The exponential decay rate for the "
......@@ -173,6 +183,10 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
"(float, default 1.0e-6) "
"Constant for numerical stability.")
.SetDefault(1.0e-6f);
AddAttr<bool>(
"multi_precision",
"(bool, default false) Whether to enable multi-precision mode.")
.SetDefault(false);
AddComment(R"DOC(
LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.
......
......@@ -16,5 +16,7 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lamb, ops::LambOpKernel<paddle::platform::CUDADeviceContext, float>,
lamb, ops::LambOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::LambOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::LambOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -72,6 +72,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"adamw",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
{"lamb",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
{"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
};
......@@ -112,8 +115,6 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"run_program", {"DOut"}},
{"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
......@@ -121,6 +122,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......@@ -142,6 +146,9 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"average_accumulates",
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}},
......@@ -173,8 +180,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
{"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}},
{"run_program", {"Out", "DOut", "OutScope"}},
{"clear_float_status", {"FloatStatusOut"}},
......
......@@ -371,6 +371,20 @@ class ClipGradByNorm(ClipGradBase):
return param, new_grad
_allow_pure_fp16_global_norm_clip_flag = False
def _allow_pure_fp16_global_norm_clip(*args):
global _allow_pure_fp16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_fp16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_fp16_global_norm_clip_flag
_allow_pure_fp16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase):
r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
......@@ -537,8 +551,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
if sum_square_list_fp32 or sum_square_list or not _allow_pure_fp16_global_norm_clip(
):
global_norm_var.append(
global_norm_var_fp16.astype(sum_dtype))
else:
global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
if sum_dtype == 'float32':
......@@ -573,8 +591,9 @@ class ClipGradByGlobalNorm(ClipGradBase):
with p.block.program._optimized_guard([p, g]):
# inplace
scale_input = (scale_var.astype('float16')
if g.dtype == core.VarDesc.VarType.FP16 else
scale_var)
if g.dtype == core.VarDesc.VarType.FP16 and
scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter
......
......@@ -411,6 +411,8 @@ class OptimizerWithMixedPrecision(object):
found_inf = paddle.tensor.creation._memcpy(found_inf,
paddle.CPUPlace())
real_optimizer._set_auxiliary_var('found_inf', found_inf)
elif hasattr(real_optimizer, "_set_auxiliary_var"):
real_optimizer._set_auxiliary_var('found_inf', found_inf)
optimize_ops = self._optimizer.apply_gradients(params_grads)
return optimize_ops
......
......@@ -80,11 +80,27 @@ def _dtype_to_str(dtype):
return 'fp32'
_keep_layer_norm_scale_bias_to_fp32_flag = True
def _keep_layer_norm_scale_bias_to_fp32(*args):
global _keep_layer_norm_scale_bias_to_fp32_flag
if len(args) == 0:
return _keep_layer_norm_scale_bias_to_fp32_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _keep_layer_norm_scale_bias_to_fp32_flag
_keep_layer_norm_scale_bias_to_fp32_flag = args[0]
return old_value
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type in ['batch_norm', 'layer_norm']:
if op_type == 'batch_norm':
# Scale, Bias, Mean, Variance should be float32.
return in_name != 'X'
if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
return in_name != 'X'
if op_type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit':
......@@ -98,7 +114,9 @@ def _keep_fp32_input(op, in_name):
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']:
if op_type in ['batch_norm', 'fused_bn_add_activation']:
return out_name != 'Y'
if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
return out_name != 'Y'
if op_type == 'resnet_unit':
return out_name not in {'Y', 'ConvX', 'ConvZ'}
......
......@@ -3554,14 +3554,14 @@ class LambOptimizer(AdamOptimizer):
else:
weight_decay = self._weight_decay
lr = self._create_param_lr(param_and_grad)
master_weight = None
if framework.in_dygraph_mode():
_, _, _, _, _ = _C_ops.lamb(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay',
weight_decay)
_C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1,
moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
param_and_grad[0], moment1, moment2, beta1_pow_acc,
beta2_pow_acc, master_weight, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon,
'weight_decay', weight_decay)
return None
# create the lamb optimize op
......
......@@ -21,6 +21,7 @@ import paddle.fluid.core as core
import paddle.fluid as fluid
import six
from fake_reader import fake_imdb_reader
from paddle.fluid.clip import _allow_pure_fp16_global_norm_clip
paddle.enable_static()
......@@ -566,5 +567,35 @@ class TestDygraphGradientClipFP64(unittest.TestCase):
% (a, b))
class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase):
def check_main(self, expected_has_cast_op):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
names = ["p0", "p1"]
shapes = [[2, 3], [4, 5]]
param_and_grads = []
main_block = main_prog.global_block()
for name, shape in zip(names, shapes):
p = main_block.create_parameter(
name=name, shape=shape, dtype='float16')
g = main_block.create_parameter(
name=p.name + '@GRAD', shape=p.shape, dtype=p.dtype)
param_and_grads.append((p, g))
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
clip(param_and_grads)
actual_has_cast = any(op.type == 'cast' for op in main_block.ops)
self.assertEqual(actual_has_cast, expected_has_cast_op)
def test_main(self):
self.check_main(True)
_allow_pure_fp16_global_norm_clip(True)
self.check_main(False)
_allow_pure_fp16_global_norm_clip(False)
self.check_main(True)
if __name__ == '__main__':
unittest.main()
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
from paddle.fluid import core
from paddle.fluid.op import Operator
from paddle.fluid.dygraph.base import switch_to_static_graph
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
......@@ -181,5 +182,58 @@ class TestLambOpV2Group(TestLambOpV2):
adam.clear_gradients()
class TestLambOpMultiPrecision(unittest.TestCase):
def check_main(self, x_np, place, multi_precision=False, seed=10, n=10):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
paddle.seed(seed)
with paddle.static.amp.fp16_guard():
x = paddle.static.data(
name='x', shape=[None, 10], dtype='float32')
linear = paddle.nn.Linear(10, 2)
hidden = linear(x)
loss = paddle.mean(hidden)
optimizer = paddle.optimizer.Lamb(learning_rate=1e-3)
optimizer._multi_precision = multi_precision
if multi_precision:
optimizer = paddle.static.amp.decorate(
optimizer, use_pure_fp16=True, use_fp16_guard=True)
optimizer.minimize(loss)
weight, bias = linear.weight, linear.bias
scope = paddle.static.Scope()
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
x = main_prog.global_block().var(x.name)
if x.dtype == core.VarDesc.VarType.FP16:
x_np = x_np.astype(np.float16)
with paddle.static.scope_guard(scope):
exe.run(startup_prog)
if multi_precision:
optimizer.amp_init(place)
weight_np, bias_np = None, None
for i in range(n):
feed_dict = {x.name: x_np}
weight_np, bias_np = exe.run(main_prog,
feed=feed_dict,
fetch_list=[weight, bias])
return weight_np.astype('float32'), bias_np.astype('float32')
@switch_to_static_graph
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
place = paddle.CUDAPlace(0)
x_np = np.random.random(size=[5, 10]).astype('float32')
weight_1, bias_1 = self.check_main(x_np, place, multi_precision=False)
weight_2, bias_2 = self.check_main(x_np, place, multi_precision=True)
self.assertTrue(np.all(np.abs(weight_1 - weight_2) < 1e-3))
self.assertTrue(np.all(np.abs(bias_1 - bias_2) < 1e-7))
if __name__ == "__main__":
unittest.main()
......@@ -20,9 +20,13 @@ import paddle
from operator import mul
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.nn.functional as F
from functools import reduce
from op_test import _set_use_system_allocator
from paddle.fluid import Program, program_guard
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32
paddle.enable_static()
np.random.random(123)
......@@ -325,5 +329,58 @@ class TestDygraphLayerNormAPIError(unittest.TestCase):
self.assertRaises(TypeError, layer_norm, x2)
class TestFP16ScaleBiasLayerNorm(unittest.TestCase):
def check_main(self, x_np, weight_np, bias_np, dtype):
paddle.disable_static()
weight_np = weight_np.astype(dtype)
bias_np = bias_np.astype(dtype)
x = paddle.to_tensor(x_np)
weight = paddle.to_tensor(weight_np)
bias = paddle.to_tensor(bias_np)
x.stop_gradient = False
weight.stop_gradient = False
bias.stop_gradient = False
y = F.layer_norm(x, x.shape[1:], weight, bias)
x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])
y_np = y.numpy().astype('float32')
x_g_np = x_g.numpy().astype('float32')
w_g_np = w_g.numpy().astype('float16')
b_g_np = b_g.numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np, w_g_np, b_g_np
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
x_np = np.random.random([10, 20]).astype('float16')
weight_np = np.random.random([20]).astype('float16')
bias_np = np.random.random([20]).astype('float16')
y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
x_np, weight_np, bias_np, 'float16')
y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
x_np, weight_np, bias_np, 'float32')
def assert_equal(x, y):
self.assertTrue(np.array_equal(x, y))
assert_equal(y_np_1, y_np_2)
assert_equal(x_g_np_1, x_g_np_2)
assert_equal(w_g_np_1, w_g_np_2)
assert_equal(b_g_np_1, b_g_np_2)
class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase):
def test_main(self):
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
_keep_layer_norm_scale_bias_to_fp32(False)
self.assertFalse(_keep_layer_norm_scale_bias_to_fp32())
_keep_layer_norm_scale_bias_to_fp32(True)
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,9 @@ from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
from ..fluid.framework import Variable
from ..fluid import layers
from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper
from paddle import _C_ops
__all__ = []
......@@ -127,6 +130,36 @@ class Lamb(Optimizer):
'lamb_weight_decay': lamb_weight_decay,
'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
}
self._master_weights = {}
# TODO(zengjinle): expose API as soon as possible
self._multi_precision = False
def _create_master_weight(self, param):
assert self._multi_precision
if param.name in self._master_weights:
var = self._master_weights[param.name]
else:
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
self._master_weights[param.name] = var
return var
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
......@@ -135,11 +168,43 @@ class Lamb(Optimizer):
# Create accumulator tensors for first and second moments
for p in parameters:
self._add_accumulator(self._moment1_acc_str, p)
self._add_accumulator(self._moment2_acc_str, p)
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
else:
self._add_moments_pows(p)
def _get_accumulator(self, name, param):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if self._name is not None:
name = self._name + "_" + name
find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
target_param = self._master_weights[
param.name] if find_master else param
target_name = target_param.name
if (name not in self._accumulators or
target_name not in self._accumulators[name]):
raise Exception("Accumulator {} does not exist for parameter {}".
format(name, target_name))
return self._accumulators[name][target_name]
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16:
acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
......@@ -147,6 +212,7 @@ class Lamb(Optimizer):
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
......@@ -175,13 +241,20 @@ class Lamb(Optimizer):
weight_decay = self._lamb_weight_decay
lr = self._create_param_lr(param_and_grad)
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = self._master_weights[param_and_grad[0]
.name] if find_master else None
found_inf = self._get_auxiliary_var('found_inf')
if framework.in_dygraph_mode():
_, _, _, _, _ = _C_ops.lamb(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay',
weight_decay)
_C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1,
moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
param_and_grad[0], moment1, moment2, beta1_pow_acc,
beta2_pow_acc, master_weight, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon,
'weight_decay', weight_decay, 'multi_precision',
find_master)
return None
# create the lamb optimize op
......@@ -205,9 +278,17 @@ class Lamb(Optimizer):
"beta1": self._beta1,
"beta2": self._beta2,
"epsilon": self._epsilon,
"weight_decay": weight_decay
"weight_decay": weight_decay,
"multi_precision": find_master,
}
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
if found_inf:
inputs["SkipUpdate"] = found_inf
lamb_op = block.append_op(
type=self.type,
inputs=inputs,
......
......@@ -217,6 +217,14 @@ class Optimizer(object):
else:
self._param_groups = self._parameter_list
self._auxiliary_vars = {}
def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val
def _get_auxiliary_var(self, key):
return self._auxiliary_vars.get(key, None)
@framework.dygraph_only
def state_dict(self):
'''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册