未验证 提交 da558f0e 编写于 作者: Q Qi Li 提交者: GitHub

[ROCm] fix bfloat16 support, test=develop (#40401)

上级 60899549
...@@ -423,7 +423,7 @@ void TensorAdd(const VarType& src, VarType* dst) { ...@@ -423,7 +423,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
} }
if (data_type == framework::proto::VarType::BF16) { if (data_type == framework::proto::VarType::BF16) {
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<platform::CUDADeviceContext, platform::bfloat16>( return TensorAddImpl<platform::CUDADeviceContext, platform::bfloat16>(
src_tensor, dst_tensor, place); src_tensor, dst_tensor, place);
#else #else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册