未验证 提交 b0a3638f 编写于 作者: L LielinJiang 提交者: GitHub

Fix grad error of groupnorm op when cuda version==11.7 (#45738)

* fix grad error of grounorm op when cuda version==11.7
上级 31efe00a
...@@ -427,8 +427,21 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x, ...@@ -427,8 +427,21 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
} }
CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data); CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data);
CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data); CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data);
if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data); if (flags & kHasScale) {
#if CUDA_VERSION >= 11070
platform::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data);
#else
CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
#endif
}
if (flags & kHasBias) {
#if CUDA_VERSION >= 11070
platform::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data);
#else
CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
#endif
}
} }
template <typename T, int flags> template <typename T, int flags>
......
...@@ -68,8 +68,21 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x, ...@@ -68,8 +68,21 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
} }
CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data); CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data);
CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data); CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data);
if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data); if (flags & kHasScale) {
#if CUDA_VERSION >= 11070
platform::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data);
#else
CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
#endif
}
if (flags & kHasBias) {
#if CUDA_VERSION >= 11070
platform::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data);
#else
CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
#endif
}
} }
template <typename T, int flags> template <typename T, int flags>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册