未验证 提交 b0a3638f 编写于 作者: L LielinJiang 提交者: GitHub

Fix grad error of groupnorm op when cuda version==11.7 (#45738)

* fix grad error of grounorm op when cuda version==11.7
上级 31efe00a
......@@ -427,8 +427,21 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
}
CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data);
CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data);
if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
if (flags & kHasScale) {
#if CUDA_VERSION >= 11070
platform::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data);
#else
CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
#endif
}
if (flags & kHasBias) {
#if CUDA_VERSION >= 11070
platform::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data);
#else
CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
#endif
}
}
template <typename T, int flags>
......
......@@ -68,8 +68,21 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
}
CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data);
CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data);
if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
if (flags & kHasScale) {
#if CUDA_VERSION >= 11070
platform::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data);
#else
CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
#endif
}
if (flags & kHasBias) {
#if CUDA_VERSION >= 11070
platform::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data);
#else
CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
#endif
}
}
template <typename T, int flags>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册