未验证 提交 d80fe268 编写于 作者: S sneaxiy 提交者: GitHub

Refine some AMP operators for BERT (#37923)

* support multi precision update for LAMB

* hide some api

* fix ci uts

* fix lamb output of dygraph

* remove some changes to some PR

* try to fix Py3 CI compile error

* fix test_imperative_optimizer, add lars ut, add layer_norm ut

* fix ut, fix format

* fix ut

* fix windows ci
上级 cff03734
...@@ -169,10 +169,16 @@ __inline__ __device__ half rsqrt_(const half val) { ...@@ -169,10 +169,16 @@ __inline__ __device__ half rsqrt_(const half val) {
} }
#endif #endif
template <typename T, typename U, int BlockDim> template <typename T, typename U, bool ScaleBiasWithSameTypeX>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias, using LayerNormScaleBiasT =
T *y, U *mean, U *var, float epsilon, typename std::conditional<ScaleBiasWithSameTypeX, T, U>::type;
int64_t feature_size) {
template <typename T, typename U, int BlockDim,
bool ScaleBiasWithSameTypeX = false>
__global__ void LayerNormForward(
const T *x, const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *bias, T *y,
U *mean, U *var, float epsilon, int64_t feature_size) {
__shared__ U mean_share; __shared__ U mean_share;
__shared__ U var_share; __shared__ U var_share;
__shared__ U shared_mean[32]; // threadIdx.x / warpSize <= kMaxBlockDim / __shared__ U shared_mean[32]; // threadIdx.x / warpSize <= kMaxBlockDim /
...@@ -212,14 +218,15 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, ...@@ -212,14 +218,15 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
if (bias != nullptr) { if (bias != nullptr) {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>( y[i] = static_cast<T>(static_cast<U>(scale[j]) *
scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]); (static_cast<U>(x[i]) - mean_val) * invvar +
static_cast<U>(bias[j]));
} }
} else { } else {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) * y[i] = static_cast<T>(static_cast<U>(scale[j]) *
invvar); (static_cast<U>(x[i]) - mean_val) * invvar);
} }
} }
} else { // scale == nullptr } else { // scale == nullptr
...@@ -227,7 +234,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, ...@@ -227,7 +234,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar + y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
bias[j]); static_cast<U>(bias[j]));
} }
} else { } else {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
...@@ -336,12 +343,15 @@ __global__ void LayerNormBackwardPartGradGammaBeta( ...@@ -336,12 +343,15 @@ __global__ void LayerNormBackwardPartGradGammaBeta(
} }
} }
template <typename T, typename U, int BDIMX, int BDIMY> template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX>
__global__ void LayerNormBackwardSumGradGammaBeta( __global__ void LayerNormBackwardSumGradGammaBeta(
const U *part_grad_gamma, const U *part_grad_beta, const int part_size, const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
// const int n1, const int n2, T* grad_gamma, T* grad_beta) { // const int n1, const int n2, T* grad_gamma, T* grad_beta) {
const int n1, const int n2, U *grad_gamma, U *grad_beta) { const int n1, const int n2,
LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_gamma,
LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_beta) {
// sum partial gradients for gamma and beta // sum partial gradients for gamma and beta
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX>;
__shared__ U buf[BDIMX * BDIMY]; __shared__ U buf[BDIMX * BDIMY];
int64_t i2 = blockIdx.x * BDIMX + threadIdx.x; int64_t i2 = blockIdx.x * BDIMX + threadIdx.x;
if (i2 < n2) { if (i2 < n2) {
...@@ -378,20 +388,18 @@ __global__ void LayerNormBackwardSumGradGammaBeta( ...@@ -378,20 +388,18 @@ __global__ void LayerNormBackwardSumGradGammaBeta(
} }
// write out fully summed gradients // write out fully summed gradients
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma; grad_gamma[i2] = static_cast<ScaleBiasT>(sum_gamma);
grad_beta[i2] = sum_beta; grad_beta[i2] = static_cast<ScaleBiasT>(sum_beta);
} }
} }
} }
template <typename T, typename U, int BDIMX, int BDIMY> template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX>
__global__ void LayerNormBackwardComputeGradInput( __global__ void LayerNormBackwardComputeGradInput(
const T *__restrict__ dout, const T *__restrict__ input, const int n1, const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2, const int n2, const U *__restrict__ mean, const U *__restrict__ var,
// const U* __restrict__ mean, const U* __restrict__ var, const float const float epsilon,
// epsilon, const T* gamma, const LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *gamma, T *grad_input) {
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) {
#ifdef __HIPCC__ #ifdef __HIPCC__
for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) { for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) {
#else #else
...@@ -411,15 +419,17 @@ __global__ void LayerNormBackwardComputeGradInput( ...@@ -411,15 +419,17 @@ __global__ void LayerNormBackwardComputeGradInput(
for (int k = 0; k < 4; ++k) { for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]); const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]); const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k]; sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; sum_loss2 +=
c_loss * static_cast<U>(gamma[l + k]) * (c_h - c_mean) * c_invvar;
} }
} }
for (; l < n2; ++l) { for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]); const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]); const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l]; sum_loss1 += c_loss * static_cast<U>(gamma[l]);
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; sum_loss2 +=
c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
} }
} else { } else {
int l = 4 * thrx; int l = 4 * thrx;
...@@ -491,7 +501,7 @@ __global__ void LayerNormBackwardComputeGradInput( ...@@ -491,7 +501,7 @@ __global__ void LayerNormBackwardComputeGradInput(
for (int l = thrx; l < n2; l += numx) { for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]); const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]); const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l]; U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
f_grad_input -= sum_loss1; f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1; f_grad_input *= term1;
...@@ -513,11 +523,17 @@ __global__ void LayerNormBackwardComputeGradInput( ...@@ -513,11 +523,17 @@ __global__ void LayerNormBackwardComputeGradInput(
// Make sure that d_scale != nullptr && d_bias != nullptr // Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr // Since d_scale != nullptr, scale would not be nullptr
template <typename T, typename U, int BlockDim, bool HasDx> template <typename T, typename U, int BlockDim, bool HasDx,
bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientAll( __global__ void LayerNormBackwardGradientAll(
const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, const T *x, const T *d_y,
const U *var, const U *scale, float epsilon, int64_t batch_size, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
int64_t feature_size, int64_t col_offset) { LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, T *d_x,
const U *mean, const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t batch_size, int64_t feature_size,
int64_t col_offset) {
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
int64_t beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset); int64_t beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
int64_t end_idx = batch_size * feature_size + (blockIdx.x + col_offset); int64_t end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
int64_t stride = BlockDim * feature_size; int64_t stride = BlockDim * feature_size;
...@@ -532,7 +548,8 @@ __global__ void LayerNormBackwardGradientAll( ...@@ -532,7 +548,8 @@ __global__ void LayerNormBackwardGradientAll(
d_bias_partial += static_cast<U>(d_y[i]); d_bias_partial += static_cast<U>(d_y[i]);
if (HasDx) { if (HasDx) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val); static_cast<U>(scale[blockIdx.x + col_offset]) /
var_val);
} }
} }
...@@ -543,19 +560,24 @@ __global__ void LayerNormBackwardGradientAll( ...@@ -543,19 +560,24 @@ __global__ void LayerNormBackwardGradientAll(
d_bias_partial = BlockReduceSum<U>(d_bias_partial, shared_bias); d_bias_partial = BlockReduceSum<U>(d_bias_partial, shared_bias);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
d_scale[blockIdx.x + col_offset] = d_scale_partial; d_scale[blockIdx.x + col_offset] = static_cast<ScaleBiasT>(d_scale_partial);
d_bias[blockIdx.x + col_offset] = d_bias_partial; d_bias[blockIdx.x + col_offset] = static_cast<ScaleBiasT>(d_bias_partial);
} }
} }
// Make sure that there is only one true expression: d_scale != nullptr // Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr // or d_bias != nullptr
// Notice: scale may be nullptr // Notice: scale may be nullptr
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale> template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale,
bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientScaleOrBias( __global__ void LayerNormBackwardGradientScaleOrBias(
const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, const T *x, const T *d_y,
const U *var, const U *scale, float epsilon, int64_t batch_size, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
int64_t feature_size, int col_offset) { LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, T *d_x,
const U *mean, const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t batch_size, int64_t feature_size, int col_offset) {
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
using BlockReduce = cub::BlockReduce<U, BlockDim>; using BlockReduce = cub::BlockReduce<U, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
int64_t beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset; int64_t beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
...@@ -578,7 +600,8 @@ __global__ void LayerNormBackwardGradientScaleOrBias( ...@@ -578,7 +600,8 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if (HasDx) { if (HasDx) {
if (scale != nullptr) { if (scale != nullptr) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val); static_cast<U>(scale[blockIdx.x + col_offset]) /
var_val);
} else { } else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val); d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
} }
...@@ -590,9 +613,11 @@ __global__ void LayerNormBackwardGradientScaleOrBias( ...@@ -590,9 +613,11 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (HasDScale) { if (HasDScale) {
d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial; d_scale[blockIdx.x + col_offset] =
static_cast<ScaleBiasT>(d_scale_or_d_bias_partial);
} else { } else {
d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial; d_bias[blockIdx.x + col_offset] =
static_cast<ScaleBiasT>(d_scale_or_d_bias_partial);
} }
} }
} }
...@@ -640,12 +665,12 @@ __global__ void LayerNormBackwardPostProcessToCalculateDX( ...@@ -640,12 +665,12 @@ __global__ void LayerNormBackwardPostProcessToCalculateDX(
} }
// Here, we only calculate d_x // Here, we only calculate d_x
template <typename T, typename U, int BlockDim> template <typename T, typename U, int BlockDim, bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, __global__ void LayerNormBackwardGradientOnlyDX(
T *d_x, const U *mean, const T *x, const T *d_y, T *d_x, const U *mean, const U *var,
const U *var, const U *scale, const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, float epsilon, int64_t feature_size) {
int64_t feature_size) { using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U d_x_reduce_tmp[2]; __shared__ U d_x_reduce_tmp[2];
...@@ -660,8 +685,8 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, ...@@ -660,8 +685,8 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon)); static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
if (scale != nullptr) { if (scale != nullptr) {
int col_idx = i % feature_size; int col_idx = i % feature_size;
d_x[i] = d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val); static_cast<U>(scale[col_idx]) / var_val);
} else { } else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val); d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
} }
...@@ -692,11 +717,16 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, ...@@ -692,11 +717,16 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
} }
} }
template <typename T, typename U> template <typename T, typename U, bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardWhenBatchSizeIsOne( __global__ void LayerNormBackwardWhenBatchSizeIsOne(
const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean, const T *x, const T *d_y, T *d_x,
const U *var, const U *scale, float epsilon, int64_t feature_size) { LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, const U *mean,
const U *var,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
float epsilon, int64_t feature_size) {
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
if (idx < feature_size) { if (idx < feature_size) {
auto var_val = auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon)); static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon));
...@@ -704,25 +734,31 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( ...@@ -704,25 +734,31 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
if (d_scale == nullptr) { if (d_scale == nullptr) {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val); d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
} else { } else {
d_x[idx] = d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) *
static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val); static_cast<U>(scale[idx]) / var_val);
} }
} }
if (d_scale != nullptr) { if (d_scale != nullptr) {
d_scale[idx] = static_cast<U>(d_y[idx]) * d_scale[idx] =
(static_cast<U>(x[idx]) - mean[0]) / var_val; static_cast<ScaleBiasT>(static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[0]) / var_val);
} }
if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]); if (d_bias != nullptr) {
d_bias[idx] = static_cast<ScaleBiasT>(d_y[idx]);
}
} }
} }
template <typename T, typename U> template <typename T, typename U, bool ScaleBiasWithSameTypeX = false>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale, static void LayerNormBackward(
const U *mean, const U *var, T *d_x, U *d_scale, const T *x, const T *d_y,
U *d_bias, float epsilon, int64_t batch_size, const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
int64_t feature_size, const U *mean, const U *var, T *d_x,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, float epsilon,
int64_t batch_size, int64_t feature_size,
const platform::CUDADeviceContext &dev_ctx) { const platform::CUDADeviceContext &dev_ctx) {
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
#ifdef __HIPCC__ #ifdef __HIPCC__
...@@ -737,10 +773,10 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -737,10 +773,10 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
if (gradient_flag == 0) return; if (gradient_flag == 0) return;
if (batch_size == 1) { if (batch_size == 1) {
LayerNormBackwardWhenBatchSizeIsOne< LayerNormBackwardWhenBatchSizeIsOne<T, U, ScaleBiasWithSameTypeX><<<
T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, (feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0,
0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
epsilon, feature_size); feature_size);
if (d_x != nullptr) { if (d_x != nullptr) {
switch (GetDesiredBlockDim(feature_size)) { switch (GetDesiredBlockDim(feature_size)) {
...@@ -759,8 +795,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -759,8 +795,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias< LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, false, T, U, kBlockDim, false, false,
false><<<block_num, kBlockDim, 0, stream>>>( ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
} }
...@@ -770,8 +806,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -770,8 +806,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias< LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, false, T, U, kBlockDim, false, true,
true><<<block_num, kBlockDim, 0, stream>>>( ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
} }
...@@ -781,7 +817,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -781,7 +817,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientAll< LayerNormBackwardGradientAll<
T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>( T, U, kBlockDim, false,
ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
} }
...@@ -790,7 +827,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -790,7 +827,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
switch (GetDesiredBlockDim(feature_size)) { switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
LayerNormBackwardGradientOnlyDX< LayerNormBackwardGradientOnlyDX<
T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>( T, U, kBlockDim,
ScaleBiasWithSameTypeX><<<batch_size, kBlockDim, 0, stream>>>(
x, d_y, d_x, mean, var, scale, epsilon, feature_size)); x, d_y, d_x, mean, var, scale, epsilon, feature_size));
} }
break; break;
...@@ -799,8 +837,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -799,8 +837,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias< LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, true, T, U, kBlockDim, true, false,
false><<<block_num, kBlockDim, 0, stream>>>( ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
} }
...@@ -816,8 +854,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -816,8 +854,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias< LayerNormBackwardGradientScaleOrBias<
T, U, kBlockDim, true, T, U, kBlockDim, true, true,
true><<<block_num, kBlockDim, 0, stream>>>( ScaleBiasWithSameTypeX><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
} }
...@@ -854,7 +892,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -854,7 +892,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
dim3 threads3(BDIMX3, BDIMY3, 1); dim3 threads3(BDIMX3, BDIMY3, 1);
const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1); const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
LayerNormBackwardSumGradGammaBeta< LayerNormBackwardSumGradGammaBeta<
T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>( T, U, BDIMX3, BDIMY3,
ScaleBiasWithSameTypeX><<<blocks3, threads3, 0, stream>>>(
part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size, part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
d_scale, d_bias); d_scale, d_bias);
...@@ -862,7 +901,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -862,7 +901,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
constexpr int BDIMY1 = 4; constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1); dim3 threads1(BDIMX1, BDIMY1, 1);
LayerNormBackwardComputeGradInput< LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<batch_size, threads1, 0, stream>>>( T, U, BDIMX1, BDIMY1,
ScaleBiasWithSameTypeX><<<batch_size, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break; break;
} }
......
...@@ -102,24 +102,6 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -102,24 +102,6 @@ class LayerNormOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto ln_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
ln_param_type = framework::proto::VarType::FP64;
}
if (ctx.HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
}
if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
......
...@@ -63,8 +63,32 @@ class LayerNormKernel<platform::CUDADeviceContext, T> ...@@ -63,8 +63,32 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
auto *y_data = y->mutable_data<T>(ctx.GetPlace()); auto *y_data = y->mutable_data<T>(ctx.GetPlace());
auto *mean_data = mean->mutable_data<U>(ctx.GetPlace()); auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
auto *var_data = var->mutable_data<U>(ctx.GetPlace()); auto *var_data = var->mutable_data<U>(ctx.GetPlace());
auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>()); auto *void_scale_data = (scale == nullptr ? nullptr : scale->data<void>());
auto *void_bias_data = (bias == nullptr ? nullptr : bias->data<void>());
framework::proto::VarType::Type x_dtype = x->type();
framework::proto::VarType::Type scale_bias_dtype;
if (void_scale_data != nullptr) {
scale_bias_dtype = scale->type();
if (void_bias_data != nullptr) {
PADDLE_ENFORCE_EQ(scale_bias_dtype, bias->type(),
platform::errors::InvalidArgument(
"Thie Scale and Bias of layer_norm op "
"should have the same data type."));
}
} else {
scale_bias_dtype = (void_bias_data != nullptr ? bias->type() : x_dtype);
}
bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
if (!is_scale_bias_same_dtype_with_x) {
PADDLE_ENFORCE_EQ(scale_bias_dtype,
framework::DataTypeTrait<U>::DataType(),
platform::errors::InvalidArgument(
"Unsupported data type of Scale and Bias: %s",
framework::DataTypeToString(scale_bias_dtype)));
}
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]); int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
...@@ -72,17 +96,28 @@ class LayerNormKernel<platform::CUDADeviceContext, T> ...@@ -72,17 +96,28 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
switch (GetDesiredBlockDim(feature_size)) { #define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
FIXED_BLOCK_DIM_CASE( do { \
LayerNormForward<T, U, switch (GetDesiredBlockDim(feature_size)) { \
kBlockDim><<<batch_size, kBlockDim, 0, stream>>>( FIXED_BLOCK_DIM_CASE( \
x_data, scale_data, bias_data, y_data, mean_data, var_data, LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX><<< \
epsilon, feature_size)); batch_size, kBlockDim, 0, stream>>>( \
default: x_data, static_cast<const ScaleBiasT *>(void_scale_data), \
PADDLE_THROW(platform::errors::InvalidArgument( static_cast<const ScaleBiasT *>(void_bias_data), y_data, \
"Product from begin_norm_axis to end must be larger than 1")); mean_data, var_data, epsilon, feature_size)); \
break; default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Product from begin_norm_axis to end must be larger than 1")); \
break; \
} \
} while (0)
if (is_scale_bias_same_dtype_with_x) {
PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
} }
#undef PADDLE_LAUNCH_LAYERNORM_FWD
} }
}; };
...@@ -102,32 +137,64 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T> ...@@ -102,32 +137,64 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
auto *mean = ctx.Input<Tensor>("Mean"); auto *mean = ctx.Input<Tensor>("Mean");
auto *var = ctx.Input<Tensor>("Variance"); auto *var = ctx.Input<Tensor>("Variance");
auto *scale = ctx.Input<Tensor>("Scale"); auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto *x_data = x->data<T>(); auto *x_data = x->data<T>();
auto *d_y_data = d_y->data<T>(); auto *d_y_data = d_y->data<T>();
auto *mean_data = mean->data<U>(); auto *mean_data = mean->data<U>();
auto *var_data = var->data<U>(); auto *var_data = var->data<U>();
auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *d_scale_data =
(d_scale == nullptr ? nullptr
: d_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_bias_data =
(d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
auto *d_x_data = auto *d_x_data =
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace())); (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));
const auto &x_dims = x->dims(); framework::proto::VarType::Type x_dtype = x->type();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); framework::proto::VarType::Type scale_bias_dtype;
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); if (scale != nullptr) {
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]); scale_bias_dtype = scale->type();
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]); } else {
// FIXME(zengjinle): do not find a better way to get the right
// data type of the d_scale and d_bias if scale == nullptr.
auto *bias = ctx.Input<Tensor>("Bias");
if (bias != nullptr) {
scale_bias_dtype = bias->saved_type();
} else {
scale_bias_dtype = x_dtype;
}
}
#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
auto *scale_data = \
(scale == nullptr ? nullptr : scale->data<ScaleBiasT>()); \
auto *d_scale_data = \
(d_scale == nullptr ? nullptr : d_scale->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_bias_data = \
(d_bias == nullptr ? nullptr : d_bias->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_x_data = \
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace())); \
LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>( \
x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, \
d_scale_data, d_bias_data, epsilon, batch_size, feature_size, \
ctx.cuda_device_context()); \
} while (0)
if (scale_bias_dtype == x_dtype) {
PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
}
LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data, #undef PADDLE_LAUNCH_LAYERNORM_BWD
d_x_data, d_scale_data, d_bias_data, epsilon,
batch_size, feature_size,
ctx.cuda_device_context());
} }
}; };
......
...@@ -152,6 +152,14 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -152,6 +152,14 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Moment2", "(Tensor) Input second moment."); AddInput("Moment2", "(Tensor) Input second moment.");
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator."); AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator.");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator."); AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator.");
AddInput("MasterParam",
"(LoDTensor, default LoDTensor<float>) "
"Input master parameter that has to be updated.")
.AsDispensable();
AddInput(
"SkipUpdate",
"(Tensor) Input tensor to determine whether to update the parameter.")
.AsDispensable();
AddOutput("ParamOut", "(Tensor) Output parameter."); AddOutput("ParamOut", "(Tensor) Output parameter.");
AddOutput("Moment1Out", "(Tensor) Output first moment."); AddOutput("Moment1Out", "(Tensor) Output first moment.");
...@@ -160,6 +168,8 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -160,6 +168,8 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable(); .AsDispensable();
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator") AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator")
.AsDispensable(); .AsDispensable();
AddOutput("MasterParamOut", "(Tensor) Output master parameter.")
.AsDispensable();
AddAttr<float>("weight_decay", "(float) Weight decay rate."); AddAttr<float>("weight_decay", "(float) Weight decay rate.");
AddAttr<float>("beta1", AddAttr<float>("beta1",
"(float, default 0.9) The exponential decay rate for the " "(float, default 0.9) The exponential decay rate for the "
...@@ -173,6 +183,10 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -173,6 +183,10 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
"(float, default 1.0e-6) " "(float, default 1.0e-6) "
"Constant for numerical stability.") "Constant for numerical stability.")
.SetDefault(1.0e-6f); .SetDefault(1.0e-6f);
AddAttr<bool>(
"multi_precision",
"(bool, default false) Whether to enable multi-precision mode.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer. LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.
......
...@@ -16,5 +16,7 @@ limitations under the License. */ ...@@ -16,5 +16,7 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
lamb, ops::LambOpKernel<paddle::platform::CUDADeviceContext, float>, lamb, ops::LambOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::LambOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::LambOpKernel<paddle::platform::CUDADeviceContext, double>); ops::LambOpKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -72,6 +72,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -72,6 +72,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"adamw", {"adamw",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}}, "Beta2Pow", "MasterParam"}},
{"lamb",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
{"sparse_attention", {"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}}, {"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
}; };
...@@ -112,8 +115,6 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -112,8 +115,6 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"run_program", {"DOut"}}, {"run_program", {"DOut"}},
{"adam", {"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
...@@ -121,6 +122,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -121,6 +122,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"adamw", {"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}}, "MasterParamOut"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
}; };
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
...@@ -142,6 +146,9 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -142,6 +146,9 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"adamw", {"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}}, "MasterParamOut"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"average_accumulates", {"average_accumulates",
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates", {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}}, "out_old_num_accumulates", "out_num_updates"}},
...@@ -173,8 +180,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -173,8 +180,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}}, {"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
{"moving_average_abs_max_scale", {"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}}, {"Out", "OutScale", "OutAccum", "OutState"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}}, {"rnn", {"DropoutState"}},
{"run_program", {"Out", "DOut", "OutScope"}}, {"run_program", {"Out", "DOut", "OutScope"}},
{"clear_float_status", {"FloatStatusOut"}}, {"clear_float_status", {"FloatStatusOut"}},
......
...@@ -371,6 +371,20 @@ class ClipGradByNorm(ClipGradBase): ...@@ -371,6 +371,20 @@ class ClipGradByNorm(ClipGradBase):
return param, new_grad return param, new_grad
_allow_pure_fp16_global_norm_clip_flag = False
def _allow_pure_fp16_global_norm_clip(*args):
global _allow_pure_fp16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_fp16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_fp16_global_norm_clip_flag
_allow_pure_fp16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase): class ClipGradByGlobalNorm(ClipGradBase):
r""" r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
...@@ -537,8 +551,12 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -537,8 +551,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var = [] global_norm_var = []
if len(sum_square_list_fp16) > 0: if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16) global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
if sum_square_list_fp32 or sum_square_list or not _allow_pure_fp16_global_norm_clip(
):
global_norm_var.append( global_norm_var.append(
global_norm_var_fp16.astype(sum_dtype)) global_norm_var_fp16.astype(sum_dtype))
else:
global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_fp32) > 0: if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.sums(sum_square_list_fp32) global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
if sum_dtype == 'float32': if sum_dtype == 'float32':
...@@ -573,8 +591,9 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -573,8 +591,9 @@ class ClipGradByGlobalNorm(ClipGradBase):
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
# inplace # inplace
scale_input = (scale_var.astype('float16') scale_input = (scale_var.astype('float16')
if g.dtype == core.VarDesc.VarType.FP16 else if g.dtype == core.VarDesc.VarType.FP16 and
scale_var) scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops. # will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter # We need to handle the correct block, otherwise will encounter
......
...@@ -411,6 +411,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -411,6 +411,8 @@ class OptimizerWithMixedPrecision(object):
found_inf = paddle.tensor.creation._memcpy(found_inf, found_inf = paddle.tensor.creation._memcpy(found_inf,
paddle.CPUPlace()) paddle.CPUPlace())
real_optimizer._set_auxiliary_var('found_inf', found_inf) real_optimizer._set_auxiliary_var('found_inf', found_inf)
elif hasattr(real_optimizer, "_set_auxiliary_var"):
real_optimizer._set_auxiliary_var('found_inf', found_inf)
optimize_ops = self._optimizer.apply_gradients(params_grads) optimize_ops = self._optimizer.apply_gradients(params_grads)
return optimize_ops return optimize_ops
......
...@@ -80,11 +80,27 @@ def _dtype_to_str(dtype): ...@@ -80,11 +80,27 @@ def _dtype_to_str(dtype):
return 'fp32' return 'fp32'
_keep_layer_norm_scale_bias_to_fp32_flag = True
def _keep_layer_norm_scale_bias_to_fp32(*args):
global _keep_layer_norm_scale_bias_to_fp32_flag
if len(args) == 0:
return _keep_layer_norm_scale_bias_to_fp32_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _keep_layer_norm_scale_bias_to_fp32_flag
_keep_layer_norm_scale_bias_to_fp32_flag = args[0]
return old_value
def _keep_fp32_input(op, in_name): def _keep_fp32_input(op, in_name):
op_type = op.type op_type = op.type
if op_type in ['batch_norm', 'layer_norm']: if op_type == 'batch_norm':
# Scale, Bias, Mean, Variance should be float32. # Scale, Bias, Mean, Variance should be float32.
return in_name != 'X' return in_name != 'X'
if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
return in_name != 'X'
if op_type == 'fused_bn_add_activation': if op_type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'} return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit': if op_type == 'resnet_unit':
...@@ -98,7 +114,9 @@ def _keep_fp32_input(op, in_name): ...@@ -98,7 +114,9 @@ def _keep_fp32_input(op, in_name):
def _keep_fp32_output(op, out_name): def _keep_fp32_output(op, out_name):
op_type = op.type op_type = op.type
if op_type in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']: if op_type in ['batch_norm', 'fused_bn_add_activation']:
return out_name != 'Y'
if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
return out_name != 'Y' return out_name != 'Y'
if op_type == 'resnet_unit': if op_type == 'resnet_unit':
return out_name not in {'Y', 'ConvX', 'ConvZ'} return out_name not in {'Y', 'ConvX', 'ConvZ'}
......
...@@ -3554,14 +3554,14 @@ class LambOptimizer(AdamOptimizer): ...@@ -3554,14 +3554,14 @@ class LambOptimizer(AdamOptimizer):
else: else:
weight_decay = self._weight_decay weight_decay = self._weight_decay
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
master_weight = None
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
_, _, _, _, _ = _C_ops.lamb( _C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1,
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, param_and_grad[0], moment1, moment2, beta1_pow_acc,
moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1, beta2_pow_acc, master_weight, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay', 'beta2', self._beta2, 'epsilon', self._epsilon,
weight_decay) 'weight_decay', weight_decay)
return None return None
# create the lamb optimize op # create the lamb optimize op
......
...@@ -21,6 +21,7 @@ import paddle.fluid.core as core ...@@ -21,6 +21,7 @@ import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
import six import six
from fake_reader import fake_imdb_reader from fake_reader import fake_imdb_reader
from paddle.fluid.clip import _allow_pure_fp16_global_norm_clip
paddle.enable_static() paddle.enable_static()
...@@ -566,5 +567,35 @@ class TestDygraphGradientClipFP64(unittest.TestCase): ...@@ -566,5 +567,35 @@ class TestDygraphGradientClipFP64(unittest.TestCase):
% (a, b)) % (a, b))
class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase):
def check_main(self, expected_has_cast_op):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
names = ["p0", "p1"]
shapes = [[2, 3], [4, 5]]
param_and_grads = []
main_block = main_prog.global_block()
for name, shape in zip(names, shapes):
p = main_block.create_parameter(
name=name, shape=shape, dtype='float16')
g = main_block.create_parameter(
name=p.name + '@GRAD', shape=p.shape, dtype=p.dtype)
param_and_grads.append((p, g))
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
clip(param_and_grads)
actual_has_cast = any(op.type == 'cast' for op in main_block.ops)
self.assertEqual(actual_has_cast, expected_has_cast_op)
def test_main(self):
self.check_main(True)
_allow_pure_fp16_global_norm_clip(True)
self.check_main(False)
_allow_pure_fp16_global_norm_clip(False)
self.check_main(True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from paddle.fluid.dygraph.base import switch_to_static_graph
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
...@@ -181,5 +182,58 @@ class TestLambOpV2Group(TestLambOpV2): ...@@ -181,5 +182,58 @@ class TestLambOpV2Group(TestLambOpV2):
adam.clear_gradients() adam.clear_gradients()
class TestLambOpMultiPrecision(unittest.TestCase):
def check_main(self, x_np, place, multi_precision=False, seed=10, n=10):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
paddle.seed(seed)
with paddle.static.amp.fp16_guard():
x = paddle.static.data(
name='x', shape=[None, 10], dtype='float32')
linear = paddle.nn.Linear(10, 2)
hidden = linear(x)
loss = paddle.mean(hidden)
optimizer = paddle.optimizer.Lamb(learning_rate=1e-3)
optimizer._multi_precision = multi_precision
if multi_precision:
optimizer = paddle.static.amp.decorate(
optimizer, use_pure_fp16=True, use_fp16_guard=True)
optimizer.minimize(loss)
weight, bias = linear.weight, linear.bias
scope = paddle.static.Scope()
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
x = main_prog.global_block().var(x.name)
if x.dtype == core.VarDesc.VarType.FP16:
x_np = x_np.astype(np.float16)
with paddle.static.scope_guard(scope):
exe.run(startup_prog)
if multi_precision:
optimizer.amp_init(place)
weight_np, bias_np = None, None
for i in range(n):
feed_dict = {x.name: x_np}
weight_np, bias_np = exe.run(main_prog,
feed=feed_dict,
fetch_list=[weight, bias])
return weight_np.astype('float32'), bias_np.astype('float32')
@switch_to_static_graph
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
place = paddle.CUDAPlace(0)
x_np = np.random.random(size=[5, 10]).astype('float32')
weight_1, bias_1 = self.check_main(x_np, place, multi_precision=False)
weight_2, bias_2 = self.check_main(x_np, place, multi_precision=True)
self.assertTrue(np.all(np.abs(weight_1 - weight_2) < 1e-3))
self.assertTrue(np.all(np.abs(bias_1 - bias_2) < 1e-7))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -20,9 +20,13 @@ import paddle ...@@ -20,9 +20,13 @@ import paddle
from operator import mul from operator import mul
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.nn.functional as F
from functools import reduce from functools import reduce
from op_test import _set_use_system_allocator from op_test import _set_use_system_allocator
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32
paddle.enable_static()
np.random.random(123) np.random.random(123)
...@@ -325,5 +329,58 @@ class TestDygraphLayerNormAPIError(unittest.TestCase): ...@@ -325,5 +329,58 @@ class TestDygraphLayerNormAPIError(unittest.TestCase):
self.assertRaises(TypeError, layer_norm, x2) self.assertRaises(TypeError, layer_norm, x2)
class TestFP16ScaleBiasLayerNorm(unittest.TestCase):
def check_main(self, x_np, weight_np, bias_np, dtype):
paddle.disable_static()
weight_np = weight_np.astype(dtype)
bias_np = bias_np.astype(dtype)
x = paddle.to_tensor(x_np)
weight = paddle.to_tensor(weight_np)
bias = paddle.to_tensor(bias_np)
x.stop_gradient = False
weight.stop_gradient = False
bias.stop_gradient = False
y = F.layer_norm(x, x.shape[1:], weight, bias)
x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])
y_np = y.numpy().astype('float32')
x_g_np = x_g.numpy().astype('float32')
w_g_np = w_g.numpy().astype('float16')
b_g_np = b_g.numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np, w_g_np, b_g_np
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
x_np = np.random.random([10, 20]).astype('float16')
weight_np = np.random.random([20]).astype('float16')
bias_np = np.random.random([20]).astype('float16')
y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
x_np, weight_np, bias_np, 'float16')
y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
x_np, weight_np, bias_np, 'float32')
def assert_equal(x, y):
self.assertTrue(np.array_equal(x, y))
assert_equal(y_np_1, y_np_2)
assert_equal(x_g_np_1, x_g_np_2)
assert_equal(w_g_np_1, w_g_np_2)
assert_equal(b_g_np_1, b_g_np_2)
class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase):
def test_main(self):
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
_keep_layer_norm_scale_bias_to_fp32(False)
self.assertFalse(_keep_layer_norm_scale_bias_to_fp32())
_keep_layer_norm_scale_bias_to_fp32(True)
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,6 +16,9 @@ from .optimizer import Optimizer ...@@ -16,6 +16,9 @@ from .optimizer import Optimizer
from ..fluid import core from ..fluid import core
from ..fluid import framework from ..fluid import framework
from ..fluid.framework import Variable from ..fluid.framework import Variable
from ..fluid import layers
from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper
from paddle import _C_ops from paddle import _C_ops
__all__ = [] __all__ = []
...@@ -127,6 +130,36 @@ class Lamb(Optimizer): ...@@ -127,6 +130,36 @@ class Lamb(Optimizer):
'lamb_weight_decay': lamb_weight_decay, 'lamb_weight_decay': lamb_weight_decay,
'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn, 'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
} }
self._master_weights = {}
# TODO(zengjinle): expose API as soon as possible
self._multi_precision = False
def _create_master_weight(self, param):
assert self._multi_precision
if param.name in self._master_weights:
var = self._master_weights[param.name]
else:
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
self._master_weights[param.name] = var
return var
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -135,11 +168,43 @@ class Lamb(Optimizer): ...@@ -135,11 +168,43 @@ class Lamb(Optimizer):
# Create accumulator tensors for first and second moments # Create accumulator tensors for first and second moments
for p in parameters: for p in parameters:
self._add_accumulator(self._moment1_acc_str, p) if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
self._add_accumulator(self._moment2_acc_str, p) master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
else:
self._add_moments_pows(p)
def _get_accumulator(self, name, param):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if self._name is not None:
name = self._name + "_" + name
find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
target_param = self._master_weights[
param.name] if find_master else param
target_name = target_param.name
if (name not in self._accumulators or
target_name not in self._accumulators[name]):
raise Exception("Accumulator {} does not exist for parameter {}".
format(name, target_name))
return self._accumulators[name][target_name]
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16:
acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
self._add_accumulator( self._add_accumulator(
name=self._beta1_pow_acc_str, name=self._beta1_pow_acc_str,
param=p, param=p,
dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, Variable) \ fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1, else self._beta1,
shape=[1], shape=[1],
...@@ -147,6 +212,7 @@ class Lamb(Optimizer): ...@@ -147,6 +212,7 @@ class Lamb(Optimizer):
self._add_accumulator( self._add_accumulator(
name=self._beta2_pow_acc_str, name=self._beta2_pow_acc_str,
param=p, param=p,
dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, Variable) \ fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2, else self._beta2,
shape=[1], shape=[1],
...@@ -175,13 +241,20 @@ class Lamb(Optimizer): ...@@ -175,13 +241,20 @@ class Lamb(Optimizer):
weight_decay = self._lamb_weight_decay weight_decay = self._lamb_weight_decay
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = self._master_weights[param_and_grad[0]
.name] if find_master else None
found_inf = self._get_auxiliary_var('found_inf')
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
_, _, _, _, _ = _C_ops.lamb( _C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1,
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, param_and_grad[0], moment1, moment2, beta1_pow_acc,
moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1, beta2_pow_acc, master_weight, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay', 'beta2', self._beta2, 'epsilon', self._epsilon,
weight_decay) 'weight_decay', weight_decay, 'multi_precision',
find_master)
return None return None
# create the lamb optimize op # create the lamb optimize op
...@@ -205,9 +278,17 @@ class Lamb(Optimizer): ...@@ -205,9 +278,17 @@ class Lamb(Optimizer):
"beta1": self._beta1, "beta1": self._beta1,
"beta2": self._beta2, "beta2": self._beta2,
"epsilon": self._epsilon, "epsilon": self._epsilon,
"weight_decay": weight_decay "weight_decay": weight_decay,
"multi_precision": find_master,
} }
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
if found_inf:
inputs["SkipUpdate"] = found_inf
lamb_op = block.append_op( lamb_op = block.append_op(
type=self.type, type=self.type,
inputs=inputs, inputs=inputs,
......
...@@ -217,6 +217,14 @@ class Optimizer(object): ...@@ -217,6 +217,14 @@ class Optimizer(object):
else: else:
self._param_groups = self._parameter_list self._param_groups = self._parameter_list
self._auxiliary_vars = {}
def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val
def _get_auxiliary_var(self, key):
return self._auxiliary_vars.get(key, None)
@framework.dygraph_only @framework.dygraph_only
def state_dict(self): def state_dict(self):
''' '''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册