Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
20eafd79
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
20eafd79
编写于
6月 22, 2021
作者:
F
feng_shuai
提交者:
GitHub
6月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add squared_mat_sub_fuse_pass (#33597)
上级
cf3ddd3b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
225 addition
and
7 deletion
+225
-7
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc
+104
-3
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h
+3
-1
paddle/fluid/operators/compat/elementwise_mul.pbtxt
paddle/fluid/operators/compat/elementwise_mul.pbtxt
+70
-0
paddle/fluid/operators/compat/fill_constant.pbtxt
paddle/fluid/operators/compat/fill_constant.pbtxt
+4
-3
paddle/fluid/operators/compat/square.pbtxt
paddle/fluid/operators/compat/square.pbtxt
+44
-0
未找到文件。
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc
浏览文件 @
20eafd79
...
...
@@ -298,7 +298,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
return
last_out_var
;
}
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
)
{
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
const
SquaredMatSubFusePass
*
pass
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
...
...
@@ -320,6 +321,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
LOG
(
INFO
)
<<
"handle sqaure mat sub fuse"
;
if
(
!
pass
->
IsAcceptable
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Pass in op compat failed."
;
return
;
}
auto
&
fused_pattern
=
gpd
.
pattern
();
auto
*
matx
=
retrieve_node
(
name_scope
+
"/x"
,
subgraph
,
fused_pattern
);
...
...
@@ -368,14 +374,109 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
SquaredMatSubFusePass
::
SquaredMatSubFusePass
()
{
AddOpCompat
(
OpCompat
(
"square"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.99
f
)
.
IsNumLE
(
1.01
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_sub"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsNumEQ
(
-
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_mul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsNumEQ
(
-
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"fill_constant"
))
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"dtype"
)
.
IsNumGE
(
0
)
.
IsNumLE
(
25
)
.
End
()
.
AddAttr
(
"shape"
)
.
End
()
// type:float,there is no restriction
.
AddAttr
(
"value"
)
.
End
();
}
// to use IsCompat
bool
SquaredMatSubFusePass
::
IsAcceptable
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
const
{
return
IsCompat
(
subgraph
,
g
);
}
void
SquaredMatSubFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
);
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
this
);
AddStatis
(
fusion_count
);
}
...
...
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h
浏览文件 @
20eafd79
...
...
@@ -31,11 +31,13 @@ class Graph;
class
SquaredMatSubFusePass
:
public
FusePassBase
{
public:
SquaredMatSubFusePass
();
bool
IsAcceptable
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
const
;
virtual
~
SquaredMatSubFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
const
std
::
string
name_scope_
{
"squared_mat_sub_fuse"
};
};
...
...
paddle/fluid/operators/compat/elementwise_mul.pbtxt
0 → 100644
浏览文件 @
20eafd79
type: "elementwise_mul"
def {
inputs {
name: "X"
}
inputs {
name: "Y"
}
outputs {
name: "Out"
}
attrs {
name: "axis"
type: INT
}
}
extra {
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "x_data_format"
type: STRING
}
attrs {
name: "y_data_format"
type: STRING
}
attrs {
name: "use_quantizer"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "Scale_x"
type: FLOAT
}
attrs {
name: "Scale_y"
type: FLOAT
}
attrs {
name: "Scale_out"
type: FLOAT
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
paddle/fluid/operators/compat/fill_constant.pbtxt
浏览文件 @
20eafd79
...
...
@@ -24,12 +24,13 @@ def {
name: "value"
type: FLOAT
}
attrs {
}
extra {
attrs {
name: "str_value"
type: STRING
}
}
extra {
attrs {
name: "force_cpu"
type: BOOLEAN
...
...
paddle/fluid/operators/compat/square.pbtxt
0 → 100644
浏览文件 @
20eafd79
type: "square"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
}
extra {
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "use_cudnn"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录