Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
249b55c5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
249b55c5
编写于
6月 25, 2021
作者:
M
MissPenguin
提交者:
GitHub
6月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add pass enhance for map_matmul_to_mul_pass and flatten2_matmul_fuse_… (#33463)
上级
77a880c0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
120 addition
and
0 deletion
+120
-0
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
+118
-0
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
+2
-0
未找到文件。
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
100644 → 100755
浏览文件 @
249b55c5
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <cmath>
#include <cmath>
#include <string>
#include <string>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -26,6 +27,103 @@ namespace ir {
...
@@ -26,6 +27,103 @@ namespace ir {
class
Node
;
class
Node
;
MapMatmul2MulPass
::
MapMatmul2MulPass
()
{
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.99
f
)
.
IsNumLE
(
1.01
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
}
Flatten2MatmulFusePass
::
Flatten2MatmulFusePass
()
{
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.99
f
)
.
IsNumLE
(
1.01
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"flatten2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsNumGE
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
}
void
MapMatmul2MulPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
MapMatmul2MulPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
...
@@ -39,6 +137,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -39,6 +137,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
int
found_count
=
0
;
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"map matmul to mul"
;
VLOG
(
4
)
<<
"map matmul to mul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_x
,
matmul_in_x
,
matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_x
,
matmul_in_x
,
matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_y
,
matmul_in_y
,
matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_y
,
matmul_in_y
,
matmul_pattern
);
...
@@ -82,6 +185,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -82,6 +185,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO
(
mul_node
,
matmul_out
);
IR_NODE_LINK_TO
(
mul_node
,
matmul_out
);
GraphSafeRemoveNodes
(
graph
,
{
matmul_op
});
GraphSafeRemoveNodes
(
graph
,
{
matmul_op
});
++
found_count
;
++
found_count
;
if
(
!
IsCompat
(
desc
))
{
LOG
(
WARNING
)
<<
"MapMatmul2MulPass in out mul op compat failed."
;
return
;
}
}
}
};
};
...
@@ -244,6 +352,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -244,6 +352,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int
found_count
=
0
;
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"fuse flatten2+matmul to mul"
;
VLOG
(
4
)
<<
"fuse flatten2+matmul to mul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
flatten2_in_x
,
flatten2_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
flatten2_in_x
,
flatten2_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
flatten2_op
,
flatten2_op
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
flatten2_op
,
flatten2_op
,
fuse_pattern
);
...
@@ -301,6 +414,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -301,6 +414,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO
(
mul_node
,
matmul_out
);
IR_NODE_LINK_TO
(
mul_node
,
matmul_out
);
GraphSafeRemoveNodes
(
graph
,
{
flatten2_op
,
matmul_in_x
,
matmul_op
});
GraphSafeRemoveNodes
(
graph
,
{
flatten2_op
,
matmul_in_x
,
matmul_op
});
++
found_count
;
++
found_count
;
if
(
!
IsCompat
(
desc
))
{
LOG
(
WARNING
)
<<
"Flatten2MatmulFusePass in out mul op compat failed."
;
return
;
}
}
}
};
};
...
...
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
浏览文件 @
249b55c5
...
@@ -39,6 +39,7 @@ class Graph;
...
@@ -39,6 +39,7 @@ class Graph;
class
MapMatmul2MulPass
:
public
FusePassBase
{
class
MapMatmul2MulPass
:
public
FusePassBase
{
public:
public:
MapMatmul2MulPass
();
virtual
~
MapMatmul2MulPass
()
{}
virtual
~
MapMatmul2MulPass
()
{}
protected:
protected:
...
@@ -103,6 +104,7 @@ class Reshape2MatmulFusePass : public FusePassBase {
...
@@ -103,6 +104,7 @@ class Reshape2MatmulFusePass : public FusePassBase {
class
Flatten2MatmulFusePass
:
public
FusePassBase
{
class
Flatten2MatmulFusePass
:
public
FusePassBase
{
public:
public:
Flatten2MatmulFusePass
();
virtual
~
Flatten2MatmulFusePass
()
{}
virtual
~
Flatten2MatmulFusePass
()
{}
protected:
protected:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录