From 9a4acfee2fb1e90ded399511cf0f8ee1def0229f Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 31 Jan 2023 14:56:13 +0800 Subject: [PATCH] optimize 2D sync_batch_norm (#49663) --- paddle/phi/kernels/funcs/norm_utils.cu.h | 120 ++++++++- .../phi/kernels/gpu/batch_norm_grad_kernel.cu | 229 ++++-------------- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 86 ++----- .../phi/kernels/gpu/sync_batch_norm_utils.h | 206 +++++++++++++++- 4 files changed, 388 insertions(+), 253 deletions(-) diff --git a/paddle/phi/kernels/funcs/norm_utils.cu.h b/paddle/phi/kernels/funcs/norm_utils.cu.h index 0971db1052..80f37750ad 100644 --- a/paddle/phi/kernels/funcs/norm_utils.cu.h +++ b/paddle/phi/kernels/funcs/norm_utils.cu.h @@ -26,6 +26,7 @@ namespace cub = hipcub; #endif #include "paddle/phi/common/layout.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" #ifdef __HIPCC__ #define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim) @@ -36,8 +37,6 @@ namespace cub = hipcub; namespace phi { namespace funcs { -using DataLayout = phi::DataLayout; - // math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx, // axis=(n,h,w)) * // np.sum(dy, axis=(n,h,w)) - @@ -670,5 +669,122 @@ void NormDoubleGradFunctor(const DeviceContext &ctx, } } } + +template +__device__ __forceinline__ void BlockReduceByVetical(BnT x_sum, + BnT x_square_sum, + BnT *smem_sum, + BnT *smem_square_sum, + BnT *x_sum_out, + BnT *x_square_sum_out) { + int tid = threadIdx.x + threadIdx.y * blockDim.x; +#pragma unroll + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset * 2) { + smem_sum[tid] = x_sum; + smem_square_sum[tid] = x_square_sum; + } + __syncthreads(); + if (threadIdx.y < offset) { + int pair_tid = tid + offset * blockDim.x; + x_sum += smem_sum[pair_tid]; + x_square_sum += smem_square_sum[pair_tid]; + } + } + if (threadIdx.y == 0) { + *x_sum_out = x_sum; + *x_square_sum_out = x_square_sum; + } +} + +template +__device__ __forceinline__ void ReduceSumPost(const int C, // channels + const int c, // channel index + BnT *sum1, + BnT *sum2, + bool *is_last_block_done, + BnT *cache1, + BnT *cache2, + BnT *block_data_ptr, + int *flag_ptr) { + volatile BnT *staging_sum = block_data_ptr; + volatile BnT *staging_sum2 = &block_data_ptr[C * gridDim.y]; + // write block data to global memory + if (threadIdx.y == 0) { + staging_sum[c + blockIdx.y * C] = *sum1; + staging_sum2[c + blockIdx.y * C] = *sum2; + } + + // make sure write is visible to all blocks + __threadfence(); + __syncthreads(); + + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&flag_ptr[blockIdx.x], 1); + *is_last_block_done = (old == (gridDim.y - 1)); + } + + __syncthreads(); + + if (*is_last_block_done) { + *sum1 = static_cast(0); + *sum2 = static_cast(0); + // thread sum + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + *sum1 += staging_sum[c + y * C]; + *sum2 += staging_sum2[c + y * C]; + } + + // vertical block sum + funcs::BlockReduceByVetical( + *sum1, *sum2, &cache1[0], &cache2[0], sum1, sum2); + } +} + +template +void SetLaunchConfigInfoForChannelLast(const Context &ctx, + DenseTensor *block_data_tensor, + DenseTensor *flag_tensor, + BnT **block_data_ptr, + int **flag_ptr, + const int N, + const int H, + const int W, + const int D, + const int C, + const int block_size, + dim3 *block, + dim3 *grid) { + const int MAX_GRID_SIZE = 128; + const int WARP_SIZE = 32; + + int block_x = std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE); + int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), + block_size / block_x); + if (block_x * block_y != block_size) { + block_x = + std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y); + } + int grid_x = (C + block_x - 1) / block_x; + int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16), + MAX_GRID_SIZE); + + block->x = block_x; + block->y = block_y; + grid->x = grid_x; + grid->y = grid_y; + + if (grid->y > 1) { + *block_data_tensor = phi::Empty(ctx, {2 * C * grid->y}); + *flag_tensor = phi::Empty(ctx, {grid->x}); + + *block_data_ptr = block_data_tensor->data(); + *flag_ptr = flag_tensor->data(); + funcs::SetConstant set_zero; + set_zero(ctx, flag_tensor, static_cast(0)); + } +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 01a7aa0162..58d05d6075 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -245,34 +245,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward( } } -template -__device__ __forceinline__ void BlockReduceByVetical( - BatchNormParamType x_sum, - BatchNormParamType x_square_sum, - BatchNormParamType *smem_sum, - BatchNormParamType *smem_square_sum, - BatchNormParamType *x_sum_out, - BatchNormParamType *x_square_sum_out) { - int tid = threadIdx.x + threadIdx.y * blockDim.x; -#pragma unroll - for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { - if (threadIdx.y < offset * 2) { - smem_sum[tid] = x_sum; - smem_square_sum[tid] = x_square_sum; - } - __syncthreads(); - if (threadIdx.y < offset) { - int pair_tid = tid + offset * blockDim.x; - x_sum += smem_sum[pair_tid]; - x_square_sum += smem_square_sum[pair_tid]; - } - } - if (threadIdx.y == 0) { - *x_sum_out = x_sum; - *x_square_sum_out = x_square_sum; - } -} - template static __global__ void BNBackward2DChannelLastStage1( const T *x, @@ -309,53 +281,25 @@ static __global__ void BNBackward2DChannelLastStage1( } // vertical block sum - BlockReduceByVetical(x_sum, - x_square_sum, - &smem_sum[0], - &smem_square_sum[0], - &x_sum, - &x_square_sum); + funcs::BlockReduceByVetical>(x_sum, + x_square_sum, + &smem_sum[0], + &smem_square_sum[0], + &x_sum, + &x_square_sum); if (gridDim.y > 1) { - volatile BatchNormParamType *staging_sum = block_data_ptr; - volatile BatchNormParamType *staging_square_sum = - &block_data_ptr[C * gridDim.y]; - // write block data to global memory - if (threadIdx.y == 0) { - staging_sum[i + blockIdx.y * C] = x_sum; - staging_square_sum[i + blockIdx.y * C] = x_square_sum; - } - - // make sure write is visible to all blocks - __threadfence(); - __syncthreads(); - __shared__ bool is_last_block_done; - // mark block done - if (threadIdx.x == 0 && threadIdx.y == 0) { - int old = atomicAdd(&flag_ptr[blockIdx.x], 1); - is_last_block_done = (old == (gridDim.y - 1)); - } - - __syncthreads(); - + funcs::ReduceSumPost>(C, + i, + &x_sum, + &x_square_sum, + &is_last_block_done, + smem_sum, + smem_square_sum, + block_data_ptr, + flag_ptr); if (is_last_block_done) { - x_sum = static_cast>(0); - x_square_sum = static_cast>(0); - // thread sum - for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { - x_sum += staging_sum[i + y * C]; - x_square_sum += staging_square_sum[i + y * C]; - } - - // vertical block sum - BlockReduceByVetical(x_sum, - x_square_sum, - &smem_sum[0], - &smem_square_sum[0], - &x_sum, - &x_square_sum); - // final compute if (threadIdx.y == 0) { BatchNormParamType compute_mean_val = x_sum / inner_size; @@ -417,45 +361,21 @@ static __global__ void BNBackward2DChannelLastStage2( } // vertical block sum - BlockReduceByVetical( + funcs::BlockReduceByVetical>( ds_sum, db_sum, &smem_ds_sum[0], &smem_db_sum[0], &ds_sum, &db_sum); if (gridDim.y > 1) { - volatile BatchNormParamType *staging_ds_sum = block_data_ptr; - volatile BatchNormParamType *staging_db_sum = - &block_data_ptr[C * gridDim.y]; - // write block data to global memory - if (threadIdx.y == 0) { - staging_ds_sum[i + blockIdx.y * C] = ds_sum; - staging_db_sum[i + blockIdx.y * C] = db_sum; - } - - // make sure write is visible to all blocks - __threadfence(); - __syncthreads(); - __shared__ bool is_last_block_done; - // mark block done - if (threadIdx.x == 0 && threadIdx.y == 0) { - int old = atomicAdd(&flag_ptr[blockIdx.x], 1); - is_last_block_done = (old == (gridDim.y - 1)); - } - - __syncthreads(); - + funcs::ReduceSumPost>(C, + i, + &ds_sum, + &db_sum, + &is_last_block_done, + smem_ds_sum, + smem_db_sum, + block_data_ptr, + flag_ptr); if (is_last_block_done) { - ds_sum = static_cast>(0); - db_sum = static_cast>(0); - // thread sum - for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { - ds_sum += staging_ds_sum[i + y * C]; - db_sum += staging_db_sum[i + y * C]; - } - - // vertical block sum - BlockReduceByVetical( - ds_sum, db_sum, &smem_ds_sum[0], &smem_db_sum[0], &ds_sum, &db_sum); - // final compute if (threadIdx.y == 0) { dscale[i] = ds_sum * inv_var_val; @@ -563,51 +483,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( } } -template -void SetLaunchConfigInfoForChannelLast(const Context &ctx, - DenseTensor *block_data_tensor, - DenseTensor *flag_tensor, - BatchNormParamType **block_data_ptr, - int **flag_ptr, - const int N, - const int H, - const int W, - const int D, - const int C, - const int block_size, - dim3 *block, - dim3 *grid) { - const int MAX_GRID_SIZE = 128; - const int WARP_SIZE = 32; - - int block_x = std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE); - int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), - block_size / block_x); - if (block_x * block_y != block_size) { - block_x = - std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y); - } - int grid_x = (C + block_x - 1) / block_x; - int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16), - MAX_GRID_SIZE); - - block->x = block_x; - block->y = block_y; - grid->x = grid_x; - grid->y = grid_y; - - if (grid->y > 1) { - *block_data_tensor = - phi::Empty, Context>(ctx, {2 * C * grid->y}); - *flag_tensor = phi::Empty(ctx, {grid->x}); - - *block_data_ptr = block_data_tensor->data>(); - *flag_ptr = flag_tensor->data(); - funcs::SetConstant set_zero; - set_zero(ctx, flag_tensor, static_cast(0)); - } -} - template void BatchNormGradRawKernel(const Context &ctx, const DenseTensor &x, @@ -931,19 +806,20 @@ void BatchNormGradRawKernel(const Context &ctx, BatchNormParamType *block_data_ptr = nullptr; int *flag_ptr = nullptr; - SetLaunchConfigInfoForChannelLast(ctx, - &block_data_tensor, - &flag_tensor, - &block_data_ptr, - &flag_ptr, - N, - H, - W, - D, - C, - block_size, - &block, - &grid); + funcs::SetLaunchConfigInfoForChannelLast>( + ctx, + &block_data_tensor, + &flag_tensor, + &block_data_ptr, + &flag_ptr, + N, + H, + W, + D, + C, + block_size, + &block, + &grid); // 1. reduce_sum(x) => mean, inv_var auto *mean_ptr = @@ -1294,19 +1170,20 @@ void BatchNormGradRawKernel(const Context &ctx, BatchNormParamType *block_data_ptr = nullptr; int *flag_ptr = nullptr; - SetLaunchConfigInfoForChannelLast(ctx, - &block_data_tensor, - &flag_tensor, - &block_data_ptr, - &flag_ptr, - N, - H, - W, - D, - C, - block_size, - &block, - &grid); + funcs::SetLaunchConfigInfoForChannelLast>( + ctx, + &block_data_tensor, + &flag_tensor, + &block_data_ptr, + &flag_ptr, + N, + H, + W, + D, + C, + block_size, + &block, + &grid); BNBackward2DChannelLastStage2 <<>>( transformed_d_y.template data(), diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 60d0d1a01b..fc460574b7 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -30,6 +30,7 @@ namespace cub = hipcub; #include "paddle/phi/kernels/batch_norm_kernel.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/norm_utils.cu.h" #include "paddle/phi/kernels/funcs/norm_utils.h" #include "paddle/phi/kernels/funcs/reduce_function.h" @@ -171,34 +172,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining( } } -template -__device__ __forceinline__ void merge_block_vertical( - BatchNormParamType x_sum, - BatchNormParamType x_square_sum, - BatchNormParamType *smem_sum, - BatchNormParamType *smem_square_sum, - BatchNormParamType *x_sum_out, - BatchNormParamType *x_square_sum_out) { - int tid = threadIdx.x + threadIdx.y * blockDim.x; -#pragma unroll - for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { - if (threadIdx.y < offset * 2) { - smem_sum[tid] = x_sum; - smem_square_sum[tid] = x_square_sum; - } - __syncthreads(); - if (threadIdx.y < offset) { - int pair_tid = tid + offset * blockDim.x; - x_sum += smem_sum[pair_tid]; - x_square_sum += smem_square_sum[pair_tid]; - } - } - if (threadIdx.y == 0) { - *x_sum_out = x_sum; - *x_square_sum_out = x_square_sum; - } -} - template __device__ __forceinline__ void merge_block_horizonal( BatchNormParamType x_sum, @@ -269,53 +242,26 @@ static __global__ void BNForwardTraining2DChannelLastCompStat( } // vertical block sum - merge_block_vertical(x_sum, - x_square_sum, - &smem_sum[0], - &smem_square_sum[0], - &x_sum, - &x_square_sum); + funcs::BlockReduceByVetical>(x_sum, + x_square_sum, + &smem_sum[0], + &smem_square_sum[0], + &x_sum, + &x_square_sum); if (gridDim.y > 1) { - volatile BatchNormParamType *staging_sum = block_data_ptr; - volatile BatchNormParamType *staging_square_sum = - &block_data_ptr[C * gridDim.y]; - // write block data to global memory - if (threadIdx.y == 0) { - staging_sum[i + blockIdx.y * C] = x_sum; - staging_square_sum[i + blockIdx.y * C] = x_square_sum; - } - - // make sure write is visible to all blocks - __threadfence(); - __syncthreads(); - __shared__ bool is_last_block_done; - // mark block done - if (threadIdx.x == 0 && threadIdx.y == 0) { - int old = atomicAdd(&flag_ptr[blockIdx.x], 1); - is_last_block_done = (old == (gridDim.y - 1)); - } - - __syncthreads(); + funcs::ReduceSumPost>(C, + i, + &x_sum, + &x_square_sum, + &is_last_block_done, + smem_sum, + smem_square_sum, + block_data_ptr, + flag_ptr); if (is_last_block_done) { - x_sum = static_cast>(0); - x_square_sum = static_cast>(0); - // thread sum - for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { - x_sum += staging_sum[i + y * C]; - x_square_sum += staging_square_sum[i + y * C]; - } - - // vertical block sum - merge_block_vertical(x_sum, - x_square_sum, - &smem_sum[0], - &smem_square_sum[0], - &x_sum, - &x_square_sum); - // final compute if (threadIdx.y == 0) { BatchNormParamType compute_mean_val = x_sum / inner_size; diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_utils.h b/paddle/phi/kernels/gpu/sync_batch_norm_utils.h index 81717cd445..71d0ccfa0e 100644 --- a/paddle/phi/kernels/gpu/sync_batch_norm_utils.h +++ b/paddle/phi/kernels/gpu/sync_batch_norm_utils.h @@ -34,6 +34,7 @@ namespace cub = hipcub; #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/common/layout.h" +#include "paddle/phi/kernels/funcs/norm_utils.cu.h" #include "paddle/phi/kernels/funcs/norm_utils.h" namespace phi { @@ -168,6 +169,61 @@ __global__ void KeBackwardLocalStats(const T *dy, } } +template +__global__ void KeBackwardLocalStats2D(const T *dy, + const T *x, + const BatchNormParamType *means, + int N, + int M, + int C, + BatchNormParamType *block_data_ptr, + int *flag_ptr, + BatchNormParamType *sum_dy_prod) { + __shared__ BatchNormParamType smem_sum[BlockDim]; + __shared__ BatchNormParamType smem_square_sum[BlockDim]; + for (int k = blockIdx.x * blockDim.x + threadIdx.x; k < C; + k += gridDim.x * blockDim.x) { + BatchNormParamType sum1 = 0.; + BatchNormParamType sum2 = 0.; + auto mean = means[k]; + for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < N * M; + i += gridDim.y * blockDim.y) { + int id = layout == DataLayout::kNCHW ? (i / M) * C * M + k * M + i % M + : i * C + k; + auto g = static_cast>(dy[id]); + sum1 += g; + auto x_i = static_cast>(x[id]); + sum2 += g * (x_i - mean); + } + funcs::BlockReduceByVetical>( + sum1, sum2, &smem_sum[0], &smem_square_sum[0], &sum1, &sum2); + + if (gridDim.y > 1) { + __shared__ bool is_last_block_done; + funcs::ReduceSumPost>(C, + k, + &sum1, + &sum2, + &is_last_block_done, + smem_sum, + smem_square_sum, + block_data_ptr, + flag_ptr); + if (is_last_block_done) { + // final compute + if (threadIdx.y == 0) { + sum_dy_prod[k] = sum1; + sum_dy_prod[k + C] = sum2; + } + } + } + } + if (blockIdx.y == 0 && blockIdx.x == 0 && threadIdx.y == 0 && + threadIdx.x == 0) { + sum_dy_prod[2 * C] = 1.0; + } +} + template static __global__ void KeBNBackwardScaleBias( const T *dy, @@ -213,6 +269,68 @@ static __global__ void KeBNBackwardScaleBias( } } +template +static __global__ void KeBNBackwardScaleBias2D( + const T *dy, + const T *x, + const BatchNormParamType *mean, + const BatchNormParamType *inv_variance, + const double epsilon, + const int N, + const int C, + const int HxW, + BatchNormParamType *block_data_ptr, + int *flag_ptr, + BatchNormParamType *dscale, + BatchNormParamType *dbias) { + const int outer_size = C; + const int inner_size = N * HxW; + __shared__ BatchNormParamType smem_sum[BlockDim]; + __shared__ BatchNormParamType smem_square_sum[BlockDim]; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + i += gridDim.x * blockDim.x) { + BatchNormParamType ds_sum = 0.; + BatchNormParamType db_sum = 0.; + + auto inv_var_i = inv_variance[i]; + auto mean_i = mean[i]; + for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; + j += gridDim.y * blockDim.y) { + const int id = layout == DataLayout::kNCHW + ? ((j / HxW) * C + i) * HxW + (j % HxW) + : j * outer_size + i; + auto x_i = static_cast>(x[id]); + auto dy_i = static_cast>(dy[id]); + ds_sum += dy_i * (x_i - mean_i); + db_sum += dy_i; + } + + funcs::BlockReduceByVetical>( + ds_sum, db_sum, &smem_sum[0], &smem_square_sum[0], &ds_sum, &db_sum); + + if (gridDim.y > 1) { + __shared__ bool is_last_block_done; + funcs::ReduceSumPost>(C, + i, + &ds_sum, + &db_sum, + &is_last_block_done, + smem_sum, + smem_square_sum, + block_data_ptr, + flag_ptr); + if (is_last_block_done) { + // final compute + if (threadIdx.y == 0) { + dscale[i] = ds_sum * inv_var_i; + dbias[i] = db_sum; + } + } + } + } +} + template static __global__ void KeBNRestoreData(T *x, const BatchNormParamType *scale, @@ -410,9 +528,46 @@ void SyncBatchNormGradFunctor( <<>>( dy_d, x_d, saved_mean_ptr, N, fsize, C, stats); } else { - KeBackwardLocalStats - <<>>( - dy_d, x_d, saved_mean_ptr, N, fsize, C, stats); + if (x_dims.size() == 2 && N >= 65535) { + dim3 block; + dim3 grid; + const int block_size = 512; + + // init intermediate storage + DenseTensor block_data_tensor; + DenseTensor flag_tensor; + BatchNormParamType *block_data_ptr = nullptr; + int *flag_ptr = nullptr; + + funcs::SetLaunchConfigInfoForChannelLast>( + ctx, + &block_data_tensor, + &flag_tensor, + &block_data_ptr, + &flag_ptr, + N, + H, + W, + D, + C, + block_size, + &block, + &grid); + KeBackwardLocalStats2D + <<>>(dy_d, + x_d, + saved_mean_ptr, + N, + fsize, + C, + block_data_ptr, + flag_ptr, + stats); + } else { + KeBackwardLocalStats + <<>>( + dy_d, x_d, saved_mean_ptr, N, fsize, C, stats); + } } #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) @@ -476,8 +631,33 @@ void SyncBatchNormGradFunctor( } } else { if (d_scale && d_bias) { - KeBNBackwardScaleBias - <<>>(dy_d, + if (x_dims.size() == 2 && N >= 65535) { + dim3 block; + dim3 grid; + const int block_size = 512; + + // init intermediate storage + DenseTensor block_data_tensor; + DenseTensor flag_tensor; + BatchNormParamType *block_data_ptr = nullptr; + int *flag_ptr = nullptr; + + funcs::SetLaunchConfigInfoForChannelLast>( + ctx, + &block_data_tensor, + &flag_tensor, + &block_data_ptr, + &flag_ptr, + N, + H, + W, + D, + C, + block_size, + &block, + &grid); + KeBNBackwardScaleBias2D + <<>>(dy_d, x_d, saved_mean_ptr, saved_inv_var, @@ -485,8 +665,24 @@ void SyncBatchNormGradFunctor( N, C, fsize, + block_data_ptr, + flag_ptr, d_scale->data>(), d_bias->data>()); + } else { + KeBNBackwardScaleBias + <<>>( + dy_d, + x_d, + saved_mean_ptr, + saved_inv_var, + epsilon, + N, + C, + fsize, + d_scale->data>(), + d_bias->data>()); + } } if (d_x) { KeBNBackwardData<<>>( -- GitLab