未验证 提交 1066094a 编写于 作者: S sneaxiy 提交者: GitHub

Make bfloat16 implicitly convert to float/double (#48238)

* make bfloat16 implicit convert to float/double

* fix bfloat16_test ut compile
上级 29d75c14
...@@ -39,7 +39,7 @@ TEST(bfloat16, convert_float32_to_bfloat16_on_gpu) { ...@@ -39,7 +39,7 @@ TEST(bfloat16, convert_float32_to_bfloat16_on_gpu) {
TEST(bfloat16, assignment_operator_on_gpu) { TEST(bfloat16, assignment_operator_on_gpu) {
// Assignment operator // Assignment operator
bfloat16 v_assign; bfloat16 v_assign;
v_assign = nv_bfloat16(bfloat16(1.0f)); v_assign = bfloat16(1.0f).to_nv_bfloat16();
EXPECT_EQ(v_assign.x, 0x3f80); EXPECT_EQ(v_assign.x, 0x3f80);
v_assign = 0.33333; v_assign = 0.33333;
EXPECT_EQ(v_assign.x, 0x3eab); EXPECT_EQ(v_assign.x, 0x3eab);
......
...@@ -67,10 +67,8 @@ template <> ...@@ -67,10 +67,8 @@ template <>
__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync( __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync(
unsigned mask, phi::dtype::bfloat16 val, int delta, int width) { unsigned mask, phi::dtype::bfloat16 val, int delta, int width) {
#if defined(PADDLE_CUDA_BF16) #if defined(PADDLE_CUDA_BF16)
return phi::dtype::bfloat16(__shfl_down_sync(mask, return phi::dtype::bfloat16(__shfl_down_sync(
static_cast<nv_bfloat16>(val), mask, val.to_nv_bfloat16(), static_cast<unsigned>(delta), width));
static_cast<unsigned>(delta),
width));
#else #else
PADDLE_ENFORCE( PADDLE_ENFORCE(
false, "__shfl_down_sync with bfloat16 is not supported on cuda <= 11."); false, "__shfl_down_sync with bfloat16 is not supported on cuda <= 11.");
...@@ -114,7 +112,7 @@ __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleXorSync( ...@@ -114,7 +112,7 @@ __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleXorSync(
unsigned mask, phi::dtype::bfloat16 val, int width) { unsigned mask, phi::dtype::bfloat16 val, int width) {
#if defined(PADDLE_CUDA_BF16) #if defined(PADDLE_CUDA_BF16)
return phi::dtype::bfloat16( return phi::dtype::bfloat16(
__shfl_xor_sync(mask, static_cast<nv_bfloat16>(val), width)); __shfl_xor_sync(mask, val.to_nv_bfloat16(), width));
#else #else
PADDLE_ENFORCE( PADDLE_ENFORCE(
false, "__shfl_xor_sync with bfloat16 is not supported on cuda <= 11."); false, "__shfl_xor_sync with bfloat16 is not supported on cuda <= 11.");
......
...@@ -145,7 +145,7 @@ struct PADDLE_ALIGN(2) bfloat16 { ...@@ -145,7 +145,7 @@ struct PADDLE_ALIGN(2) bfloat16 {
} }
// Conversion opertors // Conversion opertors
HOSTDEVICE inline explicit operator float() const { HOSTDEVICE inline operator float() const {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
uint32_t res = 0; uint32_t res = 0;
// We should be using memcpy in order to respect the strict aliasing rule // We should be using memcpy in order to respect the strict aliasing rule
...@@ -168,7 +168,7 @@ struct PADDLE_ALIGN(2) bfloat16 { ...@@ -168,7 +168,7 @@ struct PADDLE_ALIGN(2) bfloat16 {
} }
#ifdef PADDLE_CUDA_BF16 #ifdef PADDLE_CUDA_BF16
HOSTDEVICE inline explicit operator __nv_bfloat16() const { HOSTDEVICE inline __nv_bfloat16 to_nv_bfloat16() const {
return *reinterpret_cast<const __nv_bfloat16*>(&x); return *reinterpret_cast<const __nv_bfloat16*>(&x);
} }
#endif #endif
...@@ -207,7 +207,7 @@ struct PADDLE_ALIGN(2) bfloat16 { ...@@ -207,7 +207,7 @@ struct PADDLE_ALIGN(2) bfloat16 {
return static_cast<uint64_t>(static_cast<float>(*this)); return static_cast<uint64_t>(static_cast<float>(*this));
} }
HOSTDEVICE inline explicit operator double() const { HOSTDEVICE inline operator double() const {
return static_cast<double>(static_cast<float>(*this)); return static_cast<double>(static_cast<float>(*this));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册