diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc index 1524a50f1ac6d6afa67722bc5d1c16a581395bb2..87df75ac465042a0f7894abecb4be4c213e5d807 100644 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -38,7 +38,7 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( ColumnMatrixFromVector(y_dims), 0, trans_y); - if (x_dims.size() == 3 && y_dims.size() <= 2) { + if (x_dims.size() >= 3 && y_dims.size() <= 2) { // if transpose_X is true, the transpose cost much time if (!trans_x) { mat_dim_a.height_ *= mat_dim_a.batch_size_; diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 3789ec322ac9952011764dd5230b72eb4b9ada19..14f516235a720c1fb8f46fe6606ac8f0bdb149f9 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -249,7 +249,7 @@ XPUOpMap& get_kl2_ops() { {"not_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace())})}, {"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py index 45d60c8538e092f4c5d97f6525870af33a6ad9d5..9891da6ea21d9ac7c8c591f71183099858832140 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py @@ -289,6 +289,18 @@ class TestMatMulOp17(TestMatMulV2Op): self.trans_y = False +class TestMatMulOp18(TestMatMulV2Op): + """ + case 18 : for ppyoloe model + """ + + def config(self): + self.x_shape = (8, 111, 4, 17) + self.y_shape = (17) + self.trans_x = False + self.trans_y = False + + # class TestMatMulOpBroadcast1(TestMatMulV2Op): # """ # case 14_3