提交 c50c5377 编写于 作者: S sneaxiy

fix arithmetic error in backward kernel

......@@ -273,9 +273,9 @@ op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
op_library(layer_norm_op DEPS cub)
else()
op_library(conv_op DEPS vol2col im2col)
endif()
......
......@@ -45,38 +45,55 @@ inline static int GetDesiredBlockDim(int block_dim) {
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }
template <typename T>
struct PairForLayerNorm {
__device__ __forceinline__ PairForLayerNorm() {}
__device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
: first_(first), second_(second) {}
T first_;
T second_;
};
template <typename T>
struct PairForLayerNormAddFunctor {
__device__ __forceinline__ PairForLayerNorm<T> operator()(
const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
}
};
template <typename T, int BlockDim>
__global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
T *y, T *mean, T *var, float epsilon,
int feature_size) {
using BlockReduce = cub::BlockReduce<T, BlockDim>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
// Step 1: Reduce to calculate mean
// Step 1: Reduce to calculate mean and var
T mean_val = static_cast<T>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) {
mean_val += x[i];
}
mean_val = BlockReduce(temp_storage).Reduce(mean_val, cub::Sum());
if (threadIdx.x == 0) mean[blockIdx.x] = mean_val / feature_size;
__syncthreads();
mean_val = mean[blockIdx.x];
// Step 2: Reduce to calculate var
T var_val = static_cast<T>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) {
T tmp = x[i] - mean_val;
T tmp = x[i];
mean_val += tmp;
var_val += (tmp * tmp);
}
var_val = BlockReduce(temp_storage).Reduce(var_val, cub::Sum());
if (threadIdx.x == 0) var[blockIdx.x] = var_val / feature_size;
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(mean_val, var_val),
PairForLayerNormAddFunctor<T>());
if (threadIdx.x == 0) {
auto tmp = pair.first_ / feature_size;
mean[blockIdx.x] = tmp;
var[blockIdx.x] = pair.second_ / feature_size - tmp * tmp;
}
__syncthreads();
mean_val = mean[blockIdx.x];
var_val = static_cast<T>(real_sqrt(var[blockIdx.x] + epsilon));
// Step 3: Calculate y
// Step 2: Calculate y
if (scale != nullptr) {
if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
......@@ -104,27 +121,6 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
}
}
template <typename T>
struct PairForLayerNormBackward {
__device__ __forceinline__ PairForLayerNormBackward() {}
__device__ __forceinline__ PairForLayerNormBackward(const T &first,
const T &second)
: first_(first), second_(second) {}
T first_;
T second_;
};
template <typename T>
struct PairForLayerNormBackwardAddFunctor {
__device__ __forceinline__ PairForLayerNormBackward<T> operator()(
const PairForLayerNormBackward<T> &p1,
const PairForLayerNormBackward<T> &p2) {
return PairForLayerNormBackward<T>(p1.first_ + p2.first_,
p1.second_ + p2.second_);
}
};
// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
template <typename T, int BlockDim, bool HasDx>
......@@ -133,12 +129,13 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
const T *mean, const T *var,
const T *scale, float epsilon,
int batch_size, int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNormBackward<T>, BlockDim>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = threadIdx.x * feature_size + blockIdx.x;
int end_idx = batch_size * feature_size + blockIdx.x;
int stride = BlockDim * feature_size;
T d_scale_partial = 0, d_bias_partial = 0;
for (int i = beg_idx; i < end_idx; i += stride) {
......@@ -146,13 +143,14 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
d_bias_partial += d_y[i];
if (HasDx) d_x[i] = d_y[i] * scale[blockIdx.x] / var_val;
if (HasDx) {
d_x[i] = d_y[i] * scale[blockIdx.x] / var_val;
}
}
auto pair =
BlockReduce(temp_storage)
.Reduce(PairForLayerNormBackward<T>(d_scale_partial, d_bias_partial),
PairForLayerNormBackwardAddFunctor<T>());
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_scale_partial, d_bias_partial),
PairForLayerNormAddFunctor<T>());
if (threadIdx.x == 0) {
d_scale[blockIdx.x] = pair.first_;
......@@ -205,22 +203,90 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
}
}
template <typename T, int BlockDim>
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
const T *mean,
const T *var,
float epsilon,
int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T d_x_reduce_tmp[2];
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
T block_mean = mean[blockIdx.x];
T block_var = var[blockIdx.x];
T d_x_mean_partial = 0, d_x_var_partial = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x_mean_partial += d_x[i];
d_x_var_partial += d_x[i] * (x[i] - block_mean);
}
auto pair =
BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<T>());
if (threadIdx.x == 0) {
d_x_reduce_tmp[0] = pair.first_ / feature_size;
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
}
__syncthreads();
d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= d_x_mean_partial;
d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
}
}
// Here, we only calculate d_x
template <typename T>
__global__ void LayerNormBackwardGradientOnlyX(const T *d_y, T *d_x,
template <typename T, int BlockDim>
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
T *d_x, const T *mean,
const T *var, const T *scale,
float epsilon, int batch_size,
float epsilon,
int feature_size) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < batch_size * feature_size) {
int row_idx = idx / feature_size;
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T d_x_reduce_tmp[2];
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
T block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
T d_x_mean_partial = 0, d_x_var_partial = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) {
auto var_val = static_cast<T>(real_sqrt(block_var + epsilon));
if (scale != nullptr) {
int col_idx = idx % feature_size;
d_x[idx] = d_y[idx] * scale[col_idx] / var_val;
int col_idx = i % feature_size;
d_x[i] = d_y[i] * scale[col_idx] / var_val;
} else {
d_x[idx] = d_y[idx] / var_val;
d_x[i] = d_y[i] / var_val;
}
d_x_mean_partial += d_x[i];
d_x_var_partial += d_x[i] * (x[i] - block_mean);
}
auto pair =
BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<T>());
if (threadIdx.x == 0) {
d_x_reduce_tmp[0] = pair.first_ / feature_size;
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
}
__syncthreads();
d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= d_x_mean_partial;
d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
}
}
......@@ -263,6 +329,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0,
stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
feature_size);
if (d_x != nullptr) {
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<1, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
}
return;
}
......@@ -296,10 +370,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
}
break;
case 4: // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
LayerNormBackwardGradientOnlyX<
T><<<(batch_size * feature_size + kMaxBlockDim - 1) / kMaxBlockDim,
kMaxBlockDim, 0, stream>>>(d_y, d_x, var, scale, epsilon,
batch_size, feature_size);
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardGradientOnlyDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_y, d_x, mean, var, scale, epsilon, feature_size));
}
break;
case 5: // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
switch (block_dim) {
......@@ -309,6 +385,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
feature_size));
}
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
break;
case 6: // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
switch (block_dim) {
......@@ -318,6 +400,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
feature_size));
}
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
break;
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
switch (block_dim) {
......@@ -327,6 +415,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size));
}
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
break;
default:
break;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册