未验证 提交 a9fd0807 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of batch_norm_bwd with NHWC layout and infer mode (#49209)

* Optimize performance of batch_norm_bwd with NHWC layout and infer mode

* fix
上级 dc694f1e
......@@ -382,6 +382,7 @@ static __global__ void BNBackward2DChannelLastStage2(
const int N,
const int HxW,
const double epsilon,
const bool is_test,
BatchNormParamType<T> *block_data_ptr,
BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias,
......@@ -402,7 +403,8 @@ static __global__ void BNBackward2DChannelLastStage2(
BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> mean_val = means[i];
BatchNormParamType<T> inv_var_val = variances[i];
BatchNormParamType<T> inv_var_val =
is_test ? 1.0 / sqrt(variances[i] + epsilon) : variances[i];
for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size;
j += inner_loop_stride) {
......@@ -561,6 +563,51 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
}
}
template <typename T, typename Context>
void SetLaunchConfigInfoForChannelLast(const Context &ctx,
DenseTensor *block_data_tensor,
DenseTensor *flag_tensor,
BatchNormParamType<T> **block_data_ptr,
int **flag_ptr,
const int N,
const int H,
const int W,
const int D,
const int C,
const int block_size,
dim3 *block,
dim3 *grid) {
const int MAX_GRID_SIZE = 128;
const int WARP_SIZE = 32;
int block_x = std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE);
int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16),
block_size / block_x);
if (block_x * block_y != block_size) {
block_x =
std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y);
}
int grid_x = (C + block_x - 1) / block_x;
int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16),
MAX_GRID_SIZE);
block->x = block_x;
block->y = block_y;
grid->x = grid_x;
grid->y = grid_y;
if (grid->y > 1) {
*block_data_tensor =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {2 * C * grid->y});
*flag_tensor = phi::Empty<int, Context>(ctx, {grid->x});
*block_data_ptr = block_data_tensor->data<BatchNormParamType<T>>();
*flag_ptr = flag_tensor->data<int>();
funcs::SetConstant<Context, int> set_zero;
set_zero(ctx, flag_tensor, static_cast<int>(0));
}
}
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context &ctx,
const DenseTensor &x,
......@@ -875,8 +922,6 @@ void BatchNormGradRawKernel(const Context &ctx,
dim3 block;
dim3 grid;
const int block_size = 512;
const int MAX_GRID_SIZE = 128;
const int WARP_SIZE = 32;
// init intermediate storage
DenseTensor block_data_tensor;
......@@ -889,35 +934,20 @@ void BatchNormGradRawKernel(const Context &ctx,
BatchNormParamType<T> *block_data_ptr = nullptr;
int *flag_ptr = nullptr;
int block_x =
std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE);
int block_y =
std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16),
block_size / block_x);
if (block_x * block_y != block_size) {
block_x = std::min(phi::funcs::details::GetLastPow2(C),
block_size / block_y);
}
int grid_x = (C + block_x - 1) / block_x;
int grid_y =
std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16),
MAX_GRID_SIZE);
block.x = block_x;
block.y = block_y;
grid.x = grid_x;
grid.y = grid_y;
if (grid.y > 1) {
block_data_tensor = phi::Empty<BatchNormParamType<T>, Context>(
ctx, {2 * C * grid.y});
flag_tensor = phi::Empty<int, Context>(ctx, {grid.x});
block_data_ptr = block_data_tensor.data<BatchNormParamType<T>>();
flag_ptr = flag_tensor.data<int>();
funcs::SetConstant<Context, int> set_zero;
set_zero(ctx, &flag_tensor, static_cast<int>(0));
}
SetLaunchConfigInfoForChannelLast<T>(ctx,
&block_data_tensor,
&flag_tensor,
&block_data_ptr,
&flag_ptr,
N,
H,
W,
D,
C,
block_size,
&block,
&grid);
// 1. reduce_sum(x) => mean, inv_var
auto *mean_ptr =
saved_mean_data == nullptr
......@@ -967,6 +997,7 @@ void BatchNormGradRawKernel(const Context &ctx,
N,
H * W * D,
epsilon,
false,
block_data_ptr,
dscale,
dbias,
......@@ -1256,18 +1287,44 @@ void BatchNormGradRawKernel(const Context &ctx,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNHWC>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
dim3 block;
dim3 grid;
const int block_size = 512;
// init intermediate storage
DenseTensor block_data_tensor;
DenseTensor flag_tensor;
BatchNormParamType<T> *block_data_ptr = nullptr;
int *flag_ptr = nullptr;
SetLaunchConfigInfoForChannelLast<T>(ctx,
&block_data_tensor,
&flag_tensor,
&block_data_ptr,
&flag_ptr,
N,
H,
W,
D,
C,
block_size,
&block,
&grid);
BNBackward2DChannelLastStage2<T, block_size>
<<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
running_mean_data,
running_var_data,
epsilon,
N,
C,
N,
H * W * D,
epsilon,
true,
block_data_ptr,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
d_bias->data<BatchNormParamType<T>>(),
flag_ptr);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册