From 04a4bdf8822688f8290bbd27d936d59e66fb2f9e Mon Sep 17 00:00:00 2001 From: crystal <62974595+Zjq9409@users.noreply.github.com> Date: Fri, 8 Apr 2022 20:46:11 +0800 Subject: [PATCH] fix group_norm (#41531) fix group_norm vectorized address misalignment --- paddle/fluid/operators/group_norm_op.cu | 39 ++----------------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index c08f192020..c93910bde5 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -419,23 +419,6 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale, } } -template -__global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy, - T* ds, T* db) { - int i = blockIdx.x; - AccT ds_sum = static_cast(0); - AccT db_sum = static_cast(0); - x += i * imsize; - const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T); - - phi::Array ins; - ins[0] = x; - ins[1] = dy; - ThreadReduce(ins, imsize, input_offset, &db_sum, - &ds_sum); - ReduceMeanAndVar(db, ds, db_sum, ds_sum, 1); -} - template __global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy, T* ds, T* db) { @@ -622,25 +605,9 @@ class GroupNormGradKernel int flags = (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; if (data_layout == DataLayout::kNCHW) { - using AccT = typename details::MPTypeTrait::Type; - constexpr int vec_size = sizeof(float4) / sizeof(T); - const int max_num_threads = 1024; - int max_block_size = std::min(imsize / vec_size, max_num_threads); - int block_size_nchw = 1; - while (block_size_nchw < max_block_size) { - block_size_nchw *= 2; - } - block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize); - dim3 blocks(block_size_nchw); - if (imsize < vec_size * block_size_nchw) { - ScalarGetDsDbCUDAKernel< - T><<>>( - imsize, x_data, dy_data, ds_data, db_data); - } else { - VectorizedGetDsDbCUDAKernel< - T, AccT, vec_size><<>>( - imsize, x_data, dy_data, ds_data, db_data); - } + ScalarGetDsDbCUDAKernel< + T><<>>( + imsize, x_data, dy_data, ds_data, db_data); if (d_scale || d_bias) { const int block = 256; -- GitLab