diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index ab8c50d90b8ece68b8e4e05d46cecd13fa84d7e0..c08f1920205daa7f6d5d3f032afcf5d832230ced 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -152,6 +152,21 @@ __device__ __forceinline__ void ThreadReduce(phi::Array arrs, } } +template +__device__ __forceinline__ void ReduceMeanAndVar(T* mean, T* var, T x_mean, + T x_var, int size) { + const int nc = blockIdx.x; + x_mean = kps::details::BlockXReduce>( + x_mean, kps::AddFunctor()); + x_var = kps::details::BlockXReduce>( + x_var, kps::AddFunctor()); + __syncthreads(); + if (threadIdx.x == 0) { + mean[nc] = static_cast(x_mean / size); + var[nc] = static_cast(x_var / size); + } +} + template __global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) { int i = blockIdx.x; @@ -162,10 +177,7 @@ __global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) { x_mean += val; x_var += val * val; } - x_mean /= size; - x_var /= size; - CudaAtomicAddWithWarp(&mean[i], x_mean); - CudaAtomicAddWithWarp(&var[i], x_var); + ReduceMeanAndVar(mean, var, x_mean, x_var, size); } template @@ -174,21 +186,12 @@ __global__ void VectorizedGetMeanAndVarNCHW(const T* x, T* mean, T* var, int i = blockIdx.x; AccT x_mean = static_cast(0); AccT x_var = static_cast(0); - const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T); x += i * size; + const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T); phi::Array ins; ins[0] = x; ThreadReduce(ins, size, input_offset, &x_mean, &x_var); - - x_mean = kps::details::BlockXReduce>( - x_mean, kps::AddFunctor()); - x_var = kps::details::BlockXReduce>( - x_var, kps::AddFunctor()); - __syncthreads(); - if (threadIdx.x == 0) { - mean[i] = static_cast(x_mean / size); - var[i] = static_cast(x_var / size); - } + ReduceMeanAndVar(mean, var, x_mean, x_var, size); } template @@ -272,10 +275,6 @@ class GroupNormKernel auto& dev_ctx = ctx.template device_context(); Tensor temp_var; temp_var.mutable_data(var->dims(), ctx.GetPlace()); - - set_zero(dev_ctx, mean, static_cast(0)); - set_zero(dev_ctx, &temp_var, static_cast(0)); - auto* x_data = x->data(); auto* y_data = y->data(); auto* mean_data = mean->data(); @@ -319,7 +318,7 @@ class GroupNormKernel block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize); dim3 grids(x_dims[0] * groups); dim3 blocks(block_size_nchw); - if (size < vec_size) { + if (size < vec_size * block_size_nchw) { ScalarGetMeanAndVarNCHW<<>>( x_data, mean_data, temp_var_data, size); } else { @@ -328,6 +327,8 @@ class GroupNormKernel x_data, mean_data, temp_var_data, size); } } else { + set_zero(dev_ctx, mean, static_cast(0)); + set_zero(dev_ctx, &temp_var, static_cast(0)); GroupNormForwardGetMeanAndVar<<>>( x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data, temp_var_data); @@ -424,24 +425,15 @@ __global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy, int i = blockIdx.x; AccT ds_sum = static_cast(0); AccT db_sum = static_cast(0); - const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T); x += i * imsize; + const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T); phi::Array ins; ins[0] = x; ins[1] = dy; ThreadReduce(ins, imsize, input_offset, &db_sum, &ds_sum); - - ds_sum = kps::details::BlockXReduce>( - ds_sum, kps::AddFunctor()); - db_sum = kps::details::BlockXReduce>( - db_sum, kps::AddFunctor()); - __syncthreads(); - if (threadIdx.x == 0) { - ds[i] = ds_sum; - db[i] = db_sum; - } + ReduceMeanAndVar(db, ds, db_sum, ds_sum, 1); } template @@ -455,8 +447,7 @@ __global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy, ds_sum += dy[index] * x[index]; db_sum += dy[index]; } - CudaAtomicAddWithWarp(&ds[nc], ds_sum); - CudaAtomicAddWithWarp(&db[nc], db_sum); + ReduceMeanAndVar(db, ds, db_sum, ds_sum, 1); } template @@ -641,13 +632,7 @@ class GroupNormGradKernel } block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize); dim3 blocks(block_size_nchw); - if (imsize < vec_size) { - if (d_scale) { - set_zero(dev_ctx, d_scale, static_cast(0)); - } - if (d_bias) { - set_zero(dev_ctx, d_bias, static_cast(0)); - } + if (imsize < vec_size * block_size_nchw) { ScalarGetDsDbCUDAKernel< T><<>>( imsize, x_data, dy_data, ds_data, db_data); @@ -687,7 +672,6 @@ class GroupNormGradKernel imsize, C, group_size, groups, p1_data, p2_data, p3_data, x_data, dy_data, d_x_data); } - } else { if (d_scale) { set_zero(dev_ctx, d_scale, static_cast(0));