From 594e412d35dc307acc0c93f2c44a2ce5ecaeb42f Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Fri, 11 Mar 2022 14:47:48 +0800 Subject: [PATCH] minor fix matmul and onehot xpu. test=kunlun (#40419) --- paddle/fluid/operators/matmul_v2_op_xpu.cc | 2 +- paddle/fluid/platform/device/xpu/xpu2_op_list.h | 2 +- .../tests/unittests/xpu/test_matmul_v2_op_xpu.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc index 1524a50f1ac..87df75ac465 100644 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -38,7 +38,7 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( ColumnMatrixFromVector(y_dims), 0, trans_y); - if (x_dims.size() == 3 && y_dims.size() <= 2) { + if (x_dims.size() >= 3 && y_dims.size() <= 2) { // if transpose_X is true, the transpose cost much time if (!trans_x) { mat_dim_a.height_ *= mat_dim_a.batch_size_; diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 3789ec322ac..14f516235a7 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -249,7 +249,7 @@ XPUOpMap& get_kl2_ops() { {"not_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace())})}, {"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py index 45d60c8538e..9891da6ea21 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py @@ -289,6 +289,18 @@ class TestMatMulOp17(TestMatMulV2Op): self.trans_y = False +class TestMatMulOp18(TestMatMulV2Op): + """ + case 18 : for ppyoloe model + """ + + def config(self): + self.x_shape = (8, 111, 4, 17) + self.y_shape = (17) + self.trans_x = False + self.trans_y = False + + # class TestMatMulOpBroadcast1(TestMatMulV2Op): # """ # case 14_3 -- GitLab