未验证 提交 da558f0e 编写于 作者: Q Qi Li 提交者: GitHub

[ROCm] fix bfloat16 support, test=develop (#40401)

上级 60899549
......@@ -423,7 +423,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
}
if (data_type == framework::proto::VarType::BF16) {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<platform::CUDADeviceContext, platform::bfloat16>(
src_tensor, dst_tensor, place);
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册