diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc index a2443c86986ec87cc29e9897fe0d38883f8fafa1..c36123f65f6644289cfba2b2729862efa601e2fd 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -57,7 +57,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { std::vector y_shape = matmul_in_y->Var()->GetShape(); size_t x_rank = x_shape.size(); size_t y_rank = y_shape.size(); - flag = flag && x_rank == 2 && y_rank == 2; + flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2; std::vector& next_ops = matmul_out->outputs; flag = flag && next_ops.size() == 1 && @@ -69,7 +69,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { desc.SetInput("X", {matmul_in_x->Name()}); desc.SetInput("Y", {matmul_in_y->Name()}); desc.SetOutput("Out", {matmul_out->Name()}); - desc.SetAttr("x_num_col_dims", 1); + desc.SetAttr("x_num_col_dims", static_cast(x_rank - 1)); desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));