未验证 提交 9a4acfee 编写于 作者: Z zhangkaihuo 提交者: GitHub

optimize 2D sync_batch_norm (#49663)

上级 118aee6f
...@@ -26,6 +26,7 @@ namespace cub = hipcub; ...@@ -26,6 +26,7 @@ namespace cub = hipcub;
#endif #endif
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#ifdef __HIPCC__ #ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim) #define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
...@@ -36,8 +37,6 @@ namespace cub = hipcub; ...@@ -36,8 +37,6 @@ namespace cub = hipcub;
namespace phi { namespace phi {
namespace funcs { namespace funcs {
using DataLayout = phi::DataLayout;
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx, // math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) * // axis=(n,h,w)) *
// np.sum(dy, axis=(n,h,w)) - // np.sum(dy, axis=(n,h,w)) -
...@@ -670,5 +669,122 @@ void NormDoubleGradFunctor(const DeviceContext &ctx, ...@@ -670,5 +669,122 @@ void NormDoubleGradFunctor(const DeviceContext &ctx,
} }
} }
} }
template <typename T, typename BnT>
__device__ __forceinline__ void BlockReduceByVetical(BnT x_sum,
BnT x_square_sum,
BnT *smem_sum,
BnT *smem_square_sum,
BnT *x_sum_out,
BnT *x_square_sum_out) {
int tid = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
if (threadIdx.y < offset * 2) {
smem_sum[tid] = x_sum;
smem_square_sum[tid] = x_square_sum;
}
__syncthreads();
if (threadIdx.y < offset) {
int pair_tid = tid + offset * blockDim.x;
x_sum += smem_sum[pair_tid];
x_square_sum += smem_square_sum[pair_tid];
}
}
if (threadIdx.y == 0) {
*x_sum_out = x_sum;
*x_square_sum_out = x_square_sum;
}
}
template <typename T, typename BnT>
__device__ __forceinline__ void ReduceSumPost(const int C, // channels
const int c, // channel index
BnT *sum1,
BnT *sum2,
bool *is_last_block_done,
BnT *cache1,
BnT *cache2,
BnT *block_data_ptr,
int *flag_ptr) {
volatile BnT *staging_sum = block_data_ptr;
volatile BnT *staging_sum2 = &block_data_ptr[C * gridDim.y];
// write block data to global memory
if (threadIdx.y == 0) {
staging_sum[c + blockIdx.y * C] = *sum1;
staging_sum2[c + blockIdx.y * C] = *sum2;
}
// make sure write is visible to all blocks
__threadfence();
__syncthreads();
// mark block done
if (threadIdx.x == 0 && threadIdx.y == 0) {
int old = atomicAdd(&flag_ptr[blockIdx.x], 1);
*is_last_block_done = (old == (gridDim.y - 1));
}
__syncthreads();
if (*is_last_block_done) {
*sum1 = static_cast<BnT>(0);
*sum2 = static_cast<BnT>(0);
// thread sum
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
*sum1 += staging_sum[c + y * C];
*sum2 += staging_sum2[c + y * C];
}
// vertical block sum
funcs::BlockReduceByVetical<T, BnT>(
*sum1, *sum2, &cache1[0], &cache2[0], sum1, sum2);
}
}
template <typename T, typename BnT, typename Context>
void SetLaunchConfigInfoForChannelLast(const Context &ctx,
DenseTensor *block_data_tensor,
DenseTensor *flag_tensor,
BnT **block_data_ptr,
int **flag_ptr,
const int N,
const int H,
const int W,
const int D,
const int C,
const int block_size,
dim3 *block,
dim3 *grid) {
const int MAX_GRID_SIZE = 128;
const int WARP_SIZE = 32;
int block_x = std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE);
int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16),
block_size / block_x);
if (block_x * block_y != block_size) {
block_x =
std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y);
}
int grid_x = (C + block_x - 1) / block_x;
int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16),
MAX_GRID_SIZE);
block->x = block_x;
block->y = block_y;
grid->x = grid_x;
grid->y = grid_y;
if (grid->y > 1) {
*block_data_tensor = phi::Empty<BnT, Context>(ctx, {2 * C * grid->y});
*flag_tensor = phi::Empty<int, Context>(ctx, {grid->x});
*block_data_ptr = block_data_tensor->data<BnT>();
*flag_ptr = flag_tensor->data<int>();
funcs::SetConstant<Context, int> set_zero;
set_zero(ctx, flag_tensor, static_cast<int>(0));
}
}
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -245,34 +245,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward( ...@@ -245,34 +245,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
} }
} }
template <typename T>
__device__ __forceinline__ void BlockReduceByVetical(
BatchNormParamType<T> x_sum,
BatchNormParamType<T> x_square_sum,
BatchNormParamType<T> *smem_sum,
BatchNormParamType<T> *smem_square_sum,
BatchNormParamType<T> *x_sum_out,
BatchNormParamType<T> *x_square_sum_out) {
int tid = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
if (threadIdx.y < offset * 2) {
smem_sum[tid] = x_sum;
smem_square_sum[tid] = x_square_sum;
}
__syncthreads();
if (threadIdx.y < offset) {
int pair_tid = tid + offset * blockDim.x;
x_sum += smem_sum[pair_tid];
x_square_sum += smem_square_sum[pair_tid];
}
}
if (threadIdx.y == 0) {
*x_sum_out = x_sum;
*x_square_sum_out = x_square_sum;
}
}
template <typename T, int BlockDim> template <typename T, int BlockDim>
static __global__ void BNBackward2DChannelLastStage1( static __global__ void BNBackward2DChannelLastStage1(
const T *x, const T *x,
...@@ -309,53 +281,25 @@ static __global__ void BNBackward2DChannelLastStage1( ...@@ -309,53 +281,25 @@ static __global__ void BNBackward2DChannelLastStage1(
} }
// vertical block sum // vertical block sum
BlockReduceByVetical<T>(x_sum, funcs::BlockReduceByVetical<T, BatchNormParamType<T>>(x_sum,
x_square_sum, x_square_sum,
&smem_sum[0], &smem_sum[0],
&smem_square_sum[0], &smem_square_sum[0],
&x_sum, &x_sum,
&x_square_sum); &x_square_sum);
if (gridDim.y > 1) { if (gridDim.y > 1) {
volatile BatchNormParamType<T> *staging_sum = block_data_ptr;
volatile BatchNormParamType<T> *staging_square_sum =
&block_data_ptr[C * gridDim.y];
// write block data to global memory
if (threadIdx.y == 0) {
staging_sum[i + blockIdx.y * C] = x_sum;
staging_square_sum[i + blockIdx.y * C] = x_square_sum;
}
// make sure write is visible to all blocks
__threadfence();
__syncthreads();
__shared__ bool is_last_block_done; __shared__ bool is_last_block_done;
// mark block done funcs::ReduceSumPost<T, BatchNormParamType<T>>(C,
if (threadIdx.x == 0 && threadIdx.y == 0) { i,
int old = atomicAdd(&flag_ptr[blockIdx.x], 1); &x_sum,
is_last_block_done = (old == (gridDim.y - 1)); &x_square_sum,
} &is_last_block_done,
smem_sum,
__syncthreads(); smem_square_sum,
block_data_ptr,
flag_ptr);
if (is_last_block_done) { if (is_last_block_done) {
x_sum = static_cast<BatchNormParamType<T>>(0);
x_square_sum = static_cast<BatchNormParamType<T>>(0);
// thread sum
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
x_sum += staging_sum[i + y * C];
x_square_sum += staging_square_sum[i + y * C];
}
// vertical block sum
BlockReduceByVetical<T>(x_sum,
x_square_sum,
&smem_sum[0],
&smem_square_sum[0],
&x_sum,
&x_square_sum);
// final compute // final compute
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
BatchNormParamType<T> compute_mean_val = x_sum / inner_size; BatchNormParamType<T> compute_mean_val = x_sum / inner_size;
...@@ -417,45 +361,21 @@ static __global__ void BNBackward2DChannelLastStage2( ...@@ -417,45 +361,21 @@ static __global__ void BNBackward2DChannelLastStage2(
} }
// vertical block sum // vertical block sum
BlockReduceByVetical<T>( funcs::BlockReduceByVetical<T, BatchNormParamType<T>>(
ds_sum, db_sum, &smem_ds_sum[0], &smem_db_sum[0], &ds_sum, &db_sum); ds_sum, db_sum, &smem_ds_sum[0], &smem_db_sum[0], &ds_sum, &db_sum);
if (gridDim.y > 1) { if (gridDim.y > 1) {
volatile BatchNormParamType<T> *staging_ds_sum = block_data_ptr;
volatile BatchNormParamType<T> *staging_db_sum =
&block_data_ptr[C * gridDim.y];
// write block data to global memory
if (threadIdx.y == 0) {
staging_ds_sum[i + blockIdx.y * C] = ds_sum;
staging_db_sum[i + blockIdx.y * C] = db_sum;
}
// make sure write is visible to all blocks
__threadfence();
__syncthreads();
__shared__ bool is_last_block_done; __shared__ bool is_last_block_done;
// mark block done funcs::ReduceSumPost<T, BatchNormParamType<T>>(C,
if (threadIdx.x == 0 && threadIdx.y == 0) { i,
int old = atomicAdd(&flag_ptr[blockIdx.x], 1); &ds_sum,
is_last_block_done = (old == (gridDim.y - 1)); &db_sum,
} &is_last_block_done,
smem_ds_sum,
__syncthreads(); smem_db_sum,
block_data_ptr,
flag_ptr);
if (is_last_block_done) { if (is_last_block_done) {
ds_sum = static_cast<BatchNormParamType<T>>(0);
db_sum = static_cast<BatchNormParamType<T>>(0);
// thread sum
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
ds_sum += staging_ds_sum[i + y * C];
db_sum += staging_db_sum[i + y * C];
}
// vertical block sum
BlockReduceByVetical<T>(
ds_sum, db_sum, &smem_ds_sum[0], &smem_db_sum[0], &ds_sum, &db_sum);
// final compute // final compute
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
dscale[i] = ds_sum * inv_var_val; dscale[i] = ds_sum * inv_var_val;
...@@ -563,51 +483,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( ...@@ -563,51 +483,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
} }
} }
template <typename T, typename Context>
void SetLaunchConfigInfoForChannelLast(const Context &ctx,
DenseTensor *block_data_tensor,
DenseTensor *flag_tensor,
BatchNormParamType<T> **block_data_ptr,
int **flag_ptr,
const int N,
const int H,
const int W,
const int D,
const int C,
const int block_size,
dim3 *block,
dim3 *grid) {
const int MAX_GRID_SIZE = 128;
const int WARP_SIZE = 32;
int block_x = std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE);
int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16),
block_size / block_x);
if (block_x * block_y != block_size) {
block_x =
std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y);
}
int grid_x = (C + block_x - 1) / block_x;
int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16),
MAX_GRID_SIZE);
block->x = block_x;
block->y = block_y;
grid->x = grid_x;
grid->y = grid_y;
if (grid->y > 1) {
*block_data_tensor =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {2 * C * grid->y});
*flag_tensor = phi::Empty<int, Context>(ctx, {grid->x});
*block_data_ptr = block_data_tensor->data<BatchNormParamType<T>>();
*flag_ptr = flag_tensor->data<int>();
funcs::SetConstant<Context, int> set_zero;
set_zero(ctx, flag_tensor, static_cast<int>(0));
}
}
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context &ctx, void BatchNormGradRawKernel(const Context &ctx,
const DenseTensor &x, const DenseTensor &x,
...@@ -931,19 +806,20 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -931,19 +806,20 @@ void BatchNormGradRawKernel(const Context &ctx,
BatchNormParamType<T> *block_data_ptr = nullptr; BatchNormParamType<T> *block_data_ptr = nullptr;
int *flag_ptr = nullptr; int *flag_ptr = nullptr;
SetLaunchConfigInfoForChannelLast<T>(ctx, funcs::SetLaunchConfigInfoForChannelLast<T, BatchNormParamType<T>>(
&block_data_tensor, ctx,
&flag_tensor, &block_data_tensor,
&block_data_ptr, &flag_tensor,
&flag_ptr, &block_data_ptr,
N, &flag_ptr,
H, N,
W, H,
D, W,
C, D,
block_size, C,
&block, block_size,
&grid); &block,
&grid);
// 1. reduce_sum(x) => mean, inv_var // 1. reduce_sum(x) => mean, inv_var
auto *mean_ptr = auto *mean_ptr =
...@@ -1294,19 +1170,20 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -1294,19 +1170,20 @@ void BatchNormGradRawKernel(const Context &ctx,
BatchNormParamType<T> *block_data_ptr = nullptr; BatchNormParamType<T> *block_data_ptr = nullptr;
int *flag_ptr = nullptr; int *flag_ptr = nullptr;
SetLaunchConfigInfoForChannelLast<T>(ctx, funcs::SetLaunchConfigInfoForChannelLast<T, BatchNormParamType<T>>(
&block_data_tensor, ctx,
&flag_tensor, &block_data_tensor,
&block_data_ptr, &flag_tensor,
&flag_ptr, &block_data_ptr,
N, &flag_ptr,
H, N,
W, H,
D, W,
C, D,
block_size, C,
&block, block_size,
&grid); &block,
&grid);
BNBackward2DChannelLastStage2<T, block_size> BNBackward2DChannelLastStage2<T, block_size>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(), transformed_d_y.template data<T>(),
......
...@@ -30,6 +30,7 @@ namespace cub = hipcub; ...@@ -30,6 +30,7 @@ namespace cub = hipcub;
#include "paddle/phi/kernels/batch_norm_kernel.h" #include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/norm_utils.cu.h"
#include "paddle/phi/kernels/funcs/norm_utils.h" #include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/reduce_function.h"
...@@ -171,34 +172,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining( ...@@ -171,34 +172,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
} }
} }
template <typename T>
__device__ __forceinline__ void merge_block_vertical(
BatchNormParamType<T> x_sum,
BatchNormParamType<T> x_square_sum,
BatchNormParamType<T> *smem_sum,
BatchNormParamType<T> *smem_square_sum,
BatchNormParamType<T> *x_sum_out,
BatchNormParamType<T> *x_square_sum_out) {
int tid = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
if (threadIdx.y < offset * 2) {
smem_sum[tid] = x_sum;
smem_square_sum[tid] = x_square_sum;
}
__syncthreads();
if (threadIdx.y < offset) {
int pair_tid = tid + offset * blockDim.x;
x_sum += smem_sum[pair_tid];
x_square_sum += smem_square_sum[pair_tid];
}
}
if (threadIdx.y == 0) {
*x_sum_out = x_sum;
*x_square_sum_out = x_square_sum;
}
}
template <typename T> template <typename T>
__device__ __forceinline__ void merge_block_horizonal( __device__ __forceinline__ void merge_block_horizonal(
BatchNormParamType<T> x_sum, BatchNormParamType<T> x_sum,
...@@ -269,53 +242,26 @@ static __global__ void BNForwardTraining2DChannelLastCompStat( ...@@ -269,53 +242,26 @@ static __global__ void BNForwardTraining2DChannelLastCompStat(
} }
// vertical block sum // vertical block sum
merge_block_vertical<T>(x_sum, funcs::BlockReduceByVetical<T, BatchNormParamType<T>>(x_sum,
x_square_sum, x_square_sum,
&smem_sum[0], &smem_sum[0],
&smem_square_sum[0], &smem_square_sum[0],
&x_sum, &x_sum,
&x_square_sum); &x_square_sum);
if (gridDim.y > 1) { if (gridDim.y > 1) {
volatile BatchNormParamType<T> *staging_sum = block_data_ptr;
volatile BatchNormParamType<T> *staging_square_sum =
&block_data_ptr[C * gridDim.y];
// write block data to global memory
if (threadIdx.y == 0) {
staging_sum[i + blockIdx.y * C] = x_sum;
staging_square_sum[i + blockIdx.y * C] = x_square_sum;
}
// make sure write is visible to all blocks
__threadfence();
__syncthreads();
__shared__ bool is_last_block_done; __shared__ bool is_last_block_done;
// mark block done funcs::ReduceSumPost<T, BatchNormParamType<T>>(C,
if (threadIdx.x == 0 && threadIdx.y == 0) { i,
int old = atomicAdd(&flag_ptr[blockIdx.x], 1); &x_sum,
is_last_block_done = (old == (gridDim.y - 1)); &x_square_sum,
} &is_last_block_done,
smem_sum,
__syncthreads(); smem_square_sum,
block_data_ptr,
flag_ptr);
if (is_last_block_done) { if (is_last_block_done) {
x_sum = static_cast<BatchNormParamType<T>>(0);
x_square_sum = static_cast<BatchNormParamType<T>>(0);
// thread sum
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
x_sum += staging_sum[i + y * C];
x_square_sum += staging_square_sum[i + y * C];
}
// vertical block sum
merge_block_vertical<T>(x_sum,
x_square_sum,
&smem_sum[0],
&smem_square_sum[0],
&x_sum,
&x_square_sum);
// final compute // final compute
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
BatchNormParamType<T> compute_mean_val = x_sum / inner_size; BatchNormParamType<T> compute_mean_val = x_sum / inner_size;
......
...@@ -34,6 +34,7 @@ namespace cub = hipcub; ...@@ -34,6 +34,7 @@ namespace cub = hipcub;
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/norm_utils.cu.h"
#include "paddle/phi/kernels/funcs/norm_utils.h" #include "paddle/phi/kernels/funcs/norm_utils.h"
namespace phi { namespace phi {
...@@ -168,6 +169,61 @@ __global__ void KeBackwardLocalStats(const T *dy, ...@@ -168,6 +169,61 @@ __global__ void KeBackwardLocalStats(const T *dy,
} }
} }
template <typename T, const int BlockDim, DataLayout layout>
__global__ void KeBackwardLocalStats2D(const T *dy,
const T *x,
const BatchNormParamType<T> *means,
int N,
int M,
int C,
BatchNormParamType<T> *block_data_ptr,
int *flag_ptr,
BatchNormParamType<T> *sum_dy_prod) {
__shared__ BatchNormParamType<T> smem_sum[BlockDim];
__shared__ BatchNormParamType<T> smem_square_sum[BlockDim];
for (int k = blockIdx.x * blockDim.x + threadIdx.x; k < C;
k += gridDim.x * blockDim.x) {
BatchNormParamType<T> sum1 = 0.;
BatchNormParamType<T> sum2 = 0.;
auto mean = means[k];
for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < N * M;
i += gridDim.y * blockDim.y) {
int id = layout == DataLayout::kNCHW ? (i / M) * C * M + k * M + i % M
: i * C + k;
auto g = static_cast<BatchNormParamType<T>>(dy[id]);
sum1 += g;
auto x_i = static_cast<BatchNormParamType<T>>(x[id]);
sum2 += g * (x_i - mean);
}
funcs::BlockReduceByVetical<T, BatchNormParamType<T>>(
sum1, sum2, &smem_sum[0], &smem_square_sum[0], &sum1, &sum2);
if (gridDim.y > 1) {
__shared__ bool is_last_block_done;
funcs::ReduceSumPost<T, BatchNormParamType<T>>(C,
k,
&sum1,
&sum2,
&is_last_block_done,
smem_sum,
smem_square_sum,
block_data_ptr,
flag_ptr);
if (is_last_block_done) {
// final compute
if (threadIdx.y == 0) {
sum_dy_prod[k] = sum1;
sum_dy_prod[k + C] = sum2;
}
}
}
}
if (blockIdx.y == 0 && blockIdx.x == 0 && threadIdx.y == 0 &&
threadIdx.x == 0) {
sum_dy_prod[2 * C] = 1.0;
}
}
template <typename T, int BlockDim, DataLayout layout> template <typename T, int BlockDim, DataLayout layout>
static __global__ void KeBNBackwardScaleBias( static __global__ void KeBNBackwardScaleBias(
const T *dy, const T *dy,
...@@ -213,6 +269,68 @@ static __global__ void KeBNBackwardScaleBias( ...@@ -213,6 +269,68 @@ static __global__ void KeBNBackwardScaleBias(
} }
} }
template <typename T, int BlockDim, DataLayout layout>
static __global__ void KeBNBackwardScaleBias2D(
const T *dy,
const T *x,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *inv_variance,
const double epsilon,
const int N,
const int C,
const int HxW,
BatchNormParamType<T> *block_data_ptr,
int *flag_ptr,
BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
__shared__ BatchNormParamType<T> smem_sum[BlockDim];
__shared__ BatchNormParamType<T> smem_square_sum[BlockDim];
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size;
i += gridDim.x * blockDim.x) {
BatchNormParamType<T> ds_sum = 0.;
BatchNormParamType<T> db_sum = 0.;
auto inv_var_i = inv_variance[i];
auto mean_i = mean[i];
for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size;
j += gridDim.y * blockDim.y) {
const int id = layout == DataLayout::kNCHW
? ((j / HxW) * C + i) * HxW + (j % HxW)
: j * outer_size + i;
auto x_i = static_cast<BatchNormParamType<T>>(x[id]);
auto dy_i = static_cast<BatchNormParamType<T>>(dy[id]);
ds_sum += dy_i * (x_i - mean_i);
db_sum += dy_i;
}
funcs::BlockReduceByVetical<T, BatchNormParamType<T>>(
ds_sum, db_sum, &smem_sum[0], &smem_square_sum[0], &ds_sum, &db_sum);
if (gridDim.y > 1) {
__shared__ bool is_last_block_done;
funcs::ReduceSumPost<T, BatchNormParamType<T>>(C,
i,
&ds_sum,
&db_sum,
&is_last_block_done,
smem_sum,
smem_square_sum,
block_data_ptr,
flag_ptr);
if (is_last_block_done) {
// final compute
if (threadIdx.y == 0) {
dscale[i] = ds_sum * inv_var_i;
dbias[i] = db_sum;
}
}
}
}
}
template <typename T, DataLayout layout> template <typename T, DataLayout layout>
static __global__ void KeBNRestoreData(T *x, static __global__ void KeBNRestoreData(T *x,
const BatchNormParamType<T> *scale, const BatchNormParamType<T> *scale,
...@@ -410,9 +528,46 @@ void SyncBatchNormGradFunctor( ...@@ -410,9 +528,46 @@ void SyncBatchNormGradFunctor(
<<<grid, threads, 0, stream>>>( <<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean_ptr, N, fsize, C, stats); dy_d, x_d, saved_mean_ptr, N, fsize, C, stats);
} else { } else {
KeBackwardLocalStats<T, threads, DataLayout::kNHWC> if (x_dims.size() == 2 && N >= 65535) {
<<<grid, threads, 0, stream>>>( dim3 block;
dy_d, x_d, saved_mean_ptr, N, fsize, C, stats); dim3 grid;
const int block_size = 512;
// init intermediate storage
DenseTensor block_data_tensor;
DenseTensor flag_tensor;
BatchNormParamType<T> *block_data_ptr = nullptr;
int *flag_ptr = nullptr;
funcs::SetLaunchConfigInfoForChannelLast<T, BatchNormParamType<T>>(
ctx,
&block_data_tensor,
&flag_tensor,
&block_data_ptr,
&flag_ptr,
N,
H,
W,
D,
C,
block_size,
&block,
&grid);
KeBackwardLocalStats2D<T, block_size, DataLayout::kNHWC>
<<<grid, block, 0, stream>>>(dy_d,
x_d,
saved_mean_ptr,
N,
fsize,
C,
block_data_ptr,
flag_ptr,
stats);
} else {
KeBackwardLocalStats<T, threads, DataLayout::kNHWC>
<<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean_ptr, N, fsize, C, stats);
}
} }
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
...@@ -476,8 +631,33 @@ void SyncBatchNormGradFunctor( ...@@ -476,8 +631,33 @@ void SyncBatchNormGradFunctor(
} }
} else { } else {
if (d_scale && d_bias) { if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, threads, DataLayout::kNHWC> if (x_dims.size() == 2 && N >= 65535) {
<<<grid, threads, 0, stream>>>(dy_d, dim3 block;
dim3 grid;
const int block_size = 512;
// init intermediate storage
DenseTensor block_data_tensor;
DenseTensor flag_tensor;
BatchNormParamType<T> *block_data_ptr = nullptr;
int *flag_ptr = nullptr;
funcs::SetLaunchConfigInfoForChannelLast<T, BatchNormParamType<T>>(
ctx,
&block_data_tensor,
&flag_tensor,
&block_data_ptr,
&flag_ptr,
N,
H,
W,
D,
C,
block_size,
&block,
&grid);
KeBNBackwardScaleBias2D<T, block_size, DataLayout::kNHWC>
<<<grid, block, 0, stream>>>(dy_d,
x_d, x_d,
saved_mean_ptr, saved_mean_ptr,
saved_inv_var, saved_inv_var,
...@@ -485,8 +665,24 @@ void SyncBatchNormGradFunctor( ...@@ -485,8 +665,24 @@ void SyncBatchNormGradFunctor(
N, N,
C, C,
fsize, fsize,
block_data_ptr,
flag_ptr,
d_scale->data<BatchNormParamType<T>>(), d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>()); d_bias->data<BatchNormParamType<T>>());
} else {
KeBNBackwardScaleBias<T, threads, DataLayout::kNHWC>
<<<grid, threads, 0, stream>>>(
dy_d,
x_d,
saved_mean_ptr,
saved_inv_var,
epsilon,
N,
C,
fsize,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} }
if (d_x) { if (d_x) {
KeBNBackwardData<T, DataLayout::kNHWC><<<grid2, block, 0, stream>>>( KeBNBackwardData<T, DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册