未验证 提交 fe94db6c 编写于 作者: Z zhiboniu 提交者: GitHub

Fix LayerNorm Problem (#33420)

* Eliminate numerical differences of LayerNorm; fix LayerNorm Nan Bug while large data input

* fix bug while large shape of data input
上级 24bde98f
...@@ -42,15 +42,46 @@ using CudnnDataType = platform::CudnnDataType<T>; ...@@ -42,15 +42,46 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T> template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType; using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
inline static int GetDesiredBlockDim(int block_dim) { inline static int GetDesiredBlockDim(int64_t block_dim) {
#ifdef __HIPCC__ #ifdef __HIPCC__
const int kMaxBlockDim = 256; const int kMaxBlockDim = 256;
const int lwarpSize = 64;
#else #else
const int kMaxBlockDim = 512; const int kMaxBlockDim = 512;
const int lwarpSize = 32;
#endif #endif
return block_dim >= kMaxBlockDim return block_dim >= kMaxBlockDim ? kMaxBlockDim : lwarpSize;
? kMaxBlockDim }
: (1 << (static_cast<int>(std::log2f(block_dim))));
template <typename U>
static __forceinline__ __device__ U WarpReduceSum(U val) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
}
return val;
}
template <typename U>
__forceinline__ __device__ U BlockReduceSum(U val) {
static __shared__ U shared[32];
int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize;
val = WarpReduceSum(val); // Each warp performs partial reduction
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions
// read from shared memory only if that warp existed
val =
(threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast<U>(0);
if (wid == 0) val = WarpReduceSum(val); // Final reduce within first warp
return val;
} }
#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ #define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \
...@@ -73,9 +104,11 @@ inline static int GetDesiredBlockDim(int block_dim) { ...@@ -73,9 +104,11 @@ inline static int GetDesiredBlockDim(int block_dim) {
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE( \ #define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE( \
log2_block_dim, feature_size, kMaxBlockNum, ...) \ log2_block_dim, feature_size, kMaxBlockNum, ...) \
case (1 << (log2_block_dim)): { \ case (1 << (log2_block_dim)): { \
for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \ for (int64_t i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); \
int col_offset = i * kMaxBlockNum; \ i++) { \
int block_num = std::min(feature_size - col_offset, kMaxBlockNum); \ int64_t col_offset = i * static_cast<int64_t>(kMaxBlockNum); \
int block_num = static_cast<int>(std::min( \
feature_size - col_offset, static_cast<int64_t>(kMaxBlockNum))); \
constexpr auto kBlockDim = (1 << (log2_block_dim)); \ constexpr auto kBlockDim = (1 << (log2_block_dim)); \
__VA_ARGS__; \ __VA_ARGS__; \
} \ } \
...@@ -147,31 +180,32 @@ __inline__ __device__ half rsqrt_(const half val) { ...@@ -147,31 +180,32 @@ __inline__ __device__ half rsqrt_(const half val) {
template <typename T, typename U, int BlockDim> template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias, __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
T *y, U *mean, U *var, float epsilon, T *y, U *mean, U *var, float epsilon,
int feature_size) { int64_t feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U mean_share; __shared__ U mean_share;
__shared__ U var_share; __shared__ U var_share;
int beg_idx = blockIdx.x * feature_size + threadIdx.x; int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size; int64_t end_idx = (blockIdx.x + 1) * feature_size;
// Step 1: Reduce to calculate mean and var // Step 1: Reduce to calculate mean and var
U mean_val = 0; U mean_val = 0;
U var_val = 0; U var_val = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
U tmp = static_cast<U>(x[i]); U tmp = static_cast<U>(x[i]);
mean_val += tmp; mean_val += tmp;
var_val += (tmp * tmp); var_val += (tmp * tmp);
} }
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<U>(mean_val, var_val), mean_val = BlockReduceSum<U>(mean_val);
PairForLayerNormAddFunctor<U>()); var_val = BlockReduceSum<U>(var_val);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
auto tmp = pair.first_ / feature_size; auto scale = static_cast<float>(1.) / static_cast<float>(feature_size);
auto tmp = mean_val * scale;
mean[blockIdx.x] = mean_share = static_cast<U>(tmp); mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
var[blockIdx.x] = var_share = var_share = static_cast<U>(var_val * scale - mean_share * mean_share);
static_cast<U>(pair.second_ / feature_size - tmp * tmp); var_share = var_share > U(0) ? var_share : U(0);
var[blockIdx.x] = var_share;
} }
__syncthreads(); __syncthreads();
...@@ -181,13 +215,13 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, ...@@ -181,13 +215,13 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
// Step 2: Calculate y // Step 2: Calculate y
if (scale != nullptr) { if (scale != nullptr) {
if (bias != nullptr) { if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx; for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>( y[i] = static_cast<T>(
scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]); scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
} }
} else { } else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx; for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) * y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
invvar); invvar);
...@@ -195,13 +229,13 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, ...@@ -195,13 +229,13 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
} }
} else { // scale == nullptr } else { // scale == nullptr
if (bias != nullptr) { if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx; for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar + y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
bias[j]); bias[j]);
} }
} else { } else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx; for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar); y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
} }
...@@ -278,18 +312,18 @@ __global__ void LayerNormForwardFP16(const T *x, const U *scale, const U *bias, ...@@ -278,18 +312,18 @@ __global__ void LayerNormForwardFP16(const T *x, const U *scale, const U *bias,
template <typename T, typename U, int VPT> template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs( __inline__ __device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off, const int64_t i1_block, const int thr_load_row_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, const int thr_load_col_off, const int i2_off, const int row_stride,
const T *input, const T *dout, const int i1_end, const int n2, U *warp_buf1, U *warp_buf2, const T *input, const T *dout,
const U *__restrict__ mean, const U *__restrict__ var, const int64_t i1_end, const int64_t n2, const U *__restrict__ mean,
const float epsilon) { const U *__restrict__ var, const float epsilon) {
const int i1 = i1_block + thr_load_row_off; const int64_t i1 = i1_block + thr_load_row_off;
if (i1 >= i1_end) return; if (i1 >= i1_end) return;
U curr_mean = mean[i1]; U curr_mean = mean[i1];
U curr_invvar = rsqrt_<U>(var[i1] + epsilon); U curr_invvar = rsqrt_<U>(var[i1] + epsilon);
for (int k = 0; k < VPT; ++k) { for (int k = 0; k < VPT; ++k) {
const int i2 = i2_off + k; const int i2 = i2_off + k;
const int load_idx = i1 * n2 + i2; const int64_t load_idx = i1 * n2 + i2;
const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) { if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]); U curr_input = static_cast<U>(input[load_idx]);
...@@ -303,8 +337,8 @@ __inline__ __device__ void cuLoadAddStridedInputs( ...@@ -303,8 +337,8 @@ __inline__ __device__ void cuLoadAddStridedInputs(
template <typename T, typename U, int BDIMX, int BDIMY, int VPTX> template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
__global__ void LayerNormBackwardPartGradGammaBeta( __global__ void LayerNormBackwardPartGradGammaBeta(
const T *__restrict__ dout, const T *__restrict__ input, const int n1, const T *__restrict__ dout, const T *__restrict__ input, const int64_t n1,
const int n2, const U *__restrict__ mean, const U *__restrict__ var, const int64_t n2, const U *__restrict__ mean, const U *__restrict__ var,
float epsilon, U *part_grad_gamma, U *part_grad_beta) { float epsilon, U *part_grad_gamma, U *part_grad_beta) {
// VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX
// -> blockDim.x // -> blockDim.x
...@@ -330,7 +364,7 @@ __global__ void LayerNormBackwardPartGradGammaBeta( ...@@ -330,7 +364,7 @@ __global__ void LayerNormBackwardPartGradGammaBeta(
} }
__syncthreads(); __syncthreads();
for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1; for (int64_t i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
i1_block += VPTX * BDIMY * gridDim.y) { i1_block += VPTX * BDIMY * gridDim.y) {
cuLoadAddStridedInputs<T, U, VPTX>( cuLoadAddStridedInputs<T, U, VPTX>(
i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
...@@ -363,7 +397,7 @@ __global__ void LayerNormBackwardPartGradGammaBeta( ...@@ -363,7 +397,7 @@ __global__ void LayerNormBackwardPartGradGammaBeta(
} }
__syncthreads(); __syncthreads();
} }
int i2 = blockIdx.x * blockDim.x + threadIdx.x; int64_t i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) { if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y; int row1 = threadIdx.y;
int row2 = threadIdx.y + 1; int row2 = threadIdx.y + 1;
...@@ -381,7 +415,7 @@ __global__ void LayerNormBackwardSumGradGammaBeta( ...@@ -381,7 +415,7 @@ __global__ void LayerNormBackwardSumGradGammaBeta(
const int n1, const int n2, U *grad_gamma, U *grad_beta) { const int n1, const int n2, U *grad_gamma, U *grad_beta) {
// sum partial gradients for gamma and beta // sum partial gradients for gamma and beta
__shared__ U buf[BDIMX * BDIMY]; __shared__ U buf[BDIMX * BDIMY];
int i2 = blockIdx.x * BDIMX + threadIdx.x; int64_t i2 = blockIdx.x * BDIMX + threadIdx.x;
if (i2 < n2) { if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps // each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / BDIMY; int num_warp_reductions = part_size / BDIMY;
...@@ -552,22 +586,17 @@ __global__ void LayerNormBackwardComputeGradInput( ...@@ -552,22 +586,17 @@ __global__ void LayerNormBackwardComputeGradInput(
// Make sure that d_scale != nullptr && d_bias != nullptr // Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr // Since d_scale != nullptr, scale would not be nullptr
template <typename T, typename U, int BlockDim, bool HasDx> template <typename T, typename U, int BlockDim, bool HasDx>
__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, __global__ void LayerNormBackwardGradientAll(
U *d_scale, U *d_bias, T *d_x, const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
const U *mean, const U *var, const U *var, const U *scale, float epsilon, int64_t batch_size,
const U *scale, float epsilon, int64_t feature_size, int64_t col_offset) {
int batch_size, int feature_size, int64_t beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
int col_offset) { int64_t end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>; int64_t stride = BlockDim * feature_size;
__shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
int end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
int stride = BlockDim * feature_size;
U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0); U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += stride) { for (int64_t i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size; int row_idx = i / feature_size;
auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon); auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
d_scale_partial += static_cast<U>(d_y[i]) * d_scale_partial += static_cast<U>(d_y[i]) *
...@@ -579,13 +608,12 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, ...@@ -579,13 +608,12 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
} }
} }
auto pair = BlockReduce(temp_storage) d_scale_partial = BlockReduceSum<U>(d_scale_partial);
.Reduce(PairForLayerNorm<U>(d_scale_partial, d_bias_partial), d_bias_partial = BlockReduceSum<U>(d_bias_partial);
PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
d_scale[blockIdx.x + col_offset] = pair.first_; d_scale[blockIdx.x + col_offset] = d_scale_partial;
d_bias[blockIdx.x + col_offset] = pair.second_; d_bias[blockIdx.x + col_offset] = d_bias_partial;
} }
} }
...@@ -595,16 +623,16 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, ...@@ -595,16 +623,16 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale> template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
__global__ void LayerNormBackwardGradientScaleOrBias( __global__ void LayerNormBackwardGradientScaleOrBias(
const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
const U *var, const U *scale, float epsilon, int batch_size, const U *var, const U *scale, float epsilon, int64_t batch_size,
int feature_size, int col_offset) { int64_t feature_size, int col_offset) {
using BlockReduce = cub::BlockReduce<U, BlockDim>; using BlockReduce = cub::BlockReduce<U, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset; int64_t beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
int end_idx = batch_size * feature_size + blockIdx.x + col_offset; int64_t end_idx = batch_size * feature_size + blockIdx.x + col_offset;
int stride = BlockDim * feature_size; int stride = BlockDim * feature_size;
U d_scale_or_d_bias_partial = static_cast<U>(0); U d_scale_or_d_bias_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += stride) { for (int64_t i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size; int row_idx = i / feature_size;
auto var_val = auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon)); static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
...@@ -639,22 +667,20 @@ __global__ void LayerNormBackwardGradientScaleOrBias( ...@@ -639,22 +667,20 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
} }
template <typename T, typename U, int BlockDim> template <typename T, typename U, int BlockDim>
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x, __global__ void LayerNormBackwardPostProcessToCalculateDX(
const U *mean, const T *x, T *d_x, const U *mean, const U *var, float epsilon,
const U *var, int64_t feature_size) {
float epsilon,
int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U d_x_reduce_tmp[2]; __shared__ U d_x_reduce_tmp[2];
int beg_idx = blockIdx.x * feature_size + threadIdx.x; int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size; int64_t end_idx = (blockIdx.x + 1) * feature_size;
U block_mean = mean[blockIdx.x]; U block_mean = mean[blockIdx.x];
U block_var = var[blockIdx.x]; U block_var = var[blockIdx.x];
U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0); U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
d_x_mean_partial += static_cast<U>(d_x[i]); d_x_mean_partial += static_cast<U>(d_x[i]);
d_x_var_partial += d_x_var_partial +=
static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean); static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
...@@ -675,7 +701,7 @@ __global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x, ...@@ -675,7 +701,7 @@ __global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
d_x_mean_partial = d_x_reduce_tmp[0]; d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1]; d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= static_cast<T>(d_x_mean_partial); d_x[i] -= static_cast<T>(d_x_mean_partial);
d_x[i] -= d_x[i] -=
static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial); static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
...@@ -688,17 +714,17 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, ...@@ -688,17 +714,17 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
T *d_x, const U *mean, T *d_x, const U *mean,
const U *var, const U *scale, const U *var, const U *scale,
float epsilon, float epsilon,
int feature_size) { int64_t feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U d_x_reduce_tmp[2]; __shared__ U d_x_reduce_tmp[2];
int beg_idx = blockIdx.x * feature_size + threadIdx.x; int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size; int64_t end_idx = (blockIdx.x + 1) * feature_size;
U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x]; U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0); U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
auto var_val = auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon)); static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
if (scale != nullptr) { if (scale != nullptr) {
...@@ -728,7 +754,7 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, ...@@ -728,7 +754,7 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
d_x_mean_partial = d_x_reduce_tmp[0]; d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1]; d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= static_cast<T>(d_x_mean_partial); d_x[i] -= static_cast<T>(d_x_mean_partial);
d_x[i] -= d_x[i] -=
static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial); static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
...@@ -738,8 +764,8 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, ...@@ -738,8 +764,8 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
template <typename T, typename U> template <typename T, typename U>
__global__ void LayerNormBackwardWhenBatchSizeIsOne( __global__ void LayerNormBackwardWhenBatchSizeIsOne(
const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean, const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
const U *var, const U *scale, float epsilon, int feature_size) { const U *var, const U *scale, float epsilon, int64_t feature_size) {
int idx = threadIdx.x + blockIdx.x * blockDim.x; int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < feature_size) { if (idx < feature_size) {
auto var_val = auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon)); static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon));
...@@ -764,8 +790,8 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( ...@@ -764,8 +790,8 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
template <typename T, typename U> template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale, static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
const U *mean, const U *var, T *d_x, U *d_scale, const U *mean, const U *var, T *d_x, U *d_scale,
U *d_bias, float epsilon, int batch_size, U *d_bias, float epsilon, int64_t batch_size,
int feature_size, int64_t feature_size,
const framework::ExecutionContext &ctx) { const framework::ExecutionContext &ctx) {
auto &dev_ctx = ctx.cuda_device_context(); auto &dev_ctx = ctx.cuda_device_context();
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
...@@ -925,8 +951,8 @@ void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream, ...@@ -925,8 +951,8 @@ void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
int begin_norm_axis, float eps) { int begin_norm_axis, float eps) {
const auto x_dims = framework::make_ddim(input_shape); const auto x_dims = framework::make_ddim(input_shape);
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int batch_size = static_cast<int>(matrix_dim[0]); int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int feature_size = static_cast<int>(matrix_dim[1]); int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
switch (GetDesiredBlockDim(feature_size)) { switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>( LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
...@@ -986,8 +1012,8 @@ class LayerNormKernel<platform::CUDADeviceContext, T> ...@@ -986,8 +1012,8 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>()); auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int batch_size = static_cast<int>(matrix_dim[0]); int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int feature_size = static_cast<int>(matrix_dim[1]); int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
...@@ -1040,8 +1066,8 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1040,8 +1066,8 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int batch_size = static_cast<int>(matrix_dim[0]); int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int feature_size = static_cast<int>(matrix_dim[1]); int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data, LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
d_x_data, d_scale_data, d_bias_data, epsilon, d_x_data, d_scale_data, d_bias_data, epsilon,
......
...@@ -51,6 +51,7 @@ class TestDygraphLayerNormv2(unittest.TestCase): ...@@ -51,6 +51,7 @@ class TestDygraphLayerNormv2(unittest.TestCase):
self.assertTrue(np.allclose(y1, y2)) self.assertTrue(np.allclose(y1, y2))
def test_static(self): def test_static(self):
paddle.enable_static()
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm"): if core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm"):
places.append(fluid.CUDAPlace(0)) places.append(fluid.CUDAPlace(0))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册