未验证 提交 033ebe7e 编写于 作者: S sneaxiy 提交者: GitHub

Refine CUDA atomicAdd for FP16 by CUDA primitive methods (#37895)

* fix cuda atomicAdd for FP16

* try to fix ci
上级 491d4f01
...@@ -101,6 +101,20 @@ inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) { ...@@ -101,6 +101,20 @@ inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) {
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16); return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
} }
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
static __device__ __forceinline__ float16 CUDAFP16ToPDFP16(__half x) {
return *reinterpret_cast<float16 *>(&x);
}
static __device__ __forceinline__ __half PDFP16ToCUDAFP16(float16 x) {
return *reinterpret_cast<__half *>(&x);
}
CUDA_ATOMIC_WRAPPER(Add, float16) {
return CUDAFP16ToPDFP16(
atomicAdd(reinterpret_cast<__half *>(address), PDFP16ToCUDAFP16(val)));
}
#else
CUDA_ATOMIC_WRAPPER(Add, float16) { CUDA_ATOMIC_WRAPPER(Add, float16) {
// concrete packed float16 value may exsits in lower or higher 16bits // concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address. // of the 32bits address.
...@@ -133,6 +147,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { ...@@ -133,6 +147,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
} }
} }
#endif #endif
#endif
CUDA_ATOMIC_WRAPPER(Add, complex<float>) { CUDA_ATOMIC_WRAPPER(Add, complex<float>) {
float *real = reinterpret_cast<float *>(address); float *real = reinterpret_cast<float *>(address);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册