diff --git a/lite/kernels/xpu/bridges/matmul_op.cc b/lite/kernels/xpu/bridges/matmul_op.cc index eaf2370ada95e77f25c1b75fa09e19a669c15b93..330b336840148fa54d5c9f2eae39a08fdfad9557 100644 --- a/lite/kernels/xpu/bridges/matmul_op.cc +++ b/lite/kernels/xpu/bridges/matmul_op.cc @@ -49,9 +49,10 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto out_type = kernel->GetOutputDeclType("Out"); CHECK(out_type->precision() == PRECISION(kFloat)); CHECK(out_type->layout() == DATALAYOUT(kNCHW)); + auto out = scope->FindMutableTensor(out_name); + auto out_dims = out->dims(); auto transpose_x = op_info->GetAttr("transpose_X"); - CHECK(!transpose_x) << "XPU only support transpose_x == true now"; auto transpose_y = op_info->GetAttr("transpose_Y"); auto alpha = op_info->GetAttr("alpha"); @@ -71,11 +72,68 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) { y_node = graph->AddNode(y_name, y_dims); } - auto matmul_node = - graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y); - graph->AddNode(out_name, graph->builder_.CreateScale(matmul_node, alpha)); - - return SUCCESS; + // Matmul node + if (x_dims.size() > 2 && y_dims.size() >= 2) { + // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] + // x: [B, M, K], y: [K, N], out: [B, M, N] + // Reshape and transposed X node + if (x_dims.size() != 3) { + auto m = static_cast(x_dims[x_dims.size() - 2]); + auto k = static_cast(x_dims[x_dims.size() - 1]); + x_node = + graph->AddNode(x_name + "/reshape", + graph->builder_.CreateReshape(*x_node, {-1, m, k})); + if (transpose_x) { + x_node = + graph->AddNode(x_name + "/reshape/transpose", + graph->builder_.CreateTranspose(*x_node, {0, 2, 1})); + } + } + // Reshape and transposed Y node + if (y_dims.size() != 3) { + auto k = static_cast(y_dims[y_dims.size() - 2]); + auto n = static_cast(y_dims[y_dims.size() - 1]); + y_node = + graph->AddNode(y_name + "/reshape", + graph->builder_.CreateReshape(*y_node, {-1, k, n})); + if (!transpose_y) { + y_node = + graph->AddNode(y_name + "/reshape/transpose", + graph->builder_.CreateTranspose(*y_node, {0, 2, 1})); + } + } + // Matmul node + auto matmul_node = graph->AddNode( + out_name, graph->builder_.CreateBatchMatmul(*x_node, *y_node)); + if (fabs(alpha - 1) > 1e-6f) { + matmul_node = graph->AddNode( + out_name, graph->builder_.CreateScale(*matmul_node, alpha)); + } + if (out_dims.size() != 3) { + graph->AddNode(out_name, + graph->builder_.CreateReshape( + *matmul_node, CvtShape(out_dims))); + } + } else if (x_dims.size() == 2 && y_dims.size() == 2) { + // x: [M, K], y: [K, N], out: [M, N] + if (transpose_x) { + x_node = graph->AddNode(x_name + "/transpose", + graph->builder_.CreateTranspose(*x_node, {1, 0})); + } + auto matmul_node = graph->AddNode( + out_name, + graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y)); + if (fabs(alpha - 1) > 1e-6f) { + matmul_node = graph->AddNode( + out_name, graph->builder_.CreateScale(*matmul_node, alpha)); + } + } else if (x_dims.size() == 1 && y_dims.size() == 1) { + // x: [K], y: [K], out: [1] + // x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N] + LOG(FATAL) << "[XPU] Not supported."; + return FAILED; + } + return REBUILD_WHEN_SHAPE_CHANGED; } } // namespace xpu