未验证 提交 6a0102b0 编写于 作者: C cc 提交者: GitHub

map matmul/squeeze2+matmul/reshape2+matmul to mul (#29911)

* map matmul/squeeze2+matmul/reshape2+matmul to mul
上级 d038746e
......@@ -60,6 +60,7 @@ pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(fc_fuse_pass inference)
pass_library(map_matmul_to_mul_pass inference)
pass_library(attention_lstm_fuse_pass inference)
pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
......
......@@ -1572,6 +1572,65 @@ PDNode *patterns::Reshape::operator()() {
}
PDNode *patterns::Matmul::operator()() {
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->AsInput()
->assert_is_op_input("matmul", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out});
return matmul_out;
}
PDNode *patterns::Squeeze2Matmul::operator()() {
auto squeeze2_in_x = pattern->NewNode(squeeze2_in_x_repr())
->assert_is_op_input("squeeze2", "X")
->AsInput();
auto squeeze2_op =
pattern->NewNode(squeeze2_op_repr())->assert_is_op("squeeze2");
auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->assert_is_op_output("squeeze2", "Out")
->assert_is_op_input("matmul", "X");
auto matmul_in_y =
pattern->NewNode(matmul_in_y_repr())->assert_is_op_input("matmul", "Y");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
squeeze2_op->LinksFrom({squeeze2_in_x}).LinksTo({matmul_in_x});
matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out});
return matmul_out;
}
PDNode *patterns::Reshape2Matmul::operator()() {
auto reshape2_in_x = pattern->NewNode(reshape2_in_x_repr())
->assert_is_op_input("reshape2", "X")
->AsInput();
auto reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("matmul", "X");
auto matmul_in_y =
pattern->NewNode(matmul_in_y_repr())->assert_is_op_input("matmul", "Y");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
reshape2_op->LinksFrom({reshape2_in_x}).LinksTo({matmul_in_x});
matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out});
return matmul_out;
}
PDNode *patterns::MatmulWithInputOps::operator()() {
auto prev_op_x = pattern->NewNode(prev_op_x_repr())->assert_is_op();
auto prev_op_y = pattern->NewNode(prev_op_y_repr())->assert_is_op();
......
......@@ -961,10 +961,52 @@ struct Reshape : public PatternBase {
// Matmul op
// Forward pass for matmul.
// matmul_out is a result of the operator.
struct Matmul : public PatternBase {
Matmul(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape2") {}
: PatternBase(pattern, name_scope, "matmul") {}
PDNode* operator()();
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
};
// Squeeze2 + Matmul
// Forward pass.
struct Squeeze2Matmul : public PatternBase {
Squeeze2Matmul(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "squeeze2_matmul") {}
PDNode* operator()();
PATTERN_DECL_NODE(squeeze2_in_x);
PATTERN_DECL_NODE(squeeze2_op);
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
};
// Reshape2 + Matmul
// Forward pass.
struct Reshape2Matmul : public PatternBase {
Reshape2Matmul(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape2_matmul") {}
PDNode* operator()();
PATTERN_DECL_NODE(reshape2_in_x);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
};
// Forward pass for two input ops and matmul op.
// matmul_out is a result of the operator.
struct MatmulWithInputOps : public PatternBase {
MatmulWithInputOps(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_with_input_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op_x);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/map_matmul_to_mul_pass.h"
#include <cmath>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_to_mul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::Matmul matmul_pattern(gpd.mutable_pattern(), name_scope);
matmul_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
bool flag = true;
bool transpose_X =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X"));
bool transpose_Y =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y"));
float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
flag = flag && !transpose_X && !transpose_Y && std::abs(alpha - 1.0) < 1e-5;
std::vector<int64_t> x_shape = matmul_in_x->Var()->GetShape();
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
flag = flag && x_rank == 2 && y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";
if (flag) {
OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {matmul_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1);
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {matmul_op});
++found_count;
}
};
gpd(graph, handler);
AddStatis(found_count);
}
void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "squeeze2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::Squeeze2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope);
fuse_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse squeeze2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern);
bool flag = true;
size_t squeeze2_in_x_rank = (squeeze2_in_x->Var()->GetShape()).size();
std::vector<int> squeeze2_op_axes =
BOOST_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes"));
flag = flag && squeeze2_in_x_rank == 4 &&
squeeze2_op_axes == std::vector<int>{2, 3} &&
(matmul_in_x->outputs).size() == 1;
bool transpose_X =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X"));
bool transpose_Y =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y"));
float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size();
size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size();
flag = flag && !transpose_X && !transpose_Y &&
std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 &&
matmul_in_y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";
if (flag) {
OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {squeeze2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1);
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(squeeze2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op});
++found_count;
}
};
gpd(graph, handler);
AddStatis(found_count);
}
void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "reshape2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::Reshape2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope);
fuse_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern);
bool flag = true;
size_t reshape2_in_nums = reshape2_op->inputs.size();
auto reshape2_in_x_shape = reshape2_in_x->Var()->GetShape();
size_t reshape2_in_x_rank = reshape2_in_x_shape.size();
std::vector<int> reshape2_op_shape =
BOOST_GET_CONST(std::vector<int>, reshape2_op->Op()->GetAttr("shape"));
flag = flag && reshape2_in_nums == 1 && reshape2_in_x_rank == 4 &&
reshape2_in_x_shape[2] == 1 && reshape2_in_x_shape[3] == 1 &&
reshape2_op_shape.size() == 2 && (matmul_in_x->outputs).size() == 1;
bool transpose_X =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X"));
bool transpose_Y =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y"));
float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size();
size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size();
flag = flag && !transpose_X && !transpose_Y &&
std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 &&
matmul_in_y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";
if (flag) {
OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {reshape2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1);
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(reshape2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {reshape2_op, matmul_in_x, matmul_op});
++found_count;
}
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0)
.EQ("mul", 0));
REGISTER_PASS(squeeze2_matmul_fuse_pass,
paddle::framework::ir::Squeeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0)
.EQ("squeeze2", 0)
.EQ("mul", 0));
REGISTER_PASS(reshape2_matmul_fuse_pass,
paddle::framework::ir::Reshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0)
.EQ("reshape2", 0)
.EQ("mul", 0));
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Map matmul to mul, so the optimization can use fc_fuse_pass.
* The mul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the rank of input activation is obtained from var_desc,
* it maybe change in runtime.
*/
class Graph;
class MapMatmul2MulPass : public FusePassBase {
public:
virtual ~MapMatmul2MulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Fuse squeeze2+matmul to mul, so the optimization can use fc_fuse_pass.
* The squeeze2 op must satisfy the following conditions:
* 1. the rank of input X is 4
* 2. the axis attr is [2, 3]
* 3. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the rank of input activation is obtained from var_desc,
* it maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class Squeeze2MatmulFusePass : public FusePassBase {
public:
virtual ~Squeeze2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Fuse reshape2+matmul to mul, so the optimization can use fc_fuse_pass.
* The reshape2 op must satisfy the following conditions:
* 1. reshape2 has one input node, which means it don't
* have Shape or ShapeTensor input
* 2. the rank of input X is 4 and the last two dims of input X is 1
* 3. the rank of shape attr is 2
* 4. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the shape and rank of input activation is obtained from var_desc,
* they maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class Reshape2MatmulFusePass : public FusePassBase {
public:
virtual ~Reshape2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -679,7 +679,7 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Matmul matmul_pattern{pattern, name_scope_};
patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_};
matmul_pattern();
int quantize_matmul_count = 0;
......
......@@ -82,8 +82,11 @@ const std::vector<std::string> kTRTSubgraphPasses({
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"skip_layernorm_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass",
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
......@@ -113,6 +116,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
......@@ -164,6 +170,9 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册