提交 c50c5377 编写于 作者: S sneaxiy

fix arithmetic error in backward kernel

...@@ -273,9 +273,9 @@ op_library(squeeze_op DEPS reshape_op) ...@@ -273,9 +273,9 @@ op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory) op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op) op_library(flatten_op DEPS reshape_op)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(conv_op DEPS vol2col depthwise_conv im2col)
op_library(layer_norm_op DEPS cub)
else() else()
op_library(conv_op DEPS vol2col im2col) op_library(conv_op DEPS vol2col im2col)
endif() endif()
......
...@@ -45,38 +45,55 @@ inline static int GetDesiredBlockDim(int block_dim) { ...@@ -45,38 +45,55 @@ inline static int GetDesiredBlockDim(int block_dim) {
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); } static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); } static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }
template <typename T>
struct PairForLayerNorm {
__device__ __forceinline__ PairForLayerNorm() {}
__device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
: first_(first), second_(second) {}
T first_;
T second_;
};
template <typename T>
struct PairForLayerNormAddFunctor {
__device__ __forceinline__ PairForLayerNorm<T> operator()(
const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
}
};
template <typename T, int BlockDim> template <typename T, int BlockDim>
__global__ void LayerNormForward(const T *x, const T *scale, const T *bias, __global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
T *y, T *mean, T *var, float epsilon, T *y, T *mean, T *var, float epsilon,
int feature_size) { int feature_size) {
using BlockReduce = cub::BlockReduce<T, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = blockIdx.x * feature_size + threadIdx.x; int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size; int end_idx = (blockIdx.x + 1) * feature_size;
// Step 1: Reduce to calculate mean // Step 1: Reduce to calculate mean and var
T mean_val = static_cast<T>(0); T mean_val = static_cast<T>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) {
mean_val += x[i];
}
mean_val = BlockReduce(temp_storage).Reduce(mean_val, cub::Sum());
if (threadIdx.x == 0) mean[blockIdx.x] = mean_val / feature_size;
__syncthreads();
mean_val = mean[blockIdx.x];
// Step 2: Reduce to calculate var
T var_val = static_cast<T>(0); T var_val = static_cast<T>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int i = beg_idx; i < end_idx; i += BlockDim) {
T tmp = x[i] - mean_val; T tmp = x[i];
mean_val += tmp;
var_val += (tmp * tmp); var_val += (tmp * tmp);
} }
var_val = BlockReduce(temp_storage).Reduce(var_val, cub::Sum()); auto pair = BlockReduce(temp_storage)
if (threadIdx.x == 0) var[blockIdx.x] = var_val / feature_size; .Reduce(PairForLayerNorm<T>(mean_val, var_val),
PairForLayerNormAddFunctor<T>());
if (threadIdx.x == 0) {
auto tmp = pair.first_ / feature_size;
mean[blockIdx.x] = tmp;
var[blockIdx.x] = pair.second_ / feature_size - tmp * tmp;
}
__syncthreads(); __syncthreads();
mean_val = mean[blockIdx.x];
var_val = static_cast<T>(real_sqrt(var[blockIdx.x] + epsilon)); var_val = static_cast<T>(real_sqrt(var[blockIdx.x] + epsilon));
// Step 3: Calculate y // Step 2: Calculate y
if (scale != nullptr) { if (scale != nullptr) {
if (bias != nullptr) { if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx; for (int i = beg_idx, j = threadIdx.x; i < end_idx;
...@@ -104,27 +121,6 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias, ...@@ -104,27 +121,6 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
} }
} }
template <typename T>
struct PairForLayerNormBackward {
__device__ __forceinline__ PairForLayerNormBackward() {}
__device__ __forceinline__ PairForLayerNormBackward(const T &first,
const T &second)
: first_(first), second_(second) {}
T first_;
T second_;
};
template <typename T>
struct PairForLayerNormBackwardAddFunctor {
__device__ __forceinline__ PairForLayerNormBackward<T> operator()(
const PairForLayerNormBackward<T> &p1,
const PairForLayerNormBackward<T> &p2) {
return PairForLayerNormBackward<T>(p1.first_ + p2.first_,
p1.second_ + p2.second_);
}
};
// Make sure that d_scale != nullptr && d_bias != nullptr // Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr // Since d_scale != nullptr, scale would not be nullptr
template <typename T, int BlockDim, bool HasDx> template <typename T, int BlockDim, bool HasDx>
...@@ -133,12 +129,13 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, ...@@ -133,12 +129,13 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
const T *mean, const T *var, const T *mean, const T *var,
const T *scale, float epsilon, const T *scale, float epsilon,
int batch_size, int feature_size) { int batch_size, int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNormBackward<T>, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = threadIdx.x * feature_size + blockIdx.x; int beg_idx = threadIdx.x * feature_size + blockIdx.x;
int end_idx = batch_size * feature_size + blockIdx.x; int end_idx = batch_size * feature_size + blockIdx.x;
int stride = BlockDim * feature_size; int stride = BlockDim * feature_size;
T d_scale_partial = 0, d_bias_partial = 0; T d_scale_partial = 0, d_bias_partial = 0;
for (int i = beg_idx; i < end_idx; i += stride) { for (int i = beg_idx; i < end_idx; i += stride) {
...@@ -146,13 +143,14 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, ...@@ -146,13 +143,14 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon)); auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val; d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
d_bias_partial += d_y[i]; d_bias_partial += d_y[i];
if (HasDx) d_x[i] = d_y[i] * scale[blockIdx.x] / var_val; if (HasDx) {
d_x[i] = d_y[i] * scale[blockIdx.x] / var_val;
}
} }
auto pair = auto pair = BlockReduce(temp_storage)
BlockReduce(temp_storage) .Reduce(PairForLayerNorm<T>(d_scale_partial, d_bias_partial),
.Reduce(PairForLayerNormBackward<T>(d_scale_partial, d_bias_partial), PairForLayerNormAddFunctor<T>());
PairForLayerNormBackwardAddFunctor<T>());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
d_scale[blockIdx.x] = pair.first_; d_scale[blockIdx.x] = pair.first_;
...@@ -205,22 +203,90 @@ __global__ void LayerNormBackwardGradientScaleOrBias( ...@@ -205,22 +203,90 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
} }
} }
template <typename T, int BlockDim>
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
const T *mean,
const T *var,
float epsilon,
int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T d_x_reduce_tmp[2];
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
T block_mean = mean[blockIdx.x];
T block_var = var[blockIdx.x];
T d_x_mean_partial = 0, d_x_var_partial = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x_mean_partial += d_x[i];
d_x_var_partial += d_x[i] * (x[i] - block_mean);
}
auto pair =
BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<T>());
if (threadIdx.x == 0) {
d_x_reduce_tmp[0] = pair.first_ / feature_size;
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
}
__syncthreads();
d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= d_x_mean_partial;
d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
}
}
// Here, we only calculate d_x // Here, we only calculate d_x
template <typename T> template <typename T, int BlockDim>
__global__ void LayerNormBackwardGradientOnlyX(const T *d_y, T *d_x, __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
T *d_x, const T *mean,
const T *var, const T *scale, const T *var, const T *scale,
float epsilon, int batch_size, float epsilon,
int feature_size) { int feature_size) {
int idx = threadIdx.x + blockIdx.x * blockDim.x; using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
if (idx < batch_size * feature_size) { __shared__ typename BlockReduce::TempStorage temp_storage;
int row_idx = idx / feature_size; __shared__ T d_x_reduce_tmp[2];
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
T block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
T d_x_mean_partial = 0, d_x_var_partial = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) {
auto var_val = static_cast<T>(real_sqrt(block_var + epsilon));
if (scale != nullptr) { if (scale != nullptr) {
int col_idx = idx % feature_size; int col_idx = i % feature_size;
d_x[idx] = d_y[idx] * scale[col_idx] / var_val; d_x[i] = d_y[i] * scale[col_idx] / var_val;
} else { } else {
d_x[idx] = d_y[idx] / var_val; d_x[i] = d_y[i] / var_val;
}
d_x_mean_partial += d_x[i];
d_x_var_partial += d_x[i] * (x[i] - block_mean);
}
auto pair =
BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<T>());
if (threadIdx.x == 0) {
d_x_reduce_tmp[0] = pair.first_ / feature_size;
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
} }
__syncthreads();
d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= d_x_mean_partial;
d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
} }
} }
...@@ -263,6 +329,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -263,6 +329,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0, T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0,
stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
feature_size); feature_size);
if (d_x != nullptr) {
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<1, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
}
return; return;
} }
...@@ -296,10 +370,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -296,10 +370,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
} }
break; break;
case 4: // d_x != nullptr, d_scale == nullptr, d_bias == nullptr case 4: // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
LayerNormBackwardGradientOnlyX< switch (GetDesiredBlockDim(feature_size)) {
T><<<(batch_size * feature_size + kMaxBlockDim - 1) / kMaxBlockDim, FIXED_BLOCK_DIM_CASE(
kMaxBlockDim, 0, stream>>>(d_y, d_x, var, scale, epsilon, LayerNormBackwardGradientOnlyDX<
batch_size, feature_size); T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_y, d_x, mean, var, scale, epsilon, feature_size));
}
break; break;
case 5: // d_x != nulptr, d_scale == nullptr, d_bias != nullptr case 5: // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
switch (block_dim) { switch (block_dim) {
...@@ -309,6 +385,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -309,6 +385,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
feature_size)); feature_size));
} }
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
break; break;
case 6: // d_x != nullptr, d_scale != nullptr, d_bias == nullptr case 6: // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
switch (block_dim) { switch (block_dim) {
...@@ -318,6 +400,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -318,6 +400,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
feature_size)); feature_size));
} }
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
break; break;
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
switch (block_dim) { switch (block_dim) {
...@@ -327,6 +415,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -327,6 +415,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size)); batch_size, feature_size));
} }
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
break; break;
default: default:
break; break;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册