未验证 提交 6cdc18af 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse]Optimize BatchNorm1D forward in test mode (#47736)

上级 1ad95e97
...@@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x, ...@@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
} }
} }
template <typename T>
static __global__ void InverseVariance(const BatchNormParamType<T> *variance,
const double epsilon,
const int C,
BatchNormParamType<T> *inv_variance) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < C) {
inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon);
}
}
template <typename T, phi::DataLayout layout>
static __global__ void BN1DForwardInference(
const T *x,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *inv_variance,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const int C,
const int N,
const int HxW,
const double epsilon,
T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int num = N * C * HxW;
for (int i = gid; i < num; i += stride) {
const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_variance[c] + bias[c]);
}
}
template <typename T, int BlockDim, phi::DataLayout layout> template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining( static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
const T *x, const T *x,
...@@ -795,7 +829,7 @@ void BatchNormKernel(const Context &ctx, ...@@ -795,7 +829,7 @@ void BatchNormKernel(const Context &ctx,
// epsilon)); // epsilon));
#else #else
const bool use_native_kernel = const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || (x_dims.size() == 2 ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_kernel) { if (use_native_kernel) {
const int block_size = 256; const int block_size = 256;
...@@ -814,18 +848,43 @@ void BatchNormKernel(const Context &ctx, ...@@ -814,18 +848,43 @@ void BatchNormKernel(const Context &ctx,
epsilon, epsilon,
transformed_y.template data<T>()); transformed_y.template data<T>());
} else { } else {
BNForwardInference<T, DataLayout::kNHWC> if (x_dims.size() == 2) {
<<<grid_size, block_size, 0, ctx.stream()>>>( DenseTensor inv_var = phi::Empty<BatchNormParamType<T>>(ctx, {C});
transformed_x.template data<T>(), auto *inv_var_ptr = inv_var.data<BatchNormParamType<T>>();
est_mean->template data<BatchNormParamType<T>>(), const int threads = 512 > C ? C : 512;
est_var->template data<BatchNormParamType<T>>(), const int blocks = (C + 511) / 512;
scale.template data<BatchNormParamType<T>>(), InverseVariance<T><<<blocks, threads>>>(
bias.template data<BatchNormParamType<T>>(), est_var->template data<BatchNormParamType<T>>(),
C, epsilon,
N, C,
H * W * D, inv_var_ptr);
epsilon, BN1DForwardInference<T, DataLayout::kNHWC>
transformed_y.template data<T>()); <<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
// est_var->template data<BatchNormParamType<T>>(),
inv_var_ptr,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
} }
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册