未验证 提交 b6786ff7 编写于 作者: R Rane2021 提交者: GitHub

[ROCM]fix bfloat16 to float error! (#56517)

上级 ea4182d7
...@@ -157,7 +157,9 @@ struct PADDLE_ALIGN(2) bfloat16 { ...@@ -157,7 +157,9 @@ struct PADDLE_ALIGN(2) bfloat16 {
uint16_t temp = x; uint16_t temp = x;
uint16_t* temp_ptr = reinterpret_cast<uint16_t*>(&temp); uint16_t* temp_ptr = reinterpret_cast<uint16_t*>(&temp);
res = *temp_ptr; res = *temp_ptr;
return res; // return res;
res = res << 16;
return *reinterpret_cast<float*>(&res);
#else #else
#ifdef PADDLE_CUDA_BF16 #ifdef PADDLE_CUDA_BF16
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x)); return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册