未验证 提交 a9fd0807 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of batch_norm_bwd with NHWC layout and infer mode (#49209)

* Optimize performance of batch_norm_bwd with NHWC layout and infer mode

* fix
上级 dc694f1e
...@@ -382,6 +382,7 @@ static __global__ void BNBackward2DChannelLastStage2( ...@@ -382,6 +382,7 @@ static __global__ void BNBackward2DChannelLastStage2(
const int N, const int N,
const int HxW, const int HxW,
const double epsilon, const double epsilon,
const bool is_test,
BatchNormParamType<T> *block_data_ptr, BatchNormParamType<T> *block_data_ptr,
BatchNormParamType<T> *dscale, BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias, BatchNormParamType<T> *dbias,
...@@ -402,7 +403,8 @@ static __global__ void BNBackward2DChannelLastStage2( ...@@ -402,7 +403,8 @@ static __global__ void BNBackward2DChannelLastStage2(
BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0); BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0); BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> mean_val = means[i]; BatchNormParamType<T> mean_val = means[i];
BatchNormParamType<T> inv_var_val = variances[i]; BatchNormParamType<T> inv_var_val =
is_test ? 1.0 / sqrt(variances[i] + epsilon) : variances[i];
for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size;
j += inner_loop_stride) { j += inner_loop_stride) {
...@@ -561,6 +563,51 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( ...@@ -561,6 +563,51 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
} }
} }
template <typename T, typename Context>
void SetLaunchConfigInfoForChannelLast(const Context &ctx,
DenseTensor *block_data_tensor,
DenseTensor *flag_tensor,
BatchNormParamType<T> **block_data_ptr,
int **flag_ptr,
const int N,
const int H,
const int W,
const int D,
const int C,
const int block_size,
dim3 *block,
dim3 *grid) {
const int MAX_GRID_SIZE = 128;
const int WARP_SIZE = 32;
int block_x = std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE);
int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16),
block_size / block_x);
if (block_x * block_y != block_size) {
block_x =
std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y);
}
int grid_x = (C + block_x - 1) / block_x;
int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16),
MAX_GRID_SIZE);
block->x = block_x;
block->y = block_y;
grid->x = grid_x;
grid->y = grid_y;
if (grid->y > 1) {
*block_data_tensor =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {2 * C * grid->y});
*flag_tensor = phi::Empty<int, Context>(ctx, {grid->x});
*block_data_ptr = block_data_tensor->data<BatchNormParamType<T>>();
*flag_ptr = flag_tensor->data<int>();
funcs::SetConstant<Context, int> set_zero;
set_zero(ctx, flag_tensor, static_cast<int>(0));
}
}
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context &ctx, void BatchNormGradRawKernel(const Context &ctx,
const DenseTensor &x, const DenseTensor &x,
...@@ -875,8 +922,6 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -875,8 +922,6 @@ void BatchNormGradRawKernel(const Context &ctx,
dim3 block; dim3 block;
dim3 grid; dim3 grid;
const int block_size = 512; const int block_size = 512;
const int MAX_GRID_SIZE = 128;
const int WARP_SIZE = 32;
// init intermediate storage // init intermediate storage
DenseTensor block_data_tensor; DenseTensor block_data_tensor;
...@@ -889,35 +934,20 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -889,35 +934,20 @@ void BatchNormGradRawKernel(const Context &ctx,
BatchNormParamType<T> *block_data_ptr = nullptr; BatchNormParamType<T> *block_data_ptr = nullptr;
int *flag_ptr = nullptr; int *flag_ptr = nullptr;
int block_x = SetLaunchConfigInfoForChannelLast<T>(ctx,
std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE); &block_data_tensor,
int block_y = &flag_tensor,
std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), &block_data_ptr,
block_size / block_x); &flag_ptr,
if (block_x * block_y != block_size) { N,
block_x = std::min(phi::funcs::details::GetLastPow2(C), H,
block_size / block_y); W,
} D,
int grid_x = (C + block_x - 1) / block_x; C,
int grid_y = block_size,
std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16), &block,
MAX_GRID_SIZE); &grid);
block.x = block_x;
block.y = block_y;
grid.x = grid_x;
grid.y = grid_y;
if (grid.y > 1) {
block_data_tensor = phi::Empty<BatchNormParamType<T>, Context>(
ctx, {2 * C * grid.y});
flag_tensor = phi::Empty<int, Context>(ctx, {grid.x});
block_data_ptr = block_data_tensor.data<BatchNormParamType<T>>();
flag_ptr = flag_tensor.data<int>();
funcs::SetConstant<Context, int> set_zero;
set_zero(ctx, &flag_tensor, static_cast<int>(0));
}
// 1. reduce_sum(x) => mean, inv_var // 1. reduce_sum(x) => mean, inv_var
auto *mean_ptr = auto *mean_ptr =
saved_mean_data == nullptr saved_mean_data == nullptr
...@@ -967,6 +997,7 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -967,6 +997,7 @@ void BatchNormGradRawKernel(const Context &ctx,
N, N,
H * W * D, H * W * D,
epsilon, epsilon,
false,
block_data_ptr, block_data_ptr,
dscale, dscale,
dbias, dbias,
...@@ -1256,18 +1287,44 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -1256,18 +1287,44 @@ void BatchNormGradRawKernel(const Context &ctx,
d_x->data<T>()); d_x->data<T>());
} }
if (d_scale && d_bias) { if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNHWC> dim3 block;
<<<grid2, block, 0, stream>>>( dim3 grid;
d_y->data<T>(), const int block_size = 512;
x.data<T>(),
// init intermediate storage
DenseTensor block_data_tensor;
DenseTensor flag_tensor;
BatchNormParamType<T> *block_data_ptr = nullptr;
int *flag_ptr = nullptr;
SetLaunchConfigInfoForChannelLast<T>(ctx,
&block_data_tensor,
&flag_tensor,
&block_data_ptr,
&flag_ptr,
N,
H,
W,
D,
C,
block_size,
&block,
&grid);
BNBackward2DChannelLastStage2<T, block_size>
<<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
running_mean_data, running_mean_data,
running_var_data, running_var_data,
epsilon,
N,
C, C,
N,
H * W * D, H * W * D,
epsilon,
true,
block_data_ptr,
d_scale->data<BatchNormParamType<T>>(), d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>()); d_bias->data<BatchNormParamType<T>>(),
flag_ptr);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册