Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
60765f75
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 2 年 前同步成功
通知
329
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
60765f75
编写于
1月 25, 2022
作者:
W
wjj19950828
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixed bug
上级
1b5969d9
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
68 addition
and
88 deletion
+68
-88
x2paddle/optimizer/pytorch_code_optimizer/module_graph.py
x2paddle/optimizer/pytorch_code_optimizer/module_graph.py
+68
-88
未找到文件。
x2paddle/optimizer/pytorch_code_optimizer/module_graph.py
浏览文件 @
60765f75
...
@@ -21,8 +21,8 @@ from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import construct_
...
@@ -21,8 +21,8 @@ from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import construct_
from
x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator
import
gen_layer_code
,
rename_layers
from
x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator
import
gen_layer_code
,
rename_layers
from
x2paddle.optimizer.pytorch_code_optimizer.parameter_tree
import
PamareterNode
,
PamareterTree
from
x2paddle.optimizer.pytorch_code_optimizer.parameter_tree
import
PamareterNode
,
PamareterTree
NoModuleStart
=
[
"paddle.nn.ReLU"
]
NoModuleStart
=
[
"paddle.nn.ReLU"
]
class
Apriori
(
object
):
class
Apriori
(
object
):
""" 使用Apriori算法挖掘频繁子图
""" 使用Apriori算法挖掘频繁子图
...
@@ -33,7 +33,6 @@ class Apriori(object):
...
@@ -33,7 +33,6 @@ class Apriori(object):
Args:
Args:
min_support (int): 子图出现次数的最小值。
min_support (int): 子图出现次数的最小值。
"""
"""
def
__init__
(
self
,
min_support
):
def
__init__
(
self
,
min_support
):
self
.
min_support
=
min_support
self
.
min_support
=
min_support
...
@@ -50,9 +49,9 @@ class Apriori(object):
...
@@ -50,9 +49,9 @@ class Apriori(object):
if
layer
.
kernel
==
"paddle.to_tensor"
or
\
if
layer
.
kernel
==
"paddle.to_tensor"
or
\
layer
.
kernel
==
"prim.if"
or
\
layer
.
kernel
==
"prim.if"
or
\
layer
.
kernel
==
"prim.loop"
:
#or \
layer
.
kernel
==
"prim.loop"
:
#or \
# layer.kernel == "prim.list" or \
# layer.kernel == "prim.list" or \
# layer.kernel == "prim.tuple" or \
# layer.kernel == "prim.tuple" or \
# layer.kernel == "prim.dict_construct":
# layer.kernel == "prim.dict_construct":
continue
continue
if
self
.
pd_graph
.
edges_in
.
get
(
layer_id
,
0
)
==
0
and
\
if
self
.
pd_graph
.
edges_in
.
get
(
layer_id
,
0
)
==
0
and
\
self
.
pd_graph
.
edges_out
.
get
(
layer_id
,
0
)
==
0
:
self
.
pd_graph
.
edges_out
.
get
(
layer_id
,
0
)
==
0
:
...
@@ -103,7 +102,6 @@ class Apriori(object):
...
@@ -103,7 +102,6 @@ class Apriori(object):
class
DP
(
object
):
class
DP
(
object
):
""" 使用动态规划找到使代码最短的组合方式。
""" 使用动态规划找到使代码最短的组合方式。
"""
"""
def
__init__
(
self
,
combination_itemset
):
def
__init__
(
self
,
combination_itemset
):
self
.
combination_itemset
=
combination_itemset
self
.
combination_itemset
=
combination_itemset
...
@@ -147,8 +145,7 @@ class DP(object):
...
@@ -147,8 +145,7 @@ class DP(object):
if
j
-
1
<
0
:
if
j
-
1
<
0
:
last_itemset
=
list
()
last_itemset
=
list
()
else
:
else
:
last_itemset
=
copy
.
deepcopy
(
layer_combination_list
[
last_itemset
=
copy
.
deepcopy
(
layer_combination_list
[
j
-
1
])
j
-
1
])
else
:
else
:
if
j
==
prefix_ids
[
0
]:
if
j
==
prefix_ids
[
0
]:
min_count
=
len
(
layer_combination_list
[
j
])
+
1
min_count
=
len
(
layer_combination_list
[
j
])
+
1
...
@@ -166,7 +163,6 @@ class DP(object):
...
@@ -166,7 +163,6 @@ class DP(object):
class
ModuleGraph
(
object
):
class
ModuleGraph
(
object
):
""" 更新PaddleGraph,生成代码。
""" 更新PaddleGraph,生成代码。
"""
"""
def
__init__
(
self
,
graph
):
def
__init__
(
self
,
graph
):
self
.
pd_graph
=
graph
self
.
pd_graph
=
graph
self
.
global_layers
=
graph
.
get_global_layers
()
self
.
global_layers
=
graph
.
get_global_layers
()
...
@@ -200,7 +196,7 @@ class ModuleGraph(object):
...
@@ -200,7 +196,7 @@ class ModuleGraph(object):
if
len
(
elements_list
)
>
1
:
if
len
(
elements_list
)
>
1
:
max_ct
=
0
max_ct
=
0
for
k
,
v
in
zip
(
elements_list
,
count_list
):
for
k
,
v
in
zip
(
elements_list
,
count_list
):
if
v
>
max_ct
and
str
(
k
)
!=
"nan"
:
if
v
>
max_ct
and
str
(
k
)
!=
"nan"
:
max_ele
=
k
max_ele
=
k
max_ct
=
v
max_ct
=
v
diff_attrs_column
[
column
]
=
max_ele
diff_attrs_column
[
column
]
=
max_ele
...
@@ -218,34 +214,24 @@ class ModuleGraph(object):
...
@@ -218,34 +214,24 @@ class ModuleGraph(object):
layer_id2
=
layer_id_list2
[
i
]
layer_id2
=
layer_id_list2
[
i
]
if
layer_id2
not
in
self
.
pd_graph
.
edges_in
:
if
layer_id2
not
in
self
.
pd_graph
.
edges_in
:
return
False
return
False
if
len
(
self
.
pd_graph
.
edges_in
[
layer_id1
])
!=
len
(
if
len
(
self
.
pd_graph
.
edges_in
[
layer_id1
])
!=
len
(
self
.
pd_graph
.
edges_in
[
layer_id2
]):
self
.
pd_graph
.
edges_in
[
layer_id2
]):
return
False
return
False
for
j
,
ipt_layer_id1
in
enumerate
(
self
.
pd_graph
.
edges_in
[
for
j
,
ipt_layer_id1
in
enumerate
(
self
.
pd_graph
.
edges_in
[
layer_id1
]):
layer_id1
]):
ipt_layer_id2
=
self
.
pd_graph
.
edges_in
[
layer_id2
][
j
]
ipt_layer_id2
=
self
.
pd_graph
.
edges_in
[
layer_id2
][
j
]
if
(
ipt_layer_id1
in
layer_id_list1
)
^
(
if
(
ipt_layer_id1
in
layer_id_list1
)
^
(
ipt_layer_id2
in
layer_id_list2
):
ipt_layer_id2
in
layer_id_list2
):
return
False
return
False
if
(
layer_id1
in
self
.
pd_graph
.
edges_out
)
^
(
if
(
layer_id1
in
self
.
pd_graph
.
edges_out
)
^
(
layer_id2
in
self
.
pd_graph
.
edges_out
):
layer_id2
in
self
.
pd_graph
.
edges_out
):
return
False
return
False
if
(
layer_id1
in
self
.
pd_graph
.
edges_out
)
and
(
if
(
layer_id1
in
self
.
pd_graph
.
edges_out
)
and
(
layer_id2
in
self
.
pd_graph
.
edges_out
):
layer_id2
in
self
.
pd_graph
.
edges_out
):
if
(
len
(
self
.
pd_graph
.
edges_out
[
layer_id1
])
>
1
and
len
(
self
.
pd_graph
.
edges_out
[
layer_id2
])
==
1
)
or
\
if
(
len
(
self
.
pd_graph
.
edges_out
[
layer_id1
])
>
1
and
len
(
self
.
pd_graph
.
edges_out
[
layer_id2
])
==
1
)
or
\
(
len
(
self
.
pd_graph
.
edges_out
[
layer_id1
])
==
1
and
len
(
self
.
pd_graph
.
edges_out
[
layer_id2
])
>
1
):
(
len
(
self
.
pd_graph
.
edges_out
[
layer_id1
])
==
1
and
len
(
self
.
pd_graph
.
edges_out
[
layer_id2
])
>
1
):
return
False
return
False
for
j
,
opt_layer_id1
in
enumerate
(
self
.
pd_graph
.
edges_out
[
for
j
,
opt_layer_id1
in
enumerate
(
self
.
pd_graph
.
edges_out
[
layer_id1
]):
layer_id1
]):
if
len
(
self
.
pd_graph
.
edges_out
[
layer_id1
])
==
1
and
len
(
self
.
pd_graph
.
edges_out
[
layer_id2
])
==
1
:
if
len
(
self
.
pd_graph
.
edges_out
[
layer_id1
])
==
1
and
len
(
opt_layer_id2
=
self
.
pd_graph
.
edges_out
[
layer_id2
][
j
]
self
.
pd_graph
.
edges_out
[
layer_id2
])
==
1
:
if
(
opt_layer_id1
in
layer_id_list1
)
^
(
opt_layer_id2
in
layer_id_list2
):
opt_layer_id2
=
self
.
pd_graph
.
edges_out
[
layer_id2
][
j
]
if
(
opt_layer_id1
in
layer_id_list1
)
^
(
opt_layer_id2
in
layer_id_list2
):
return
False
return
False
return
True
return
True
sub_layers_list_list
=
list
()
sub_layers_list_list
=
list
()
id_list
=
list
()
id_list
=
list
()
ipt_opt_list
=
list
()
ipt_opt_list
=
list
()
...
@@ -266,12 +252,12 @@ class ModuleGraph(object):
...
@@ -266,12 +252,12 @@ class ModuleGraph(object):
id_list
.
append
(
i
)
id_list
.
append
(
i
)
return
sub_layers_list_list
return
sub_layers_list_list
def
merge_node
(
self
,
sub_layers_list
,
attrs_table
,
module_name
):
def
merge_node
(
self
,
sub_layers_list
,
attrs_table
,
module_name
):
sub_layers
=
sub_layers_list
[
0
]
sub_layers
=
sub_layers_list
[
0
]
diff_attrs_column
=
self
.
analyze_attrs_table
(
attrs_table
)
diff_attrs_column
=
self
.
analyze_attrs_table
(
attrs_table
)
sub_layers
,
_
,
_
=
rename_layers
(
sub_layers
)
sub_layers
,
_
,
_
=
rename_layers
(
sub_layers
)
code_str
=
gen_layer_code
(
code_str
=
gen_layer_code
(
self
.
pd_graph
,
self
.
pd_graph
,
sub_layers
,
sub_layers
,
module_name
,
module_name
,
different_attrs
=
diff_attrs_column
)
different_attrs
=
diff_attrs_column
)
...
@@ -289,8 +275,7 @@ class ModuleGraph(object):
...
@@ -289,8 +275,7 @@ class ModuleGraph(object):
current_element
=
attrs_table
.
get
(
column
).
loc
[
node_name
]
current_element
=
attrs_table
.
get
(
column
).
loc
[
node_name
]
if
current_element
!=
element
:
if
current_element
!=
element
:
diff_attrs
[
column
]
=
current_element
diff_attrs
[
column
]
=
current_element
new_layer
=
PaddleLayer
(
new_layer
=
PaddleLayer
(
id
=
list
(
sub_layers
.
keys
())[
-
1
],
id
=
list
(
sub_layers
.
keys
())[
-
1
],
kernel
=
"module"
,
kernel
=
"module"
,
inputs
=
inputs_dict
,
inputs
=
inputs_dict
,
outputs
=
outputs
,
outputs
=
outputs
,
...
@@ -333,13 +318,12 @@ class ModuleGraph(object):
...
@@ -333,13 +318,12 @@ class ModuleGraph(object):
else
:
else
:
real_module_name
=
module_name
+
"__{}"
.
format
(
i
)
real_module_name
=
module_name
+
"__{}"
.
format
(
i
)
if
len
(
sub_layers_list
)
>
1
:
if
len
(
sub_layers_list
)
>
1
:
attrs_table
=
construct_attrs_table
(
attrs_table
=
construct_attrs_table
(
sub_layers_list
,
module_name
=
real_module_name
)
sub_layers_list
,
module_name
=
real_module_name
)
self
.
merge_node
(
sub_layers_list
,
attrs_table
,
real_module_name
)
self
.
merge_node
(
sub_layers_list
,
attrs_table
,
layers
,
nn_param_nodes
,
_
=
rename_layers
(
self
.
pd_graph
.
layers
,
self
.
param_tree
,
is_rename_module
=
True
)
real_module_name
)
code_str
=
gen_layer_code
(
self
.
pd_graph
,
layers
,
nn_param_nodes
,
_
=
rename_layers
(
layers
,
self
.
pd_graph
.
layers
,
self
.
param_tree
,
is_rename_module
=
True
)
self
.
pd_graph
.
name
)
code_str
=
gen_layer_code
(
self
.
pd_graph
,
layers
,
self
.
pd_graph
.
name
)
self
.
codes
.
append
(
code_str
)
self
.
codes
.
append
(
code_str
)
param_node
=
PamareterNode
(
old_name
=
"Module"
)
param_node
=
PamareterNode
(
old_name
=
"Module"
)
for
node
in
nn_param_nodes
:
for
node
in
nn_param_nodes
:
...
@@ -350,13 +334,11 @@ class ModuleGraph(object):
...
@@ -350,13 +334,11 @@ class ModuleGraph(object):
""" 更新参数。
""" 更新参数。
"""
"""
self
.
param_tree
.
traverse
()
self
.
param_tree
.
traverse
()
full_old_name_list
=
copy
.
deepcopy
(
full_old_name_list
=
copy
.
deepcopy
(
list
(
self
.
pd_graph
.
parameters
.
keys
()))
list
(
self
.
pd_graph
.
parameters
.
keys
()))
for
old_name
,
new_name
in
self
.
param_tree
.
old2new
.
items
():
for
old_name
,
new_name
in
self
.
param_tree
.
old2new
.
items
():
for
full_old_name
in
full_old_name_list
:
for
full_old_name
in
full_old_name_list
:
if
full_old_name
.
startswith
(
"{}."
.
format
(
old_name
)):
if
full_old_name
.
startswith
(
"{}."
.
format
(
old_name
)):
full_new_name
=
full_old_name
.
replace
(
full_new_name
=
full_old_name
.
replace
(
"{}."
.
format
(
old_name
),
"{}."
.
format
(
new_name
))
"{}."
.
format
(
old_name
),
"{}."
.
format
(
new_name
))
params
=
self
.
pd_graph
.
parameters
.
pop
(
full_old_name
)
params
=
self
.
pd_graph
.
parameters
.
pop
(
full_old_name
)
self
.
pd_graph
.
parameters
[
full_new_name
]
=
params
self
.
pd_graph
.
parameters
[
full_new_name
]
=
params
if
full_old_name
==
old_name
:
if
full_old_name
==
old_name
:
...
@@ -369,21 +351,18 @@ class ModuleGraph(object):
...
@@ -369,21 +351,18 @@ class ModuleGraph(object):
input_data_name
=
', '
.
join
(
self
.
pd_graph
.
inputs
)
input_data_name
=
', '
.
join
(
self
.
pd_graph
.
inputs
)
run_func_list
=
list
()
run_func_list
=
list
()
run_func_list
.
append
(
"def main({}):"
.
format
(
input_data_name
))
run_func_list
.
append
(
"def main({}):"
.
format
(
input_data_name
))
run_func_list
.
append
(
" # There are {} inputs."
.
format
(
run_func_list
.
append
(
" # There are {} inputs."
.
format
(
len
(
self
.
pd_graph
.
inputs_info
)))
len
(
self
.
pd_graph
.
inputs_info
)))
for
k
,
v
in
self
.
pd_graph
.
inputs_info
.
items
():
for
k
,
v
in
self
.
pd_graph
.
inputs_info
.
items
():
run_func_list
.
append
(
" # {}: shape-{}, type-{}."
.
format
(
k
,
v
[
run_func_list
.
append
(
" # {}: shape-{}, type-{}."
.
format
(
k
,
v
[
0
],
v
[
1
]))
0
],
v
[
1
]))
run_func_list
.
extend
(
run_func_list
.
extend
([
[
" paddle.disable_static()"
,
" paddle.disable_static()"
,
" params = paddle.load('{}')"
.
format
(
osp
.
join
(
osp
.
abspath
(
save_dir
),
"model.pdparams"
)),
" params = paddle.load('{}')"
.
format
(
osp
.
join
(
osp
.
abspath
(
save_dir
),
"model.pdparams"
)),
" model = {}()"
.
format
(
self
.
pd_graph
.
name
),
" model = {}()"
.
format
(
self
.
pd_graph
.
name
),
" model.set_dict(params)"
,
" model.eval()"
,
" model.set_dict(params)"
,
" out = model({})"
.
format
(
input_data_name
),
" return out"
" model.eval()"
,
])
" out = model({})"
.
format
(
input_data_name
),
" return out"
])
return
"
\n
"
.
join
(
run_func_list
)
return
"
\n
"
.
join
(
run_func_list
)
combination
,
combination_id
=
self
.
get_updation_information
()
combination
,
combination_id
=
self
.
get_updation_information
()
self
.
convert_subgraph_to_layer
(
combination
,
combination_id
)
self
.
convert_subgraph_to_layer
(
combination
,
combination_id
)
self
.
update_parameters
()
self
.
update_parameters
()
...
@@ -403,3 +382,4 @@ class ModuleGraph(object):
...
@@ -403,3 +382,4 @@ class ModuleGraph(object):
run_func
=
gen_main_code
()
run_func
=
gen_main_code
()
f
.
write
(
run_func
)
f
.
write
(
run_func
)
f
.
close
()
f
.
close
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录