未验证 提交 594e412d 编写于 作者: H houj04 提交者: GitHub

minor fix matmul and onehot xpu. test=kunlun (#40419)

上级 82c30f71
...@@ -38,7 +38,7 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, ...@@ -38,7 +38,7 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(
ColumnMatrixFromVector(y_dims), 0, trans_y); ColumnMatrixFromVector(y_dims), 0, trans_y);
if (x_dims.size() == 3 && y_dims.size() <= 2) { if (x_dims.size() >= 3 && y_dims.size() <= 2) {
// if transpose_X is true, the transpose cost much time // if transpose_X is true, the transpose cost much time
if (!trans_x) { if (!trans_x) {
mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.height_ *= mat_dim_a.batch_size_;
......
...@@ -249,7 +249,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -249,7 +249,7 @@ XPUOpMap& get_kl2_ops() {
{"not_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), {"not_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})}, pOpKernelType(vartype::INT64, XPUPlace())})},
{"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
......
...@@ -289,6 +289,18 @@ class TestMatMulOp17(TestMatMulV2Op): ...@@ -289,6 +289,18 @@ class TestMatMulOp17(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMulOp18(TestMatMulV2Op):
"""
case 18 : for ppyoloe model
"""
def config(self):
self.x_shape = (8, 111, 4, 17)
self.y_shape = (17)
self.trans_x = False
self.trans_y = False
# class TestMatMulOpBroadcast1(TestMatMulV2Op): # class TestMatMulOpBroadcast1(TestMatMulV2Op):
# """ # """
# case 14_3 # case 14_3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册