Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
775d6e6e
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
775d6e6e
编写于
11月 13, 2020
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename delete_layer
上级
934ee6a8
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
4 addition
and
52 deletion
+4
-52
x2paddle/optimizer/elimination/dygraph/transpose_elimination.py
...le/optimizer/elimination/dygraph/transpose_elimination.py
+1
-1
x2paddle/optimizer/pattern_matcher.py
x2paddle/optimizer/pattern_matcher.py
+3
-51
未找到文件。
x2paddle/optimizer/elimination/dygraph/transpose_elimination.py
浏览文件 @
775d6e6e
...
@@ -252,7 +252,7 @@ class DygraphTransposeElimination(FuseBase):
...
@@ -252,7 +252,7 @@ class DygraphTransposeElimination(FuseBase):
continue
continue
for
l
in
transpose_layers
:
for
l
in
transpose_layers
:
self
.
delete_layer_with_associated
(
_graph
,
l
)
_graph
.
delete_layer
(
l
)
optimized_transpose_layers
.
extend
(
transpose_layers
)
optimized_transpose_layers
.
extend
(
transpose_layers
)
optimized_reduce_layers
.
extend
(
reduce_layers
)
optimized_reduce_layers
.
extend
(
reduce_layers
)
...
...
x2paddle/optimizer/pattern_matcher.py
浏览文件 @
775d6e6e
...
@@ -268,7 +268,7 @@ class FuseBase(object):
...
@@ -268,7 +268,7 @@ class FuseBase(object):
first_layer_id
=
list
(
match
.
keys
())[
0
]
first_layer_id
=
list
(
match
.
keys
())[
0
]
subgraph
=
get_subgraph
(
""
,
first_layer_id
,
graph
)
subgraph
=
get_subgraph
(
""
,
first_layer_id
,
graph
)
self
.
insert_new_layer
(
subgraph
,
parameters
,
match
)
self
.
insert_new_layer
(
subgraph
,
parameters
,
match
)
self
.
delete_
layer
(
graph
)
self
.
delete_
match
(
graph
)
graph
.
build
()
graph
.
build
()
def
perform_pattern_matcher
(
self
,
graph
,
match_kind
=
"topo"
):
def
perform_pattern_matcher
(
self
,
graph
,
match_kind
=
"topo"
):
...
@@ -283,7 +283,7 @@ class FuseBase(object):
...
@@ -283,7 +283,7 @@ class FuseBase(object):
pattern_matcher
=
PatternMatcher
(
self
.
pattern
)
pattern_matcher
=
PatternMatcher
(
self
.
pattern
)
self
.
matches
=
pattern_matcher
.
operate
(
graph
,
match_kind
)
self
.
matches
=
pattern_matcher
.
operate
(
graph
,
match_kind
)
def
delete_
layer
(
self
,
graph
):
def
delete_
match
(
self
,
graph
):
""" 删除不需要的中间layer及其对应参数。
""" 删除不需要的中间layer及其对应参数。
"""
"""
for
match
in
self
.
matches
:
for
match
in
self
.
matches
:
...
@@ -298,52 +298,4 @@ class FuseBase(object):
...
@@ -298,52 +298,4 @@ class FuseBase(object):
if
layer_id
in
subgraph
.
layers
:
if
layer_id
in
subgraph
.
layers
:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图
subgraph
.
layers
.
pop
(
layer_id
)
subgraph
.
layers
.
pop
(
layer_id
)
def
delete_layer_with_associated
(
self
,
graph
,
layer_id
):
\ No newline at end of file
""" 删除不需要的中间layer及其相关连接点。
"""
layer
=
graph
.
layers
[
layer_id
]
outputs
=
graph
.
edges_out
.
get
(
layer_id
,
[])
inputs
=
graph
.
edges_in
.
get
(
layer_id
,
[])
assert
len
(
inputs
)
<=
1
,
"There should be 0 or 1 input for deleted layer."
if
len
(
inputs
)
==
0
:
for
out
in
outputs
:
while
layer_id
in
graph
.
edges_in
[
out
]:
index
=
graph
.
edges_in
[
out
].
index
(
layer_id
)
del
graph
.
edges_in
[
out
][
index
]
input_keys
=
list
(
graph
.
layers
[
out
].
inputs
.
keys
())
for
k
in
input_keys
:
if
graph
.
layers
[
out
].
inputs
[
k
]
==
layer
.
outputs
[
0
]:
del
graph
.
layers
[
out
].
inputs
[
k
]
del
graph
.
layers
[
layer_id
]
if
layer_id
in
graph
.
edges_in
:
del
graph
.
edges_in
[
layer_id
]
if
layer_id
in
graph
.
edges_out
:
del
graph
.
edges_out
[
layer_id
]
return
# 将所有输出layer的输入layer进行替换
for
out
in
outputs
:
for
i
in
range
(
len
(
graph
.
edges_in
[
out
])):
if
graph
.
edges_in
[
out
][
i
]
==
layer_id
:
graph
.
edges_in
[
out
][
i
]
=
inputs
[
0
]
# 将输出layer赋给输入layer的输出
replace_index
=
graph
.
edges_out
[
inputs
[
0
]].
index
(
layer_id
)
del
graph
.
edges_out
[
inputs
[
0
]][
replace_index
]
for
i
,
out
in
enumerate
(
outputs
):
graph
.
edges_out
[
inputs
[
0
]].
insert
(
replace_index
+
i
,
out
)
for
k
,
v
in
graph
.
layers
[
out
].
inputs
.
items
():
if
v
==
layer
.
outputs
[
0
]:
graph
.
layers
[
out
].
inputs
[
k
]
=
list
(
layer
.
inputs
.
values
())[
0
]
del
graph
.
layers
[
layer_id
]
if
layer_id
in
graph
.
edges_out
:
del
graph
.
edges_out
[
layer_id
]
if
layer_id
in
graph
.
edges_in
:
del
graph
.
edges_in
[
layer_id
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录