未验证 提交 98e803e0 编写于 作者: P Pei Yang 提交者: GitHub

map_matmul_to_mul_pass support 3dim (#31958)

上级 a37a7f67
......@@ -57,7 +57,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
flag = flag && x_rank == 2 && y_rank == 2;
flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
......@@ -69,7 +69,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetInput("X", {matmul_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册