Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
5f0f476f
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5f0f476f
编写于
8月 19, 2020
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify optimizer
上级
cbc3efdb
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
341 addition
and
190 deletion
+341
-190
x2paddle/optimizer/fusion/__init__.py
x2paddle/optimizer/fusion/__init__.py
+16
-0
x2paddle/optimizer/fusion/__pycache__/__init__.cpython-37.pyc
...ddle/optimizer/fusion/__pycache__/__init__.cpython-37.pyc
+0
-0
x2paddle/optimizer/fusion/__pycache__/fc_fuse_pass.cpython-37.pyc
.../optimizer/fusion/__pycache__/fc_fuse_pass.cpython-37.pyc
+0
-0
x2paddle/optimizer/fusion/__pycache__/fc_fuser.cpython-37.pyc
...ddle/optimizer/fusion/__pycache__/fc_fuser.cpython-37.pyc
+0
-0
x2paddle/optimizer/fusion/fc_fuse_pass.py
x2paddle/optimizer/fusion/fc_fuse_pass.py
+33
-0
x2paddle/optimizer/fusion/fc_fuser.py
x2paddle/optimizer/fusion/fc_fuser.py
+45
-53
x2paddle/optimizer/optimizer.py
x2paddle/optimizer/optimizer.py
+7
-28
x2paddle/optimizer/pass_.py
x2paddle/optimizer/pass_.py
+44
-0
x2paddle/optimizer/pass_manager.py
x2paddle/optimizer/pass_manager.py
+42
-0
x2paddle/optimizer/passes.py
x2paddle/optimizer/passes.py
+0
-109
x2paddle/optimizer/pattern_matcher.py
x2paddle/optimizer/pattern_matcher.py
+154
-0
未找到文件。
x2paddle/optimizer/fusion/__init__.py
0 → 100644
浏览文件 @
5f0f476f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.fc_fuser
import
FcFuser
from
.fc_fuse_pass
import
FcFusePass
x2paddle/optimizer/fusion/__pycache__/__init__.cpython-37.pyc
0 → 100644
浏览文件 @
5f0f476f
文件已添加
x2paddle/optimizer/fusion/__pycache__/fc_fuse_pass.cpython-37.pyc
0 → 100644
浏览文件 @
5f0f476f
文件已添加
x2paddle/optimizer/fusion/__pycache__/fc_fuser.cpython-37.pyc
0 → 100644
浏览文件 @
5f0f476f
文件已添加
x2paddle/optimizer/fusion/fc_fuse_pass.py
0 → 100644
浏览文件 @
5f0f476f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
x2paddle.optimizer.pass_
import
ProgramPass
from
x2paddle.optimizer.fusion
import
FcFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
FcFusePass
(
ProgramPass
):
name
=
"fc_fuse_pass"
def
__init__
(
self
):
ProgramPass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
FcFuser
()
fuser
.
operate
(
graph
)
# 用于注册
fc_fuse_pass
=
FcFusePass
()
x2paddle/optimizer/
linear_pass
.py
→
x2paddle/optimizer/
fusion/fc_fuser
.py
浏览文件 @
5f0f476f
...
...
@@ -13,17 +13,18 @@
# limitations under the License.
import
numpy
as
np
from
x2paddle.optimizer.pattern_matcher
import
FuseBase
from
x2paddle.core.program
import
PaddleGraph
,
PaddleLayer
from
x2paddle.core.util
import
*
from
x2paddle.core.program
import
PaddleLayer
,
PaddleGraph
from
x2paddle.optimizer.passes
import
Pass
,
Matcher
,
PyTorchMatcher
class
LinearPass
(
Pass
):
class
FcFuser
(
FuseBase
):
def
__init__
(
self
):
super
(
LinearPass
,
self
).
__init__
()
self
.
linear_index
=
0
super
(
FcFuser
,
self
).
__init__
()
def
build_pattern
(
self
):
"""
构造fc层的模式
。
"""
描述需要替换的fc图结构
。
fc层模式python实现代码示例:
x149 = 2
x151 = x146.shape
...
...
@@ -68,8 +69,8 @@ class LinearPass(Pass):
outputs
=
[
gen_name
(
3
)])
self
.
pattern
.
add_layer
(
"prim.if"
,
{
'input'
:
gen_name
(
3
)},
[
gen_name
(
4
)])
self
.
pattern
.
outputs
.
append
(
gen_name
(
4
))
if_layer
_a
=
self
.
pattern
.
layers
[
list
(
self
.
pattern
.
layers
.
keys
())[
-
1
]]
pattern_block0
=
PaddleGraph
(
if_layer
_a
)
if_layer
1
=
self
.
pattern
.
layers
[
list
(
self
.
pattern
.
layers
.
keys
())[
-
1
]]
pattern_block0
=
PaddleGraph
(
if_layer
1
)
pattern_block0
.
add_layer
(
"fluid.dygraph.base.to_variable"
,
inputs
=
{},
...
...
@@ -93,12 +94,12 @@ class LinearPass(Pass):
outputs
=
[
gen_name
(
8
)],
beta
=
1
,
alpha
=
1
)
if_layer
_a
.
inputs
[
"input-0"
]
=
"fc-input-0"
if_layer
1
.
inputs
[
"input-0"
]
=
"fc-input-0"
self
.
pattern
.
inputs
.
append
(
"fc-input-0"
)
pattern_block0
.
add_layer
(
"prim.equal"
,
inputs
=
{
'input'
:
gen_name
(
8
)},
outputs
=
[
gen_name
(
4
)])
if_layer
_a
.
add_block
(
pattern_block0
)
pattern_block1
=
PaddleGraph
(
if_layer
_a
)
if_layer
1
.
add_block
(
pattern_block0
)
pattern_block1
=
PaddleGraph
(
if_layer
1
)
pattern_block1
.
add_layer
(
"fluid.dygraph.base.to_variable"
,
inputs
=
{},
...
...
@@ -114,84 +115,75 @@ class LinearPass(Pass):
inputs
=
{
"x"
:
"fc-input-0"
,
"y"
:
gen_name
(
6
)},
outputs
=
[
gen_name
(
9
)])
if_layer
_a
.
inputs
[
"input-1"
]
=
"fc-input-0"
if_layer
1
.
inputs
[
"input-1"
]
=
"fc-input-0"
pattern_block1
.
add_layer
(
"prim.constant"
,
inputs
=
{},
outputs
=
[
gen_name
(
10
)],
value
=
True
)
pattern_block1
.
add_layer
(
"prim.if"
,
{
'input'
:
gen_name
(
10
)},
[
gen_name
(
11
)])
if_layer
_b
=
pattern_block1
.
layers
[
list
(
pattern_block1
.
layers
.
keys
())[
if_layer
2
=
pattern_block1
.
layers
[
list
(
pattern_block1
.
layers
.
keys
())[
-
1
]]
pattern_block1_block0
=
PaddleGraph
(
if_layer
_b
)
pattern_block1_block0
=
PaddleGraph
(
if_layer
2
)
pattern_block1_block0
.
add_layer
(
"fluid.dygraph.base.to_variable"
,
inputs
=
{},
outputs
=
[
gen_name
(
12
)],
value
=
"params[{}]"
.
format
(
string
(
gen_name
(
12
))))
pattern_block1_block0
.
add_layer
(
"prim.add"
,
"prim.add
_
"
,
inputs
=
{
"x"
:
gen_name
(
9
),
"y"
:
gen_name
(
12
)},
outputs
=
[
gen_name
(
13
)],
alpha
=
1
)
if_layer
_b
.
inputs
[
"input-0"
]
=
gen_name
(
9
)
if_layer
2
.
inputs
[
"input-0"
]
=
gen_name
(
9
)
pattern_block1_block0
.
add_layer
(
"prim.equal"
,
inputs
=
{
'input'
:
gen_name
(
13
)},
outputs
=
[
gen_name
(
11
)])
if_layer
_b
.
add_block
(
pattern_block1_block0
)
pattern_block1_block1
=
PaddleGraph
(
if_layer
_b
)
if_layer
2
.
add_block
(
pattern_block1_block0
)
pattern_block1_block1
=
PaddleGraph
(
if_layer
2
)
pattern_block1_block1
.
add_layer
(
"prim.equal"
,
inputs
=
{
'input'
:
gen_name
(
9
)},
outputs
=
[
gen_name
(
11
)])
if_layer
_b
.
inputs
[
"input-1"
]
=
gen_name
(
9
)
if_layer
2
.
inputs
[
"input-1"
]
=
gen_name
(
9
)
pattern_block1
.
add_layer
(
"prim.equal"
,
inputs
=
{
'input'
:
gen_name
(
11
)},
outputs
=
[
gen_name
(
4
)])
if_layer
_b
.
add_block
(
pattern_block1_block1
)
if_layer
_a
.
add_block
(
pattern_block1
)
if_layer
2
.
add_block
(
pattern_block1_block1
)
if_layer
1
.
add_block
(
pattern_block1
)
self
.
pattern
.
build
(
inputs
=
{
"input-0"
:
"fc-input-0"
,
"input-1"
:
"fc-input-0"
})
def
insert_new_layer
(
self
,
graph
,
matches
):
parameters
=
graph
.
parameters
new_layer
=
self
.
gen_new_layer
(
parameters
,
matches
)
new_layer_id
=
list
(
matches
.
keys
())[
0
]
graph
.
layers
[
new_layer_id
]
=
new_layer
matches
.
pop
(
new_layer_id
)
class
LinearMatcher
(
PyTorchMatcher
):
def
__init__
(
self
):
self
.
linear_index
=
0
super
(
LinearMatcher
,
self
).
__init__
()
def
replace_layer
(
self
,
graph
,
subgraph_global_layers
):
subgraph_global_layers_id
=
list
(
subgraph_global_layers
.
keys
())
layer
=
subgraph_global_layers
[
subgraph_global_layers_id
[
2
]]
def
gen_new_layer
(
self
,
parameters
,
matches
):
layers_id
=
list
(
matches
.
keys
())
layer
=
matches
[
layers_id
[
2
]]
input_name
=
layer
.
inputs
[
"input"
]
layer
=
subgraph_global_layers
[
subgraph_global_
layers_id
[
5
]]
layer
=
matches
[
layers_id
[
5
]]
output_name
=
layer
.
outputs
[
0
]
layer
=
subgraph_global_layers
[
subgraph_global_
layers_id
[
6
]]
layer
=
matches
[
layers_id
[
6
]]
weight_name
=
layer
.
attrs
[
"value"
][
8
:
-
2
]
layer
=
subgraph_global_layers
[
subgraph_global_
layers_id
[
8
]]
layer
=
matches
[
layers_id
[
8
]]
bias_name
=
layer
.
attrs
[
"value"
][
8
:
-
2
]
attrs
=
{}
attrs
[
"input_dim"
]
=
graph
.
parameters
[
weight_name
].
shape
[
1
]
attrs
[
"output_dim"
]
=
graph
.
parameters
[
weight_name
].
shape
[
0
]
attrs
[
"input_dim"
]
=
parameters
[
weight_name
].
shape
[
1
]
attrs
[
"output_dim"
]
=
parameters
[
weight_name
].
shape
[
0
]
linear_name
=
"linear{}"
.
format
(
self
.
linear_index
)
self
.
linear_index
+=
1
graph
.
parameters
[
"{}.weight"
.
format
(
linear_name
)]
=
graph
.
parameters
[
parameters
[
"{}.weight"
.
format
(
linear_name
)]
=
parameters
[
weight_name
].
transpose
((
1
,
0
))
graph
.
parameters
[
"{}.bias"
.
format
(
linear_name
)]
=
np
.
squeeze
(
graph
.
parameters
[
bias_name
])
graph
.
parameters
.
pop
(
weight_name
)
graph
.
parameters
.
pop
(
bias_name
)
for
i
,
layer_id
in
enumerate
(
subgraph_global_layers
):
if
layer_id
in
graph
.
layers
:
layer
=
graph
.
layers
[
layer_id
]
if
i
==
0
:
parameters
[
"{}.bias"
.
format
(
linear_name
)]
=
np
.
squeeze
(
parameters
[
bias_name
])
new_layer
=
PaddleLayer
(
layer_id
,
layers_id
[
0
]
,
"fluid.dygraph.Linear"
,
inputs
=
{
"input"
:
input_name
},
outputs
=
[
linear_name
,
output_name
],
**
attrs
)
graph
.
layers
[
layer_id
]
=
new_layer
else
:
graph
.
layers
.
pop
(
layer_id
)
graph
.
build
()
return
graph
return
new_layer
x2paddle/optimizer/optimizer.py
浏览文件 @
5f0f476f
...
...
@@ -12,38 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
x2paddle.optimizer.linear_pass
import
LinearPass
,
LinearMatcher
from
x2paddle.optimizer.fusion
import
*
from
x2paddle.optimizer.pass_manager
import
PassManager
class
GraphOptimizer
(
object
):
def
__init__
(
self
):
linear_pass
=
LinearPass
()
linear_matcher
=
LinearMatcher
()
self
.
passes
=
{
linear_pass
:
linear_matcher
}
def
run
(
self
,
graph
):
is_update_graph
=
False
while
True
:
for
i
,
(
layer_id
,
layer
)
in
enumerate
(
graph
.
layers
.
items
()):
is_match
=
self
.
current_matcher
.
match_pattern
(
self
.
current_pass
.
pattern
,
graph
,
i
)
if
is_match
:
is_update_graph
=
True
graph
=
self
.
current_matcher
.
replace_layer
(
graph
,
is_match
)
break
for
j
,
block
in
enumerate
(
layer
.
blocks
):
if
len
(
block
.
layers
)
>
0
:
layer
.
blocks
[
j
],
is_update_block
=
self
.
run
(
block
)
if
is_update_block
:
break
if
i
+
1
==
len
(
graph
.
layers
):
return
graph
,
is_update_graph
self
.
passes
=
[
"fc_fuse_pass"
]
def
optimize
(
self
,
graph
):
# 开始优化
for
_pass
,
matcher
in
self
.
passes
.
items
():
self
.
current_pass
=
_pass
self
.
current_matcher
=
matcher
graph
,
_
=
self
.
run
(
graph
)
print
(
"{} done!"
.
format
(
_pass
.
__class__
.
__name__
))
for
pass_name
in
self
.
passes
:
pass_
=
PassManager
.
lookup
(
pass_name
)()
pass_
.
apply
(
graph
)
print
(
"{} done!"
.
format
(
pass_name
))
return
graph
x2paddle/optimizer/pass_.py
0 → 100644
浏览文件 @
5f0f476f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
enum
import
Enum
class
Kind
(
Enum
):
Program
=
1
Code
=
2
class
Pass
(
object
):
name
=
"pass"
def
__init__
(
self
,
kind
):
self
.
kind
=
kind
def
apply
(
self
,
graph
):
raise
NotImplementedError
(
"The apply function must be implemented!"
)
@
classmethod
def
get_name
(
cls
):
return
cls
.
name
class
ProgramPass
(
Pass
):
def
__init__
(
self
):
super
(
ProgramPass
,
self
).
__init__
(
Kind
.
Program
)
class
CodePass
(
Pass
):
def
__init__
(
self
):
super
(
CodePass
,
self
).
__init__
(
Kind
.
Code
)
x2paddle/optimizer/pass_manager.py
0 → 100644
浏览文件 @
5f0f476f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
PassManager
(
object
):
""" pass管理器。
"""
# pass_map存储name与其对应的pass
pass_map
=
dict
()
def
__init__
(
self
):
pass
@
staticmethod
def
add_new_pass
(
name
,
pass_
):
if
name
not
in
PassManager
.
pass_map
:
PassManager
.
pass_map
[
name
]
=
pass_
@
staticmethod
def
clear
():
PassManager
.
passes
=
list
()
@
staticmethod
def
lookup
(
name
):
return
PassManager
.
pass_map
[
name
]
def
pass_register
(
cls
):
name
=
cls
.
get_name
()
PassManager
.
add_new_pass
(
name
,
cls
)
return
cls
x2paddle/optimizer/passes.py
已删除
100644 → 0
浏览文件 @
cbc3efdb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
x2paddle.core.program
import
PaddleGraph
class
Pass
(
object
):
def
__init__
(
self
):
self
.
pattern
=
PaddleGraph
()
self
.
build_pattern
()
class
Matcher
(
object
):
def
__init__
(
self
):
pass
class
PyTorchMatcher
(
Matcher
):
def
__init__
(
self
):
super
(
PyTorchMatcher
,
self
).
__init__
()
def
match_pattern
(
self
,
pattern
,
graph
,
start_index
):
pattern_index
=
0
pattern_global_layers
=
pattern
.
get_global_layers
()
subgraph_global_layers
=
dict
()
graph_layers
=
dict
(
list
(
graph
.
layers
.
items
())[
start_index
:])
for
layer_id
,
layer
in
graph_layers
.
items
():
pattern_layer
=
pattern
.
layers
[
list
(
pattern
.
layers
.
keys
())[
pattern_index
]]
if
layer
.
kernel
==
pattern_layer
.
kernel
:
subgraph_global_layers
[
layer_id
]
=
layer
pattern_layer_id
=
pattern_layer
.
id
if
layer
.
kernel
==
"prim.constant"
:
if
layer
.
attrs
[
"value"
]
!=
pattern_layer
.
attrs
[
"value"
]:
return
False
elif
layer
.
kernel
==
"fluid.layers.addmm"
:
if
layer
.
attrs
[
"beta"
]
!=
pattern_layer
.
attrs
[
"beta"
]:
return
False
if
layer
.
attrs
[
"alpha"
]
!=
pattern_layer
.
attrs
[
"alpha"
]:
return
False
if
layer_id
in
graph
.
edges_in
:
if
pattern_layer_id
not
in
pattern
.
edges_in
:
return
False
else
:
if
len
(
graph
.
edges_in
[
layer_id
])
!=
len
(
pattern
.
edges_in
[
pattern_layer_id
]):
return
False
layer_in
=
graph
.
edges_in
[
layer_id
]
pattern_layer_in
=
pattern
.
edges_in
[
pattern_layer_id
]
for
i
in
range
(
len
(
layer_in
)):
layer_id_in
=
layer_in
[
i
]
pattern_layer_id_in
=
pattern_layer_in
[
i
]
if
pattern_layer_id_in
!=
-
1
:
pattern_global_layers_id
=
list
(
pattern_global_layers
.
keys
())
subgraph_global_layers_id
=
list
(
subgraph_global_layers
.
keys
())
if
pattern_global_layers_id
.
index
(
pattern_layer_id_in
)
==
\
subgraph_global_layers_id
.
index
(
layer_id_in
):
# 判断pattern输入在pattern_global_layers_id的索引
# 和graph输入在subgraph_global_layers_id的索引一致
continue
return
False
if
layer_id
in
graph
.
edges_out
:
if
pattern_layer_id
not
in
pattern
.
edges_out
:
if
not
set
(
pattern_layer
.
outputs
).
issubset
(
pattern
.
outputs
):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return
False
else
:
if
len
(
graph
.
edges_out
[
layer_id
])
!=
len
(
pattern
.
edges_out
[
pattern_layer_id
]):
# 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到
if
not
set
(
pattern_layer
.
outputs
).
issubset
(
pattern
.
outputs
):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return
False
if
layer
.
kernel
==
"prim.if"
:
res
=
self
.
match_pattern
(
pattern_layer
.
blocks
[
0
],
layer
.
blocks
[
0
],
0
)
if
res
:
subgraph_global_layers
.
update
(
res
)
else
:
return
False
res
=
self
.
match_pattern
(
pattern_layer
.
blocks
[
1
],
layer
.
blocks
[
1
],
0
)
if
res
:
subgraph_global_layers
.
update
(
res
)
else
:
return
False
pattern_index
+=
1
if
pattern_index
==
len
(
pattern
.
layers
):
return
subgraph_global_layers
else
:
return
False
x2paddle/optimizer/pattern_matcher.py
0 → 100644
浏览文件 @
5f0f476f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
x2paddle.core.program
import
PaddleGraph
class
PatternMatcher
(
object
):
def
__init__
(
self
,
pattern
):
self
.
pattern
=
pattern
self
.
subgraphs
=
list
()
def
operate
(
self
,
graph
):
self
.
detect_patterns
(
graph
)
self
.
remove_overlapped_match
()
return
self
.
subgraphs
def
detect_patterns
(
self
,
graph
):
""" 找到与模式匹配的子图,
并将子图的id以拓扑排序存放到subgraph_id2layers。
"""
def
get_subgraph
(
pattern
,
graph
,
start_index
):
pattern_index
=
0
pattern_id2layers
=
pattern
.
get_global_layers
()
pattern_ids
=
list
(
pattern_id2layers
.
keys
())
subgraph_id2layers
=
dict
()
graph_layers
=
dict
(
list
(
graph
.
layers
.
items
())[
start_index
:])
for
layer_id
,
layer
in
graph_layers
.
items
():
pattern_layer
=
pattern
.
layers
[
list
(
pattern
.
layers
.
keys
())[
pattern_index
]]
if
layer
.
kernel
==
pattern_layer
.
kernel
:
subgraph_id2layers
[
layer_id
]
=
layer
pattern_layer_id
=
pattern_layer
.
id
# 判断输入连接是否一致
if
layer_id
in
graph
.
edges_in
:
if
pattern_layer_id
not
in
pattern
.
edges_in
:
return
False
else
:
if
len
(
graph
.
edges_in
[
layer_id
])
!=
len
(
pattern
.
edges_in
[
pattern_layer_id
]):
return
False
layer_in
=
graph
.
edges_in
[
layer_id
]
pattern_layer_in
=
pattern
.
edges_in
[
pattern_layer_id
]
for
i
in
range
(
len
(
layer_in
)):
layer_id_in
=
layer_in
[
i
]
pattern_layer_id_in
=
pattern_layer_in
[
i
]
if
pattern_layer_id_in
!=
-
1
:
subgraph_ids
=
list
(
subgraph_id2layers
.
keys
())
if
pattern_ids
.
index
(
pattern_layer_id_in
)
==
\
subgraph_ids
.
index
(
layer_id_in
):
# 判断pattern输入在pattern_ids的索引
# 和graph输入在subgraph_ids的索引一致
continue
return
False
# 判断subgraph中的节点是否被外部图使用到(如若被使用到则无效)
if
layer_id
in
graph
.
edges_out
:
if
pattern_layer_id
not
in
pattern
.
edges_out
:
if
not
set
(
pattern_layer
.
outputs
).
issubset
(
pattern
.
outputs
):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return
False
else
:
if
len
(
graph
.
edges_out
[
layer_id
])
!=
len
(
pattern
.
edges_out
[
pattern_layer_id
]):
# 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到
if
not
set
(
pattern_layer
.
outputs
).
issubset
(
pattern
.
outputs
):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return
False
# 当为控制流时的处理
if
layer
.
kernel
==
"prim.if"
:
match_info
=
get_subgraph
(
pattern_layer
.
blocks
[
0
],
layer
.
blocks
[
0
],
0
)
if
match_info
:
subgraph_id2layers
.
update
(
match_info
)
else
:
return
False
match_info
=
get_subgraph
(
pattern_layer
.
blocks
[
1
],
layer
.
blocks
[
1
],
0
)
if
match_info
:
subgraph_id2layers
.
update
(
match_info
)
else
:
return
False
pattern_index
+=
1
if
pattern_index
==
len
(
pattern
.
layers
):
return
subgraph_id2layers
else
:
return
False
for
i
,
(
layer_id
,
layer
)
in
enumerate
(
graph
.
layers
.
items
()):
match_info
=
get_subgraph
(
self
.
pattern
,
graph
,
i
)
if
match_info
:
self
.
subgraphs
.
append
(
match_info
)
for
j
,
block
in
enumerate
(
layer
.
blocks
):
if
len
(
block
.
layers
)
>
0
:
self
.
detect_patterns
(
layer
.
blocks
[
j
])
def
remove_overlapped_match
(
self
):
""" 如果2个子图有重叠,只取前一个子图。
"""
match_ids
=
[]
for
i
,
subgraph
in
enumerate
(
self
.
subgraphs
):
is_overlapped
=
False
for
id
in
subgraph
.
keys
():
if
id
in
match_ids
:
self
.
subgraphs
.
pop
(
i
)
is_overlapped
=
True
break
if
not
is_overlapped
:
match_ids
.
extend
(
list
(
subgraph
.
keys
()))
class
FuseBase
(
object
):
def
__init__
(
self
):
self
.
pattern
=
PaddleGraph
()
def
operate
(
self
,
graph
):
self
.
build_pattern
()
self
.
perform_pattern_matcher
(
graph
)
for
subgraph
in
self
.
subgraphs
:
self
.
insert_new_layer
(
graph
,
subgraph
)
self
.
delete_inter_layer
(
graph
)
graph
.
build
()
def
perform_pattern_matcher
(
self
,
graph
):
""" 执行模式匹配,找到匹配的子图。
"""
pattern_matcher
=
PatternMatcher
(
self
.
pattern
)
self
.
subgraphs
=
pattern_matcher
.
operate
(
graph
)
def
delete_inter_layer
(
self
,
graph
):
""" 删除不需要的中间layer及其对应参数。
"""
for
subgraph
in
self
.
subgraphs
:
for
layer_id
,
layer
in
subgraph
.
items
():
if
layer
.
kernel
==
"fluid.dygraph.base.to_variable"
and
\
layer
.
attrs
[
"value"
].
startswith
(
"params["
):
param_name
=
layer
.
attrs
[
"value"
][
8
:
-
2
]
if
param_name
in
graph
.
parameters
:
graph
.
parameters
.
pop
(
param_name
)
if
layer_id
in
graph
.
layers
:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图
graph
.
layers
.
pop
(
layer_id
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录