From 53f5edbd1107b22a51e7f025a47ecfec7c438b18 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Fri, 31 Mar 2023 10:56:35 +0800 Subject: [PATCH] [XPU] register bmm fp16 (#52354) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index a3988fbe2c8..023cfa33bef 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -59,8 +59,9 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"batch_norm", XPUKernelSet({phi::DataType::FLOAT32})}, - {"bmm", XPUKernelSet({phi::DataType::FLOAT32})}, - {"bmm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"bmm_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"bce_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"bce_loss", XPUKernelSet({phi::DataType::FLOAT32})}, {"beam_search", -- GitLab