Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6a0102b0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6a0102b0
编写于
12月 29, 2020
作者:
C
cc
提交者:
GitHub
12月 29, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
map matmul/squeeze2+matmul/reshape2+matmul to mul (#29911)
* map matmul/squeeze2+matmul/reshape2+matmul to mul
上级
d038746e
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
474 addition
and
8 deletion
+474
-8
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+59
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+44
-2
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
+249
-0
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
+106
-0
paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
+1
-1
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+14
-5
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
6a0102b0
...
@@ -60,6 +60,7 @@ pass_library(graph_to_program_pass base)
...
@@ -60,6 +60,7 @@ pass_library(graph_to_program_pass base)
pass_library
(
graph_viz_pass base
)
pass_library
(
graph_viz_pass base
)
pass_library
(
lock_free_optimize_pass base
)
pass_library
(
lock_free_optimize_pass base
)
pass_library
(
fc_fuse_pass inference
)
pass_library
(
fc_fuse_pass inference
)
pass_library
(
map_matmul_to_mul_pass inference
)
pass_library
(
attention_lstm_fuse_pass inference
)
pass_library
(
attention_lstm_fuse_pass inference
)
pass_library
(
fc_lstm_fuse_pass inference
)
pass_library
(
fc_lstm_fuse_pass inference
)
pass_library
(
embedding_fc_lstm_fuse_pass inference
)
pass_library
(
embedding_fc_lstm_fuse_pass inference
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
6a0102b0
...
@@ -1572,6 +1572,65 @@ PDNode *patterns::Reshape::operator()() {
...
@@ -1572,6 +1572,65 @@ PDNode *patterns::Reshape::operator()() {
}
}
PDNode
*
patterns
::
Matmul
::
operator
()()
{
PDNode
*
patterns
::
Matmul
::
operator
()()
{
auto
matmul_op
=
pattern
->
NewNode
(
matmul_op_repr
())
->
assert_is_op
(
"matmul"
);
auto
matmul_in_x
=
pattern
->
NewNode
(
matmul_in_x_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
matmul_in_y
=
pattern
->
NewNode
(
matmul_in_y_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
matmul_out
=
pattern
->
NewNode
(
matmul_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"matmul"
,
"Out"
);
matmul_op
->
LinksFrom
({
matmul_in_x
,
matmul_in_y
}).
LinksTo
({
matmul_out
});
return
matmul_out
;
}
PDNode
*
patterns
::
Squeeze2Matmul
::
operator
()()
{
auto
squeeze2_in_x
=
pattern
->
NewNode
(
squeeze2_in_x_repr
())
->
assert_is_op_input
(
"squeeze2"
,
"X"
)
->
AsInput
();
auto
squeeze2_op
=
pattern
->
NewNode
(
squeeze2_op_repr
())
->
assert_is_op
(
"squeeze2"
);
auto
matmul_in_x
=
pattern
->
NewNode
(
matmul_in_x_repr
())
->
assert_is_op_output
(
"squeeze2"
,
"Out"
)
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
matmul_in_y
=
pattern
->
NewNode
(
matmul_in_y_repr
())
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
matmul_op
=
pattern
->
NewNode
(
matmul_op_repr
())
->
assert_is_op
(
"matmul"
);
auto
matmul_out
=
pattern
->
NewNode
(
matmul_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"matmul"
,
"Out"
);
squeeze2_op
->
LinksFrom
({
squeeze2_in_x
}).
LinksTo
({
matmul_in_x
});
matmul_op
->
LinksFrom
({
matmul_in_x
,
matmul_in_y
}).
LinksTo
({
matmul_out
});
return
matmul_out
;
}
PDNode
*
patterns
::
Reshape2Matmul
::
operator
()()
{
auto
reshape2_in_x
=
pattern
->
NewNode
(
reshape2_in_x_repr
())
->
assert_is_op_input
(
"reshape2"
,
"X"
)
->
AsInput
();
auto
reshape2_op
=
pattern
->
NewNode
(
reshape2_op_repr
())
->
assert_is_op
(
"reshape2"
);
auto
matmul_in_x
=
pattern
->
NewNode
(
matmul_in_x_repr
())
->
assert_is_op_output
(
"reshape2"
,
"Out"
)
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
matmul_in_y
=
pattern
->
NewNode
(
matmul_in_y_repr
())
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
matmul_op
=
pattern
->
NewNode
(
matmul_op_repr
())
->
assert_is_op
(
"matmul"
);
auto
matmul_out
=
pattern
->
NewNode
(
matmul_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"matmul"
,
"Out"
);
reshape2_op
->
LinksFrom
({
reshape2_in_x
}).
LinksTo
({
matmul_in_x
});
matmul_op
->
LinksFrom
({
matmul_in_x
,
matmul_in_y
}).
LinksTo
({
matmul_out
});
return
matmul_out
;
}
PDNode
*
patterns
::
MatmulWithInputOps
::
operator
()()
{
auto
prev_op_x
=
pattern
->
NewNode
(
prev_op_x_repr
())
->
assert_is_op
();
auto
prev_op_x
=
pattern
->
NewNode
(
prev_op_x_repr
())
->
assert_is_op
();
auto
prev_op_y
=
pattern
->
NewNode
(
prev_op_y_repr
())
->
assert_is_op
();
auto
prev_op_y
=
pattern
->
NewNode
(
prev_op_y_repr
())
->
assert_is_op
();
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
6a0102b0
...
@@ -961,10 +961,52 @@ struct Reshape : public PatternBase {
...
@@ -961,10 +961,52 @@ struct Reshape : public PatternBase {
// Matmul op
// Matmul op
// Forward pass for matmul.
// Forward pass for matmul.
// matmul_out is a result of the operator.
struct
Matmul
:
public
PatternBase
{
struct
Matmul
:
public
PatternBase
{
Matmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
Matmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"reshape2"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"matmul"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
matmul_in_x
);
PATTERN_DECL_NODE
(
matmul_in_y
);
PATTERN_DECL_NODE
(
matmul_op
);
PATTERN_DECL_NODE
(
matmul_out
);
};
// Squeeze2 + Matmul
// Forward pass.
struct
Squeeze2Matmul
:
public
PatternBase
{
Squeeze2Matmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"squeeze2_matmul"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
squeeze2_in_x
);
PATTERN_DECL_NODE
(
squeeze2_op
);
PATTERN_DECL_NODE
(
matmul_in_x
);
PATTERN_DECL_NODE
(
matmul_in_y
);
PATTERN_DECL_NODE
(
matmul_op
);
PATTERN_DECL_NODE
(
matmul_out
);
};
// Reshape2 + Matmul
// Forward pass.
struct
Reshape2Matmul
:
public
PatternBase
{
Reshape2Matmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"reshape2_matmul"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
reshape2_in_x
);
PATTERN_DECL_NODE
(
reshape2_op
);
PATTERN_DECL_NODE
(
matmul_in_x
);
PATTERN_DECL_NODE
(
matmul_in_y
);
PATTERN_DECL_NODE
(
matmul_op
);
PATTERN_DECL_NODE
(
matmul_out
);
};
// Forward pass for two input ops and matmul op.
// matmul_out is a result of the operator.
struct
MatmulWithInputOps
:
public
PatternBase
{
MatmulWithInputOps
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"matmul_with_input_ops"
)
{}
PDNode
*
operator
()();
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
prev_op_x
);
PATTERN_DECL_NODE
(
prev_op_x
);
...
...
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
0 → 100644
浏览文件 @
6a0102b0
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/map_matmul_to_mul_pass.h"
#include <cmath>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
MapMatmul2MulPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
std
::
string
name_scope
=
"map_matmul_to_mul_pass"
;
FusePassBase
::
Init
(
name_scope
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
Matmul
matmul_pattern
(
gpd
.
mutable_pattern
(),
name_scope
);
matmul_pattern
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"map matmul to mul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_x
,
matmul_in_x
,
matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_y
,
matmul_in_y
,
matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_op
,
matmul_op
,
matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_out
,
matmul_out
,
matmul_pattern
);
bool
flag
=
true
;
bool
transpose_X
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_X"
));
bool
transpose_Y
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_Y"
));
float
alpha
=
BOOST_GET_CONST
(
float
,
matmul_op
->
Op
()
->
GetAttr
(
"alpha"
));
flag
=
flag
&&
!
transpose_X
&&
!
transpose_Y
&&
std
::
abs
(
alpha
-
1.0
)
<
1e-5
;
std
::
vector
<
int64_t
>
x_shape
=
matmul_in_x
->
Var
()
->
GetShape
();
std
::
vector
<
int64_t
>
y_shape
=
matmul_in_y
->
Var
()
->
GetShape
();
size_t
x_rank
=
x_shape
.
size
();
size_t
y_rank
=
y_shape
.
size
();
flag
=
flag
&&
x_rank
==
2
&&
y_rank
==
2
;
std
::
vector
<
Node
*>&
next_ops
=
matmul_out
->
outputs
;
flag
=
flag
&&
next_ops
.
size
()
==
1
&&
next_ops
[
0
]
->
Name
()
==
"elementwise_add"
;
if
(
flag
)
{
OpDesc
desc
;
desc
.
SetType
(
"mul"
);
desc
.
SetInput
(
"X"
,
{
matmul_in_x
->
Name
()});
desc
.
SetInput
(
"Y"
,
{
matmul_in_y
->
Name
()});
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
1
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
matmul_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
IR_NODE_LINK_TO
(
mul_node
,
matmul_out
);
GraphSafeRemoveNodes
(
graph
,
{
matmul_op
});
++
found_count
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
void
Squeeze2MatmulFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
std
::
string
name_scope
=
"squeeze2_matmul_fuse_pass"
;
FusePassBase
::
Init
(
name_scope
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
Squeeze2Matmul
fuse_pattern
(
gpd
.
mutable_pattern
(),
name_scope
);
fuse_pattern
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"fuse squeeze2+matmul to mul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
squeeze2_in_x
,
squeeze2_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
squeeze2_op
,
squeeze2_op
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_x
,
matmul_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_y
,
matmul_in_y
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_op
,
matmul_op
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_out
,
matmul_out
,
fuse_pattern
);
bool
flag
=
true
;
size_t
squeeze2_in_x_rank
=
(
squeeze2_in_x
->
Var
()
->
GetShape
()).
size
();
std
::
vector
<
int
>
squeeze2_op_axes
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
squeeze2_op
->
Op
()
->
GetAttr
(
"axes"
));
flag
=
flag
&&
squeeze2_in_x_rank
==
4
&&
squeeze2_op_axes
==
std
::
vector
<
int
>
{
2
,
3
}
&&
(
matmul_in_x
->
outputs
).
size
()
==
1
;
bool
transpose_X
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_X"
));
bool
transpose_Y
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_Y"
));
float
alpha
=
BOOST_GET_CONST
(
float
,
matmul_op
->
Op
()
->
GetAttr
(
"alpha"
));
size_t
matmul_in_x_rank
=
(
matmul_in_x
->
Var
()
->
GetShape
()).
size
();
size_t
matmul_in_y_rank
=
(
matmul_in_y
->
Var
()
->
GetShape
()).
size
();
flag
=
flag
&&
!
transpose_X
&&
!
transpose_Y
&&
std
::
abs
(
alpha
-
1.0
)
<
1e-5
&&
matmul_in_x_rank
==
2
&&
matmul_in_y_rank
==
2
;
std
::
vector
<
Node
*>&
next_ops
=
matmul_out
->
outputs
;
flag
=
flag
&&
next_ops
.
size
()
==
1
&&
next_ops
[
0
]
->
Name
()
==
"elementwise_add"
;
if
(
flag
)
{
OpDesc
desc
;
desc
.
SetType
(
"mul"
);
desc
.
SetInput
(
"X"
,
{
squeeze2_in_x
->
Name
()});
desc
.
SetInput
(
"Y"
,
{
matmul_in_y
->
Name
()});
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
1
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
squeeze2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
IR_NODE_LINK_TO
(
mul_node
,
matmul_out
);
GraphSafeRemoveNodes
(
graph
,
{
squeeze2_op
,
matmul_in_x
,
matmul_op
});
++
found_count
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
void
Reshape2MatmulFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
std
::
string
name_scope
=
"reshape2_matmul_fuse_pass"
;
FusePassBase
::
Init
(
name_scope
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
Reshape2Matmul
fuse_pattern
(
gpd
.
mutable_pattern
(),
name_scope
);
fuse_pattern
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"fuse reshape2+matmul to mul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_in_x
,
reshape2_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_op
,
reshape2_op
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_x
,
matmul_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_y
,
matmul_in_y
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_op
,
matmul_op
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_out
,
matmul_out
,
fuse_pattern
);
bool
flag
=
true
;
size_t
reshape2_in_nums
=
reshape2_op
->
inputs
.
size
();
auto
reshape2_in_x_shape
=
reshape2_in_x
->
Var
()
->
GetShape
();
size_t
reshape2_in_x_rank
=
reshape2_in_x_shape
.
size
();
std
::
vector
<
int
>
reshape2_op_shape
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
reshape2_op
->
Op
()
->
GetAttr
(
"shape"
));
flag
=
flag
&&
reshape2_in_nums
==
1
&&
reshape2_in_x_rank
==
4
&&
reshape2_in_x_shape
[
2
]
==
1
&&
reshape2_in_x_shape
[
3
]
==
1
&&
reshape2_op_shape
.
size
()
==
2
&&
(
matmul_in_x
->
outputs
).
size
()
==
1
;
bool
transpose_X
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_X"
));
bool
transpose_Y
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_Y"
));
float
alpha
=
BOOST_GET_CONST
(
float
,
matmul_op
->
Op
()
->
GetAttr
(
"alpha"
));
size_t
matmul_in_x_rank
=
(
matmul_in_x
->
Var
()
->
GetShape
()).
size
();
size_t
matmul_in_y_rank
=
(
matmul_in_y
->
Var
()
->
GetShape
()).
size
();
flag
=
flag
&&
!
transpose_X
&&
!
transpose_Y
&&
std
::
abs
(
alpha
-
1.0
)
<
1e-5
&&
matmul_in_x_rank
==
2
&&
matmul_in_y_rank
==
2
;
std
::
vector
<
Node
*>&
next_ops
=
matmul_out
->
outputs
;
flag
=
flag
&&
next_ops
.
size
()
==
1
&&
next_ops
[
0
]
->
Name
()
==
"elementwise_add"
;
if
(
flag
)
{
OpDesc
desc
;
desc
.
SetType
(
"mul"
);
desc
.
SetInput
(
"X"
,
{
reshape2_in_x
->
Name
()});
desc
.
SetInput
(
"Y"
,
{
matmul_in_y
->
Name
()});
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
1
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
reshape2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
IR_NODE_LINK_TO
(
mul_node
,
matmul_out
);
GraphSafeRemoveNodes
(
graph
,
{
reshape2_op
,
matmul_in_x
,
matmul_op
});
++
found_count
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
map_matmul_to_mul_pass
,
paddle
::
framework
::
ir
::
MapMatmul2MulPass
);
REGISTER_PASS_CAPABILITY
(
map_matmul_to_mul_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"matmul"
,
0
)
.
EQ
(
"mul"
,
0
));
REGISTER_PASS
(
squeeze2_matmul_fuse_pass
,
paddle
::
framework
::
ir
::
Squeeze2MatmulFusePass
);
REGISTER_PASS_CAPABILITY
(
squeeze2_matmul_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"matmul"
,
0
)
.
EQ
(
"squeeze2"
,
0
)
.
EQ
(
"mul"
,
0
));
REGISTER_PASS
(
reshape2_matmul_fuse_pass
,
paddle
::
framework
::
ir
::
Reshape2MatmulFusePass
);
REGISTER_PASS_CAPABILITY
(
reshape2_matmul_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"matmul"
,
0
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"mul"
,
0
));
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
0 → 100644
浏览文件 @
6a0102b0
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/*
* Map matmul to mul, so the optimization can use fc_fuse_pass.
* The mul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the rank of input activation is obtained from var_desc,
* it maybe change in runtime.
*/
class
Graph
;
class
MapMatmul2MulPass
:
public
FusePassBase
{
public:
virtual
~
MapMatmul2MulPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
};
/*
* Fuse squeeze2+matmul to mul, so the optimization can use fc_fuse_pass.
* The squeeze2 op must satisfy the following conditions:
* 1. the rank of input X is 4
* 2. the axis attr is [2, 3]
* 3. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the rank of input activation is obtained from var_desc,
* it maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class
Squeeze2MatmulFusePass
:
public
FusePassBase
{
public:
virtual
~
Squeeze2MatmulFusePass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
};
/*
* Fuse reshape2+matmul to mul, so the optimization can use fc_fuse_pass.
* The reshape2 op must satisfy the following conditions:
* 1. reshape2 has one input node, which means it don't
* have Shape or ShapeTensor input
* 2. the rank of input X is 4 and the last two dims of input X is 1
* 3. the rank of shape attr is 2
* 4. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the shape and rank of input activation is obtained from var_desc,
* they maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class
Reshape2MatmulFusePass
:
public
FusePassBase
{
public:
virtual
~
Reshape2MatmulFusePass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
浏览文件 @
6a0102b0
...
@@ -679,7 +679,7 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
...
@@ -679,7 +679,7 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
void
CPUQuantizePass
::
QuantizeMatmul
(
Graph
*
graph
)
const
{
void
CPUQuantizePass
::
QuantizeMatmul
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
auto
pattern
=
gpd
.
mutable_pattern
();
patterns
::
Matmul
matmul_pattern
{
pattern
,
name_scope_
};
patterns
::
Matmul
WithInputOps
matmul_pattern
{
pattern
,
name_scope_
};
matmul_pattern
();
matmul_pattern
();
int
quantize_matmul_count
=
0
;
int
quantize_matmul_count
=
0
;
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
6a0102b0
...
@@ -82,8 +82,11 @@ const std::vector<std::string> kTRTSubgraphPasses({
...
@@ -82,8 +82,11 @@ const std::vector<std::string> kTRTSubgraphPasses({
"embedding_eltwise_layernorm_fuse_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
"multihead_matmul_fuse_pass_v2"
,
//
"multihead_matmul_fuse_pass_v2"
,
//
"skip_layernorm_fuse_pass"
,
//
"skip_layernorm_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
"conv_bn_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
"squeeze2_matmul_fuse_pass"
,
//
"reshape2_matmul_fuse_pass"
,
//
"map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"tensorrt_subgraph_pass"
,
//
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
...
@@ -113,6 +116,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
...
@@ -113,6 +116,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
"multihead_matmul_fuse_pass_v2"
,
//
"multihead_matmul_fuse_pass_v2"
,
//
"squeeze2_matmul_fuse_pass"
,
//
"reshape2_matmul_fuse_pass"
,
//
"map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"fc_elementwise_layernorm_fuse_pass"
,
//
"fc_elementwise_layernorm_fuse_pass"
,
//
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
...
@@ -164,6 +170,9 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
...
@@ -164,6 +170,9 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"fc_gru_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
"mul_gru_fuse_pass"
,
//
"mul_gru_fuse_pass"
,
//
"seq_concat_fc_fuse_pass"
,
//
"seq_concat_fc_fuse_pass"
,
//
"squeeze2_matmul_fuse_pass"
,
//
"reshape2_matmul_fuse_pass"
,
//
"map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"repeated_fc_relu_fuse_pass"
,
//
"repeated_fc_relu_fuse_pass"
,
//
"squared_mat_sub_fuse_pass"
,
//
"squared_mat_sub_fuse_pass"
,
//
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录