未验证 提交 36f08826 编写于 作者: Z zhangkaihuo 提交者: GitHub

opt bn1d backward (#44783)

上级 65f38869
......@@ -21,8 +21,10 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/batch_norm_utils.h"
#ifdef __HIPCC__
......@@ -197,6 +199,7 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
x_sum += x_i;
x_square_sum += x_i * x_i;
}
x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
x_square_sum =
BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum());
......@@ -218,6 +221,7 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_val);
db_sum += dy_i;
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
if (threadIdx.x == 0) {
......@@ -241,6 +245,263 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
}
}
template <typename T>
__device__ __forceinline__ void BlockReduceByVetical(
BatchNormParamType<T> x_sum,
BatchNormParamType<T> x_square_sum,
BatchNormParamType<T> *smem_sum,
BatchNormParamType<T> *smem_square_sum,
BatchNormParamType<T> *x_sum_out,
BatchNormParamType<T> *x_square_sum_out) {
int tid = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
if (threadIdx.y < offset * 2) {
smem_sum[tid] = x_sum;
smem_square_sum[tid] = x_square_sum;
}
__syncthreads();
if (threadIdx.y < offset) {
int pair_tid = tid + offset * blockDim.x;
x_sum += smem_sum[pair_tid];
x_square_sum += smem_square_sum[pair_tid];
}
}
if (threadIdx.y == 0) {
*x_sum_out = x_sum;
*x_square_sum_out = x_square_sum;
}
}
template <typename T, int BlockDim>
static __global__ void BNBackward2DChannelLastStage1(
const T *x,
const int C,
const int N,
const int HxW,
const double epsilon,
BatchNormParamType<T> *block_data_ptr,
BatchNormParamType<T> *compute_mean,
BatchNormParamType<T> *compute_inv_var,
int *flag_ptr) {
int outer_size = C;
int inner_size = N * HxW;
__shared__ BatchNormParamType<T> smem_sum[BlockDim];
__shared__ BatchNormParamType<T> smem_square_sum[BlockDim];
__shared__ BatchNormParamType<T> inv_var_val;
__shared__ BatchNormParamType<T> mean_val;
int outer_loop_stride = gridDim.x * blockDim.x;
int inner_loop_stride = gridDim.y * blockDim.y;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size;
i += outer_loop_stride) {
BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> x_square_sum = static_cast<BatchNormParamType<T>>(0);
for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size;
j += inner_loop_stride) {
const int index = j * outer_size + i;
BatchNormParamType<T> x_i = static_cast<BatchNormParamType<T>>(x[index]);
x_sum += x_i;
x_square_sum += x_i * x_i;
}
// vertical block sum
BlockReduceByVetical<T>(x_sum,
x_square_sum,
&smem_sum[0],
&smem_square_sum[0],
&x_sum,
&x_square_sum);
if (gridDim.y > 1) {
volatile BatchNormParamType<T> *staging_sum = block_data_ptr;
volatile BatchNormParamType<T> *staging_square_sum =
&block_data_ptr[C * gridDim.y];
// write block data to global memory
if (threadIdx.y == 0) {
staging_sum[i + blockIdx.y * C] = x_sum;
staging_square_sum[i + blockIdx.y * C] = x_square_sum;
}
// make sure write is visible to all blocks
__threadfence();
__syncthreads();
__shared__ bool is_last_block_done;
// mark block done
if (threadIdx.x == 0 && threadIdx.y == 0) {
int old = atomicAdd(&flag_ptr[blockIdx.x], 1);
is_last_block_done = (old == (gridDim.y - 1));
}
__syncthreads();
if (is_last_block_done) {
x_sum = static_cast<BatchNormParamType<T>>(0);
x_square_sum = static_cast<BatchNormParamType<T>>(0);
// thread sum
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
x_sum += staging_sum[i + y * C];
x_square_sum += staging_square_sum[i + y * C];
}
// vertical block sum
BlockReduceByVetical<T>(x_sum,
x_square_sum,
&smem_sum[0],
&smem_square_sum[0],
&x_sum,
&x_square_sum);
// final compute
if (threadIdx.y == 0) {
BatchNormParamType<T> compute_mean_val = x_sum / inner_size;
BatchNormParamType<T> variance_val =
x_square_sum / inner_size - compute_mean_val * compute_mean_val;
BatchNormParamType<T> compute_inv_var_val =
1 / sqrt(variance_val + epsilon);
compute_mean[i] = compute_mean_val;
compute_inv_var[i] = compute_inv_var_val;
}
}
}
}
}
template <typename T, int BlockDim>
static __global__ void BNBackward2DChannelLastStage2(
const T *dy,
const T *x,
const BatchNormParamType<T> *means,
const BatchNormParamType<T> *variances,
const int C,
const int N,
const int HxW,
const double epsilon,
BatchNormParamType<T> *block_data_ptr,
BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias,
int *flag_ptr) {
int outer_size = C;
int inner_size = N * HxW;
__shared__ BatchNormParamType<T> smem_ds_sum[BlockDim];
__shared__ BatchNormParamType<T> smem_db_sum[BlockDim];
__shared__ BatchNormParamType<T> inv_var_val;
__shared__ BatchNormParamType<T> mean_val;
int outer_loop_stride = gridDim.x * blockDim.x;
int inner_loop_stride = gridDim.y * blockDim.y;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size;
i += outer_loop_stride) {
BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> mean_val = means[i];
BatchNormParamType<T> inv_var_val = variances[i];
for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size;
j += inner_loop_stride) {
const int index = j * outer_size + i;
BatchNormParamType<T> dy_i =
static_cast<BatchNormParamType<T>>(dy[index]);
ds_sum +=
dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_val);
db_sum += dy_i;
}
// vertical block sum
BlockReduceByVetical<T>(
ds_sum, db_sum, &smem_ds_sum[0], &smem_db_sum[0], &ds_sum, &db_sum);
if (gridDim.y > 1) {
volatile BatchNormParamType<T> *staging_ds_sum = block_data_ptr;
volatile BatchNormParamType<T> *staging_db_sum =
&block_data_ptr[C * gridDim.y];
// write block data to global memory
if (threadIdx.y == 0) {
staging_ds_sum[i + blockIdx.y * C] = ds_sum;
staging_db_sum[i + blockIdx.y * C] = db_sum;
}
// make sure write is visible to all blocks
__threadfence();
__syncthreads();
__shared__ bool is_last_block_done;
// mark block done
if (threadIdx.x == 0 && threadIdx.y == 0) {
int old = atomicAdd(&flag_ptr[blockIdx.x], 1);
is_last_block_done = (old == (gridDim.y - 1));
}
__syncthreads();
if (is_last_block_done) {
ds_sum = static_cast<BatchNormParamType<T>>(0);
db_sum = static_cast<BatchNormParamType<T>>(0);
// thread sum
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
ds_sum += staging_ds_sum[i + y * C];
db_sum += staging_db_sum[i + y * C];
}
// vertical block sum
BlockReduceByVetical<T>(
ds_sum, db_sum, &smem_ds_sum[0], &smem_db_sum[0], &ds_sum, &db_sum);
// final compute
if (threadIdx.y == 0) {
dscale[i] = ds_sum * inv_var_val;
dbias[i] = db_sum;
}
}
}
}
}
template <typename T, int BlockDim>
static __global__ void BNBackward2DChannelLastStage3(
const T *dy,
const T *x,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *dscales,
const BatchNormParamType<T> *dbias,
const BatchNormParamType<T> *means,
const BatchNormParamType<T> *variances,
const int C,
const int N,
const int HxW,
const double epsilon,
T *dx) {
const int outer_size = C;
const int inner_size = N * HxW;
int outer_loop_stride = gridDim.x * blockDim.x;
int inner_loop_stride = gridDim.y * blockDim.y;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size;
i += outer_loop_stride) {
BatchNormParamType<T> mean_val = means[i];
BatchNormParamType<T> inv_var_val = variances[i];
BatchNormParamType<T> dscale_val = dscales[i];
BatchNormParamType<T> dbias_val = dbias[i];
for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size;
j += inner_loop_stride) {
const int index = j * outer_size + i;
dx[index] = scale[i] * inv_var_val *
(static_cast<BatchNormParamType<T>>(dy[index]) -
dbias_val / static_cast<BatchNormParamType<T>>(inner_size) -
(static_cast<BatchNormParamType<T>>(x[index]) - mean_val) *
inv_var_val * dscale_val / inner_size);
}
}
}
template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
const T *dy,
......@@ -592,42 +853,147 @@ void BatchNormGradRawKernel(const Context &ctx,
// epsilon, saved_mean_data, saved_var_data));
#else
// CUDNN only support small batch size
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_kernel) {
if (compute_format == DataLayout::kNCHW) {
BNBackward<T, block, DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
if (x_dims.size() == 2) {
dim3 block;
dim3 grid;
const int block_size = 512;
const int MAX_GRID_SIZE = 128;
const int WARP_SIZE = 32;
// init intermediate storage
DenseTensor block_data_tensor;
DenseTensor flag_tensor;
DenseTensor compute_mean_tensor =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
DenseTensor compute_inv_var_tensor =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
BatchNormParamType<T> *block_data_ptr = nullptr;
int *flag_ptr = nullptr;
int block_x =
std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE);
int block_y =
std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16),
block_size / block_x);
if (block_x * block_y != block_size) {
block_x = std::min(phi::funcs::details::GetLastPow2(C),
block_size / block_y);
}
int grid_x = (C + block_x - 1) / block_x;
int grid_y =
std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16),
MAX_GRID_SIZE);
block.x = block_x;
block.y = block_y;
grid.x = grid_x;
grid.y = grid_y;
if (grid.y > 1) {
block_data_tensor = phi::Empty<BatchNormParamType<T>, Context>(
ctx, {2 * C * grid.y});
flag_tensor = phi::Empty<int, Context>(ctx, {grid.x});
block_data_ptr = block_data_tensor.data<BatchNormParamType<T>>();
flag_ptr = flag_tensor.data<int>();
funcs::SetConstant<Context, int> set_zero;
set_zero(ctx, &flag_tensor, static_cast<int>(0));
}
// 1. reduce_sum(x) => mean, inv_var
auto *mean_ptr =
saved_mean_data == nullptr
? compute_mean_tensor.data<BatchNormParamType<T>>()
: saved_mean_data;
auto *variance_ptr =
saved_var_data == nullptr
? compute_inv_var_tensor.data<BatchNormParamType<T>>()
: saved_var_data;
if (saved_mean_data == nullptr) {
BNBackward2DChannelLastStage1<T, block_size>
<<<grid, block, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
C,
N,
H * W * D,
epsilon,
block_data_ptr,
compute_mean_tensor.data<BatchNormParamType<T>>(),
compute_inv_var_tensor.data<BatchNormParamType<T>>(),
flag_ptr);
}
// 2. reduce_sum(x, dy, mean) => dscale, dbias
BNBackward2DChannelLastStage2<T, block_size>
<<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
mean_ptr,
variance_ptr,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
block_data_ptr,
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
} else {
BNBackward<T, block, DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
flag_ptr);
// 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
BNBackward2DChannelLastStage3<T, block_size>
<<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>(),
mean_ptr,
variance_ptr,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
transformed_d_x.template data<T>());
} else {
if (compute_format == DataLayout::kNCHW) {
BNBackward<T, block, DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
} else {
BNBackward<T, block, DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
}
} else {
#if CUDNN_VERSION_MIN(7, 4, 1)
......
......@@ -908,7 +908,8 @@ void BatchNormKernel(const Context &ctx,
// static_cast<void *>(saved_variance->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace()))));
#else
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
......
......@@ -323,6 +323,35 @@ class TestBatchNormChannelLast(unittest.TestCase):
else:
self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True)
def test_1d_opt(self):
with fluid.dygraph.guard():
batch_size = 13700
channels = 16
shape = (batch_size, channels)
x = paddle.randn(shape)
x_4d = x.reshape((batch_size, channels, 1, 1))
x.stop_gradient = False
x_4d.stop_gradient = False
bn1d = paddle.nn.BatchNorm1D(channels)
bn2d = paddle.nn.BatchNorm2D(channels)
y = bn1d(x)
y2 = bn2d(x_4d)
y.backward()
y2.backward()
assert np.allclose(y.numpy().flatten(),
y2.numpy().flatten(),
atol=1e-5,
rtol=1e-5)
assert np.allclose(bn1d.weight.grad.numpy().flatten(),
bn2d.weight.grad.numpy().flatten(),
atol=1e-5,
rtol=1e-5)
class TestBatchNormUseGlobalStats(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册