From 1066094a035ec7d21522a90d22c93f72b9e8e4ba Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 23 Nov 2022 15:39:00 +0800 Subject: [PATCH] Make bfloat16 implicitly convert to float/double (#48238) * make bfloat16 implicit convert to float/double * fix bfloat16_test ut compile --- paddle/fluid/platform/bfloat16_test.cu | 2 +- paddle/phi/backends/gpu/cuda/cuda_device_function.h | 8 +++----- paddle/phi/common/bfloat16.h | 6 +++--- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/platform/bfloat16_test.cu b/paddle/fluid/platform/bfloat16_test.cu index 1e1919bfca0..cec83cbd11f 100644 --- a/paddle/fluid/platform/bfloat16_test.cu +++ b/paddle/fluid/platform/bfloat16_test.cu @@ -39,7 +39,7 @@ TEST(bfloat16, convert_float32_to_bfloat16_on_gpu) { TEST(bfloat16, assignment_operator_on_gpu) { // Assignment operator bfloat16 v_assign; - v_assign = nv_bfloat16(bfloat16(1.0f)); + v_assign = bfloat16(1.0f).to_nv_bfloat16(); EXPECT_EQ(v_assign.x, 0x3f80); v_assign = 0.33333; EXPECT_EQ(v_assign.x, 0x3eab); diff --git a/paddle/phi/backends/gpu/cuda/cuda_device_function.h b/paddle/phi/backends/gpu/cuda/cuda_device_function.h index 10aee53c45c..4ff2e528a91 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_device_function.h +++ b/paddle/phi/backends/gpu/cuda/cuda_device_function.h @@ -67,10 +67,8 @@ template <> __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync( unsigned mask, phi::dtype::bfloat16 val, int delta, int width) { #if defined(PADDLE_CUDA_BF16) - return phi::dtype::bfloat16(__shfl_down_sync(mask, - static_cast(val), - static_cast(delta), - width)); + return phi::dtype::bfloat16(__shfl_down_sync( + mask, val.to_nv_bfloat16(), static_cast(delta), width)); #else PADDLE_ENFORCE( false, "__shfl_down_sync with bfloat16 is not supported on cuda <= 11."); @@ -114,7 +112,7 @@ __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleXorSync( unsigned mask, phi::dtype::bfloat16 val, int width) { #if defined(PADDLE_CUDA_BF16) return phi::dtype::bfloat16( - __shfl_xor_sync(mask, static_cast(val), width)); + __shfl_xor_sync(mask, val.to_nv_bfloat16(), width)); #else PADDLE_ENFORCE( false, "__shfl_xor_sync with bfloat16 is not supported on cuda <= 11."); diff --git a/paddle/phi/common/bfloat16.h b/paddle/phi/common/bfloat16.h index 6a11f0c0714..37e4b55fbbc 100644 --- a/paddle/phi/common/bfloat16.h +++ b/paddle/phi/common/bfloat16.h @@ -145,7 +145,7 @@ struct PADDLE_ALIGN(2) bfloat16 { } // Conversion opertors - HOSTDEVICE inline explicit operator float() const { + HOSTDEVICE inline operator float() const { #ifdef PADDLE_WITH_HIP uint32_t res = 0; // We should be using memcpy in order to respect the strict aliasing rule @@ -168,7 +168,7 @@ struct PADDLE_ALIGN(2) bfloat16 { } #ifdef PADDLE_CUDA_BF16 - HOSTDEVICE inline explicit operator __nv_bfloat16() const { + HOSTDEVICE inline __nv_bfloat16 to_nv_bfloat16() const { return *reinterpret_cast(&x); } #endif @@ -207,7 +207,7 @@ struct PADDLE_ALIGN(2) bfloat16 { return static_cast(static_cast(*this)); } - HOSTDEVICE inline explicit operator double() const { + HOSTDEVICE inline operator double() const { return static_cast(static_cast(*this)); } }; -- GitLab