Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
795b3c1b
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
795b3c1b
编写于
8月 11, 2020
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the code style
上级
a7fdf1da
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
13 addition
and
12 deletion
+13
-12
x2paddle/core/convert_prim.py
x2paddle/core/convert_prim.py
+2
-4
x2paddle/core/program.py
x2paddle/core/program.py
+9
-6
x2paddle/optimizer/passes.py
x2paddle/optimizer/passes.py
+2
-2
未找到文件。
x2paddle/core/convert_prim.py
浏览文件 @
795b3c1b
...
@@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
...
@@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
inputs_list
=
list
(
layer
.
inputs
.
values
())
inputs_list
=
list
(
layer
.
inputs
.
values
())
for
i
,
input
in
enumerate
(
inputs_list
):
for
i
,
input
in
enumerate
(
inputs_list
):
if
input
is
None
:
if
input
is
None
:
inputs_list
[
i
]
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
inputs_list
[
i
]
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
i
]])
i
]])
inputs_str
=
', '
.
join
(
inputs_list
)
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = [{}]"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
line
=
"{} = [{}]"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
elif
layer
.
kernel
==
"prim.exception"
:
elif
layer
.
kernel
==
"prim.exception"
:
...
@@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
...
@@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
attrs_str
+=
"{}:"
.
format
(
v
)
attrs_str
+=
"{}:"
.
format
(
v
)
attrs_str
=
attrs_str
[:
-
1
]
attrs_str
=
attrs_str
[:
-
1
]
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
attrs_str
)
attrs_str
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
x2paddle/core/program.py
浏览文件 @
795b3c1b
...
@@ -129,7 +129,7 @@ class PaddleGraph(object):
...
@@ -129,7 +129,7 @@ class PaddleGraph(object):
if
len
(
layer
.
blocks
)
>
0
:
if
len
(
layer
.
blocks
)
>
0
:
for
block
in
layer
.
blocks
:
for
block
in
layer
.
blocks
:
block
.
build
(
layer
.
inputs
,
layer
.
outputs
)
block
.
build
(
layer
.
inputs
,
layer
.
outputs
)
if
self
.
graph_type
==
"dygraph"
:
if
self
.
graph_type
==
"dygraph"
:
self
.
get_dygraph_inputs
()
self
.
get_dygraph_inputs
()
self
.
get_dygraph_outputs
()
self
.
get_dygraph_outputs
()
...
@@ -284,6 +284,7 @@ class PaddleGraph(object):
...
@@ -284,6 +284,7 @@ class PaddleGraph(object):
for
block
in
layer
.
blocks
:
for
block
in
layer
.
blocks
:
block
.
get_dygraph_inputs
()
block
.
get_dygraph_inputs
()
self
.
inputs
.
extend
(
block
.
inputs
)
self
.
inputs
.
extend
(
block
.
inputs
)
update
(
self
.
layers
)
update
(
self
.
layers
)
self
.
inputs
=
list
(
set
(
self
.
inputs
))
self
.
inputs
=
list
(
set
(
self
.
inputs
))
...
@@ -310,7 +311,7 @@ class PaddleGraph(object):
...
@@ -310,7 +311,7 @@ class PaddleGraph(object):
else
:
else
:
codes
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
codes
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
return
codes
return
codes
def
gen_head
():
def
gen_head
():
self
.
head
=
gen_codes
(
self
.
head
=
gen_codes
(
[
[
...
@@ -332,7 +333,7 @@ class PaddleGraph(object):
...
@@ -332,7 +333,7 @@ class PaddleGraph(object):
gen_codes
(
gen_codes
(
[
"def forward(self, {}):"
.
format
(
input_data_name
)],
[
"def forward(self, {}):"
.
format
(
input_data_name
)],
indent
=
1
))
indent
=
1
))
def
write_code
(
code_dir
):
def
write_code
(
code_dir
):
f
=
open
(
os
.
path
.
join
(
code_dir
,
'code.py'
),
'w'
)
f
=
open
(
os
.
path
.
join
(
code_dir
,
'code.py'
),
'w'
)
for
code_line
in
self
.
head
:
for
code_line
in
self
.
head
:
...
@@ -396,9 +397,11 @@ class PaddleGraph(object):
...
@@ -396,9 +397,11 @@ class PaddleGraph(object):
self
.
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
self
.
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
elif
"prim"
in
layer
.
kernel
:
elif
"prim"
in
layer
.
kernel
:
from
.convert_prim
import
convert_prim
from
.convert_prim
import
convert_prim
convert_prim
(
layer
,
indent
=
indent
,
convert_prim
(
init_func
=
self
.
init_func
,
layer
,
forward_func
=
self
.
forward_func
)
indent
=
indent
,
init_func
=
self
.
init_func
,
forward_func
=
self
.
forward_func
)
else
:
else
:
if
len
(
layer
.
outputs
)
==
1
:
if
len
(
layer
.
outputs
)
==
1
:
line
=
layer
.
outputs
[
0
]
line
=
layer
.
outputs
[
0
]
...
...
x2paddle/optimizer/passes.py
浏览文件 @
795b3c1b
...
@@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher):
...
@@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
PyTorchMatcher
,
self
).
__init__
()
super
(
PyTorchMatcher
,
self
).
__init__
()
def
match_pattern
(
self
,
pattern
,
graph
,
start_i
d
):
def
match_pattern
(
self
,
pattern
,
graph
,
start_i
ndex
):
pattern_index
=
0
pattern_index
=
0
pattern_global_layers
=
pattern
.
get_global_layers
()
pattern_global_layers
=
pattern
.
get_global_layers
()
subgraph_global_layers
=
dict
()
subgraph_global_layers
=
dict
()
graph_layers
=
dict
(
list
(
graph
.
layers
.
items
())[
start_i
d
:])
graph_layers
=
dict
(
list
(
graph
.
layers
.
items
())[
start_i
ndex
:])
for
layer_id
,
layer
in
graph_layers
.
items
():
for
layer_id
,
layer
in
graph_layers
.
items
():
pattern_layer
=
pattern
.
layers
[
list
(
pattern
.
layers
.
keys
())[
pattern_layer
=
pattern
.
layers
[
list
(
pattern
.
layers
.
keys
())[
pattern_index
]]
pattern_index
]]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录