未验证 提交 3fe5cddf 编写于 作者: H hong19860320 提交者: GitHub

[LITE][XPU] Fix matmul op bridge (#2668)

上级 28481458
......@@ -49,9 +49,10 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto out = scope->FindMutableTensor(out_name);
auto out_dims = out->dims();
auto transpose_x = op_info->GetAttr<bool>("transpose_X");
CHECK(!transpose_x) << "XPU only support transpose_x == true now";
auto transpose_y = op_info->GetAttr<bool>("transpose_Y");
auto alpha = op_info->GetAttr<float>("alpha");
......@@ -71,11 +72,68 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
y_node = graph->AddNode(y_name, y_dims);
}
auto matmul_node =
graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y);
graph->AddNode(out_name, graph->builder_.CreateScale(matmul_node, alpha));
return SUCCESS;
// Matmul node
if (x_dims.size() > 2 && y_dims.size() >= 2) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N]
// Reshape and transposed X node
if (x_dims.size() != 3) {
auto m = static_cast<int>(x_dims[x_dims.size() - 2]);
auto k = static_cast<int>(x_dims[x_dims.size() - 1]);
x_node =
graph->AddNode(x_name + "/reshape",
graph->builder_.CreateReshape(*x_node, {-1, m, k}));
if (transpose_x) {
x_node =
graph->AddNode(x_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*x_node, {0, 2, 1}));
}
}
// Reshape and transposed Y node
if (y_dims.size() != 3) {
auto k = static_cast<int>(y_dims[y_dims.size() - 2]);
auto n = static_cast<int>(y_dims[y_dims.size() - 1]);
y_node =
graph->AddNode(y_name + "/reshape",
graph->builder_.CreateReshape(*y_node, {-1, k, n}));
if (!transpose_y) {
y_node =
graph->AddNode(y_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*y_node, {0, 2, 1}));
}
}
// Matmul node
auto matmul_node = graph->AddNode(
out_name, graph->builder_.CreateBatchMatmul(*x_node, *y_node));
if (fabs(alpha - 1) > 1e-6f) {
matmul_node = graph->AddNode(
out_name, graph->builder_.CreateScale(*matmul_node, alpha));
}
if (out_dims.size() != 3) {
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*matmul_node, CvtShape<xtcl::Integer>(out_dims)));
}
} else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N]
if (transpose_x) {
x_node = graph->AddNode(x_name + "/transpose",
graph->builder_.CreateTranspose(*x_node, {1, 0}));
}
auto matmul_node = graph->AddNode(
out_name,
graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y));
if (fabs(alpha - 1) > 1e-6f) {
matmul_node = graph->AddNode(
out_name, graph->builder_.CreateScale(*matmul_node, alpha));
}
} else if (x_dims.size() == 1 && y_dims.size() == 1) {
// x: [K], y: [K], out: [1]
// x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N]
LOG(FATAL) << "[XPU] Not supported.";
return FAILED;
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace xpu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册