diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index 668f69b4c75d9710d90c52ed617bfaa2027678ec..105d4d6c75efe18560c806d8ec9a456beb52fbf3 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -427,8 +427,21 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x, } CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data); CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data); - if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data); - if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data); + + if (flags & kHasScale) { +#if CUDA_VERSION >= 11070 + platform::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data); +#else + CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data); +#endif + } + if (flags & kHasBias) { +#if CUDA_VERSION >= 11070 + platform::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data); +#else + CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data); +#endif + } } template diff --git a/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu index 359dc8a0095f84528a79acad7893d11ecc8485b8..c33fbfbd51f475ae186798d2befb33d4368d5292 100644 --- a/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu @@ -68,8 +68,21 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x, } CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data); CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data); - if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data); - if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data); + + if (flags & kHasScale) { +#if CUDA_VERSION >= 11070 + platform::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data); +#else + CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data); +#endif + } + if (flags & kHasBias) { +#if CUDA_VERSION >= 11070 + platform::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data); +#else + CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data); +#endif + } } template