Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
cbc3efdb
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cbc3efdb
编写于
8月 18, 2020
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add densenet
上级
5bcd803c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
666 addition
and
158 deletion
+666
-158
x2paddle/core/convert_prim.py
x2paddle/core/convert_prim.py
+211
-135
x2paddle/core/program.py
x2paddle/core/program.py
+13
-7
x2paddle/op_mapper/pytorch2paddle/aten.py
x2paddle/op_mapper/pytorch2paddle/aten.py
+332
-9
x2paddle/op_mapper/pytorch2paddle/prim.py
x2paddle/op_mapper/pytorch2paddle/prim.py
+94
-1
x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py
x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py
+16
-6
未找到文件。
x2paddle/core/convert_prim.py
浏览文件 @
cbc3efdb
...
...
@@ -13,144 +13,220 @@
# limitations under the License.
def
convert_prim
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
def
gen_codes
(
code_list
,
indent
=
0
):
indent_blank
=
" "
*
indent
codes
=
[]
for
code_line
in
code_list
:
if
code_line
.
strip
()
==
""
:
codes
.
append
(
'
\n
'
)
else
:
codes
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
return
codes
if
layer
.
kernel
==
"prim.if"
:
line
=
"if {} :"
.
format
(
list
(
layer
.
inputs
.
values
())[
0
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
block
=
layer
.
blocks
[
1
]
if
len
(
block
.
layers
)
>
0
:
line
=
"else:"
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
return
elif
layer
.
kernel
==
"prim.loop"
:
loop_range
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
loop_range
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
line
=
"for {} in range({}):"
.
format
(
layer
.
outputs
[
1
],
loop_range
)
def
gen_codes
(
code_list
,
indent
=
0
):
indent_blank
=
" "
*
indent
codes
=
[]
for
code_line
in
code_list
:
if
code_line
.
strip
()
==
""
:
codes
.
append
(
'
\n
'
)
else
:
codes
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
return
codes
def
prim_add
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} + {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_add_
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} + {} * {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
attrs
[
"alpha"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_and
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} and {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_append
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{}.append({})"
.
format
(
layer
.
inputs
[
"list"
],
layer
.
inputs
[
"element"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_assert
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
if
layer
.
attrs
[
"type"
]
==
"eq"
:
if
isinstance
(
layer
.
attrs
[
"value"
],
list
):
s
=
""
for
v
in
layer
.
attrs
[
"value"
]:
s
+=
"{} == {} or "
.
format
(
layer
.
attrs
[
"key"
],
v
)
if
len
(
s
)
>
0
:
s
=
s
[:
-
4
]
line
=
"assert {},
\'
The {} must be {}!
\'
"
.
format
(
s
,
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
])
else
:
line
=
"assert {} == {},
\'
The {} must be {}!
\'
"
.
format
(
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
],
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
])
else
:
raise
Exception
(
"Not implement yet!"
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_constant
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
attrs
[
"value"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_eq
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} == {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_equal
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_exception
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"raise RaiseException({})"
.
format
(
layer
.
inputs
[
"input"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_if
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"if {} :"
.
format
(
list
(
layer
.
inputs
.
values
())[
0
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
block
=
layer
.
blocks
[
1
]
if
len
(
block
.
layers
)
>
0
:
line
=
"else:"
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
return
elif
layer
.
kernel
==
"prim.equal"
:
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.constant"
:
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
attrs
[
"value"
])
elif
layer
.
kernel
==
"prim.list"
:
inputs_list
=
list
(
layer
.
inputs
.
values
())
for
i
,
input
in
enumerate
(
inputs_list
):
if
input
is
None
:
inputs_list
[
i
]
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
i
]])
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = [{}]"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
elif
layer
.
kernel
==
"prim.exception"
:
exception
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
exception
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
line
=
"raise RaiseException({})"
.
format
(
exception
)
elif
layer
.
kernel
==
"prim.min"
:
line
=
"{} = min({})"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.add_"
:
line
=
"{} = {} + {} * {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
layer
.
attrs
[
"alpha"
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.append"
:
line
=
"{} = {}.append({})"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.shape"
:
line
=
"{} = {}.shape"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.len"
:
line
=
"{} = len({})"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.eq"
:
line
=
"{} = {} == {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.assert"
:
if
layer
.
attrs
[
"type"
]
==
"eq"
:
if
isinstance
(
layer
.
attrs
[
"value"
],
list
):
s
=
""
for
v
in
layer
.
attrs
[
"value"
]:
s
+=
"{} == {} or "
.
format
(
layer
.
attrs
[
"key"
],
v
)
if
len
(
s
)
>
0
:
s
=
s
[:
-
4
]
line
=
"assert {},
\'
The {} must be {}!
\'
"
.
format
(
s
,
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
])
else
:
line
=
"assert {} == {},
\'
The {} must be {}!
\'
"
.
format
(
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
],
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
])
else
:
raise
Exception
(
"Not implement yet!"
)
elif
layer
.
kernel
==
"prim.getitem"
:
item0
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
item0
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
item1
=
list
(
layer
.
inputs
.
values
())[
1
]
if
list
(
layer
.
inputs
.
values
())[
1
]
is
None
:
item1
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
1
]])
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
item0
,
item1
)
elif
layer
.
kernel
==
"prim.le"
:
item0
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
item0
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
item1
=
list
(
layer
.
inputs
.
values
())[
1
]
if
list
(
layer
.
inputs
.
values
())[
1
]
is
None
:
item1
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
1
]])
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
item0
,
item1
)
elif
layer
.
kernel
==
"prim.ne"
:
item0
=
list
(
layer
.
inputs
.
values
())[
0
]
item1
=
list
(
layer
.
inputs
.
values
())[
1
]
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
item0
,
item1
)
elif
layer
.
kernel
==
"prim.slice"
:
inputs_str
=
""
for
v
in
list
(
layer
.
inputs
.
values
())[
1
:]:
inputs_str
+=
"{}:"
.
format
(
v
)
inputs_str
=
inputs_str
[:
-
1
]
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
inputs_str
)
elif
layer
.
kernel
==
"prim.add"
:
line
=
"{} = {} + {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.sub"
:
line
=
"{} = {} - {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.mul"
:
line
=
"{} = {} * {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.neg"
:
line
=
"{} = -{}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
else
:
print
(
layer
.
kernel
)
line
=
""
def
prim_getitem
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"list"
],
layer
.
inputs
[
"index"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_gt
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} > {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_le
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} <= {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_len
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = len({})"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_lt
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_list
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
inputs_list
=
list
(
layer
.
inputs
.
values
())
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = [{}]"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_loop
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
loop_range
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
loop_range
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
line
=
"for {} in range({}):"
.
format
(
layer
.
outputs
[
1
],
loop_range
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
def
prim_min
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = min({})"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_mul
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} * {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_ne
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_neg
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = -{}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_not
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = not {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_requires_grad
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = not {}.stop_gradient"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_select
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}["
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
])
for
dim
in
range
(
layer
.
attrs
[
"dim"
]):
line
+=
":, "
line
+=
(
layer
.
inputs
[
"index"
]
+
"]"
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_shape
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}.shape"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_slice
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}[{}: {}: {}]"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
],
layer
.
inputs
[
"start"
],
layer
.
inputs
[
"end"
],
layer
.
inputs
[
"step"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_sub
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} - {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_tuple
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
inputs_list
=
list
(
layer
.
inputs
.
values
())
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = ({})"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_tuple_unpack
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
outputs_str
=
', '
.
join
(
layer
.
outputs
)
line
=
"{} = {}"
.
format
(
outputs_str
,
layer
.
inputs
[
"input"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_warnings
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
lines
=
[
"import warnings"
]
line
=
"warnings.warn({}, stacklevel={})"
.
format
(
layer
.
inputs
[
"input"
],
layer
.
attrs
[
"stacklevel"
])
lines
.
append
(
line
)
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
x2paddle/core/program.py
浏览文件 @
cbc3efdb
...
...
@@ -297,7 +297,6 @@ class PaddleGraph(object):
for
output_name
in
layer
.
outputs
:
if
not
output_name
.
startswith
(
"x"
):
continue
print
(
layer
.
kernel
)
self
.
outputs
.
append
(
output_name
)
self
.
outputs
=
list
(
set
(
self
.
outputs
))
...
...
@@ -396,12 +395,19 @@ class PaddleGraph(object):
line
+=
")"
self
.
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
elif
"prim"
in
layer
.
kernel
:
from
.convert_prim
import
convert_prim
convert_prim
(
layer
,
indent
=
indent
,
init_func
=
self
.
init_func
,
forward_func
=
self
.
forward_func
)
func_name
=
layer
.
kernel
.
replace
(
"."
,
"_"
)
from
.
import
convert_prim
if
hasattr
(
convert_prim
,
func_name
):
func
=
getattr
(
convert_prim
,
func_name
)
func
(
layer
,
indent
=
indent
,
init_func
=
self
.
init_func
,
forward_func
=
self
.
forward_func
)
else
:
raise
Exception
(
"The kind {} in paddle model is not supported yet."
.
format
(
layer
.
kernel
))
else
:
if
len
(
layer
.
outputs
)
==
1
:
line
=
layer
.
outputs
[
0
]
...
...
x2paddle/op_mapper/pytorch2paddle/aten.py
浏览文件 @
cbc3efdb
...
...
@@ -183,6 +183,36 @@ def aten_add_(mapper, graph, node):
return
current_inputs
,
current_outputs
def
aten___and__
(
mapper
,
graph
,
node
):
""" 构造与计算的PaddleLayer。
TorchScript示例:
%361 : bool = aten::__and__(%360, %358)
参数含义:
%361 (bool): 输出,与计算结果。
%360 (-): 输入 x。
%358 (-): 输入 y。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%i.12
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"x"
]
=
inputs_name
[
0
]
# 处理输入1,即%288
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
add_dim
=
True
)
layer_inputs
[
"y"
]
=
inputs_name
[
1
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.and"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
)
return
current_inputs
,
current_outputs
def
aten_append
(
mapper
,
graph
,
node
):
""" 构造对list进行append的PaddleLayer。
...
...
@@ -193,12 +223,11 @@ def aten_append(mapper, graph, node):
%_output_size.1 (list): 需要进行append的list。
%v.1 (-): append的元素。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
layer_outputs
=
[
inputs_name
[
0
]]
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
current_outputs
=
[
inputs_name
[
0
]
]
# 处理输入0,即_output_size.1
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"list"
]
=
inputs_name
[
0
]
...
...
@@ -212,6 +241,66 @@ def aten_append(mapper, graph, node):
return
current_inputs
,
current_outputs
def
aten_avg_pool2d
(
mapper
,
graph
,
node
):
""" 构造最大池化的PaddleLayer。
TorchScript示例:
%branch_pool.2 : Tensor = aten::avg_pool2d(%x.43, %538, %539, %540, %273, %272, %271)
参数含义:
%branch_pool.2 (Tensor): 输出,池化后的结果。
%x.43 (Tensor): 需要池化的Tensor。
%538 (list): 池化kernel的大小。
%539 (list): 步长大小。
%540 (list): 填充大小。
%273 (bool): 是否用ceil函数计算输出高度和宽度。
%272 (bool): 是否在平均池化模式不忽略填充值,False为忽略。
%271 (int): 如果指定,它将用作除数,否则将使用池化区域的大小。
"""
if
"pool"
in
mapper
.
dygraph_name_id
:
mapper
.
dygraph_name_id
[
"pool"
]
+=
1
else
:
mapper
.
dygraph_name_id
[
"pool"
]
=
0
pool_name
=
"pool"
+
str
(
mapper
.
dygraph_name_id
[
"pool"
])
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
pool_name
,
output_name
]
layer_inputs
=
{}
layer_attrs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%x.34
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
# 处理输入1,即%538
layer_attrs
[
"pool_size"
]
=
mapper
.
attrs
[
inputs_name
[
1
]]
# 处理输入2,即%539
layer_attrs
[
"pool_stride"
]
=
mapper
.
attrs
[
inputs_name
[
2
]]
# 处理输入3,即%540
layer_attrs
[
"pool_padding"
]
=
mapper
.
attrs
[
inputs_name
[
3
]]
# 处理输入4,即%273
layer_attrs
[
"ceil_mode"
]
=
mapper
.
attrs
[
inputs_name
[
4
]]
# 处理输入5,即%272
layer_attrs
[
"exclusive"
]
=
not
mapper
.
attrs
[
inputs_name
[
5
]]
# 处理输入6,即%271
graph
.
add_layer
(
"prim.assert"
,
inputs
=
{},
outputs
=
[
inputs_name
[
6
]],
type
=
"eq"
,
key
=
mapper
.
attrs
[
inputs_name
[
6
]],
value
=
None
)
layer_attrs
[
"pool_type"
]
=
string
(
"avg"
)
graph
.
add_layer
(
"fluid.dygraph.Pool2D"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
**
layer_attrs
)
return
current_inputs
,
current_outputs
def
aten_batch_norm
(
mapper
,
graph
,
node
):
""" 构造BatchNorm的PaddleLayer。
...
...
@@ -278,6 +367,44 @@ def aten_batch_norm(mapper, graph, node):
return
current_inputs
,
current_outputs
def
aten_cat
(
mapper
,
graph
,
node
):
""" 构造连接Tensor的PaddleLayer。
TorchScript示例:
%x.222 : Tensor = aten::cat(%32, %7)
参数含义:
%x.222 (Tensor): 输出,连接后的结果。
%i.12 (list): 需要连接的Tensor组成的list。
%7 (int): 连接的轴。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
layer_attrs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%13
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
# 处理输入1,即%12
if
inputs_name
[
1
]
in
mapper
.
attrs
:
layer_attrs
[
"axis"
]
=
mapper
.
attrs
[
inputs_name
[
1
]]
else
:
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
)
layer_attrs
[
"axis"
]
=
inputs_name
[
1
]
current_inputs
.
append
(
inputs_name
[
1
])
graph
.
add_layer
(
"fluid.layers.concat"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
**
layer_attrs
)
return
current_inputs
,
current_outputs
def
aten_conv2d
(
mapper
,
graph
,
node
):
""" 构造conv2d的PaddleLayer。
...
...
@@ -512,6 +639,35 @@ def aten___getitem__(mapper, graph, node):
return
current_inputs
,
current_outputs
def
aten_gt
(
mapper
,
graph
,
node
):
""" 构造对比大小的PaddleLayer。
TorchScript示例:
%83 : bool = aten::gt(%82, %78)
参数含义:
%83 (bool): 输出,第一个元素是否大于第二个元素。
%82 (-): 需对比的输入1。
%78 (-): 需对比的输入2。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%82
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"x"
]
=
inputs_name
[
0
]
# 处理输入1,即%78
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
)
layer_inputs
[
"y"
]
=
inputs_name
[
1
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.gt"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
)
return
current_inputs
,
current_outputs
def
aten_hardtanh_
(
mapper
,
graph
,
node
):
""" 构造hardtanh激活的PaddleLayer。
...
...
@@ -565,7 +721,7 @@ def aten_le(mapper, graph, node):
TorchScript示例:
%80 : bool = aten::le(%78, %79)
参数含义:
%80 (bool): 输出,第一个元素是否小于第二个元素。
%80 (bool): 输出,第一个元素是否小于
等于
第二个元素。
%78 (-): 需对比的输入1。
%79 (-): 需对比的输入2。
"""
...
...
@@ -577,10 +733,10 @@ def aten_le(mapper, graph, node):
current_outputs
=
[
output_name
]
# 处理输入0,即%78
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"
input0
"
]
=
inputs_name
[
0
]
layer_inputs
[
"
x
"
]
=
inputs_name
[
0
]
# 处理输入1,即%79
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
)
layer_inputs
[
"
input1
"
]
=
inputs_name
[
1
]
layer_inputs
[
"
y
"
]
=
inputs_name
[
1
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
...
...
@@ -613,6 +769,35 @@ def aten_len(mapper, graph, node):
return
current_inputs
,
current_outputs
def
aten_lt
(
mapper
,
graph
,
node
):
""" 构造对比大小的PaddleLayer。
TorchScript示例:
%80 : bool = aten::lt(%78, %79)
参数含义:
%80 (bool): 输出,第一个元素是否小于第二个元素。
%78 (-): 需对比的输入1。
%79 (-): 需对比的输入2。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%78
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"x"
]
=
inputs_name
[
0
]
# 处理输入1,即%79
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
)
layer_inputs
[
"y"
]
=
inputs_name
[
1
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.lt"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
)
return
current_inputs
,
current_outputs
def
aten_max_pool2d
(
mapper
,
graph
,
node
):
""" 构造最大池化的PaddleLayer。
...
...
@@ -784,6 +969,31 @@ def aten_neg(mapper, graph, node):
return
current_inputs
,
current_outputs
def
aten___not__
(
mapper
,
graph
,
node
):
""" 构造对bool型取负的PaddleLayer。
TorchScript示例:
%4498 : bool = aten::__not__(%aux_defined.2)
参数含义:
%4498 (bool): 取负后结果。
%aux_defined.2 (bool): 需取负的输入。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%124
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.not"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
)
return
current_inputs
,
current_outputs
def
aten_relu_
(
mapper
,
graph
,
node
):
""" 构造ReLU激活的PaddleLayer。
...
...
@@ -874,6 +1084,43 @@ def aten_reshape(mapper, graph, node):
return
current_inputs
,
current_outputs
def
aten_select
(
mapper
,
graph
,
node
):
""" 构造选取特定维度Variable的PaddleLayer。
TorchScript示例:
%19 : Tensor = aten::select(%18, %8, %7)
参数含义:
%19 (Tensor): 输出,选取的Tensor。
%18 (Tensor): 需要选取的Tensor。
%8 (int): select的维度。
%7 (int): select的第n个向量。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
layer_attrs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%18
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 处理输入1,即%8
layer_attrs
[
"dim"
]
=
mapper
.
attrs
[
inputs_name
[
1
]]
# 处理输入2,即%75
mapper
.
_check_input
(
graph
,
inputs_node
[
2
],
inputs_name
[
2
],
current_outputs
)
layer_inputs
[
"index"
]
=
inputs_name
[
2
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.select"
,
inputs
=
layer_inputs
,
outputs
=
current_outputs
,
**
layer_attrs
)
return
current_inputs
,
current_outputs
def
aten_size
(
mapper
,
graph
,
node
):
""" 构造获取shape的PaddleLayer。
...
...
@@ -900,13 +1147,13 @@ def aten_size(mapper, graph, node):
def
aten_slice
(
mapper
,
graph
,
node
):
""" 构造切分list的PaddleLayer。
""" 构造切分list
或Variable
的PaddleLayer。
TorchScript示例:
%83 : int[] = aten::slice(%73, %82, %75, %77)
参数含义:
%83 (list): 输出,切分后的list。
%73 (list): 需要切分的list。
%83 (list
/Tensor
): 输出,切分后的list。
%73 (list
/Tensor
): 需要切分的list。
%82 (int): 切分的开始索引。
%75 (int): 切分的结束索引。
%77 (int): 切分的步长。
...
...
@@ -993,3 +1240,79 @@ def aten_t(mapper, graph, node):
outputs
=
layer_outputs
,
perm
=
[
1
,
0
])
return
current_inputs
,
current_outputs
def
aten_unsqueeze
(
mapper
,
graph
,
node
):
""" 构造插入维度的PaddleLayer。
TorchScript示例:
%13 : Tensor = aten::unsqueeze(%12, %7)
参数含义:
%13 (Tensor): 输出,插入维度后的Tensor。
%12 (Tensor): 需要插入维度的Tensor。
%7 (int): 维度。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
layer_attrs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%13
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
# 处理输入1,即%12
if
inputs_name
[
1
]
in
mapper
.
attrs
:
layer_attrs
[
"axes"
]
=
mapper
.
attrs
[
inputs_name
[
1
]]
else
:
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
)
layer_attrs
[
"axes"
]
=
inputs_name
[
1
]
current_inputs
.
append
(
inputs_name
[
1
])
graph
.
add_layer
(
"fluid.layers.unsqueeze"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
**
layer_attrs
)
return
current_inputs
,
current_outputs
def
aten_warn
(
mapper
,
graph
,
node
):
""" 构造warning的PaddleLayer。
TorchScript示例:
= aten::warn(%3, %2)
参数含义:
%3 (str): warning的提示字符串。
%2 (int): warning的stacklevel。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
layer_attrs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%3
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
# 处理输入1,即%2
if
inputs_name
[
1
]
in
mapper
.
attrs
:
layer_attrs
[
"stacklevel"
]
=
mapper
.
attrs
[
inputs_name
[
1
]]
else
:
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
)
layer_attrs
[
"stacklevel"
]
=
inputs_name
[
1
]
current_inputs
.
append
(
inputs_name
[
1
])
graph
.
add_layer
(
"prim.warnings"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
**
layer_attrs
)
return
current_inputs
,
current_outputs
x2paddle/op_mapper/pytorch2paddle/prim.py
浏览文件 @
cbc3efdb
...
...
@@ -74,9 +74,9 @@ def prim_ListConstruct(mapper, graph, node):
TorchScript示例:
%86 : int[] = prim::ListConstruct(%84, %85)
参数含义:
%86 (list): list节点输出。
%84 (int/其他): list第一个元素信息。
%85 (int/其他): list第二个元素信息。
%86 (list): list节点输出。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
...
...
@@ -247,6 +247,32 @@ def prim_min(mapper, graph, node):
return
current_inputs
,
current_outputs
def
prim_requires_grad
(
mapper
,
graph
,
node
):
""" 构造是否计算梯度的PaddleLayer。
TorchScript示例:
%356 : bool = prim::requires_grad(%tensor.31)
参数含义:
%356 (bool): 输出,当前Tensor是否计算梯度。
%tensor.31 (Tensor): 输入的Tensor。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%86
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.requires_grad"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
)
return
current_inputs
,
current_outputs
def
prim_SetAttr
(
mapper
,
graph
,
node
):
""" 设置attribute信息。
...
...
@@ -297,3 +323,70 @@ def prim_shape(mapper, graph, node):
graph
.
add_layer
(
"prim.shape"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
)
return
current_inputs
,
current_outputs
def
prim_TupleConstruct
(
mapper
,
graph
,
node
):
""" 构造tuple的PaddleLayer。
TorchScript示例:
%4492 : (Tensor, Tensor?) = prim::TupleConstruct(%x.46, %aux)
参数含义:
%4492 (tuple): 输出,tuple。
%x.46 (Tensor/其他): tuple第一个元素信息。
%aux (Tensor/其他): tuple第二个元素信息。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理每个输入
for
i
,
input_name
in
enumerate
(
inputs_name
):
layer_inputs
[
"input{}"
.
format
(
i
)]
=
input_name
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.tuple"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
)
return
current_inputs
,
current_outputs
def
prim_TupleUnpack
(
mapper
,
graph
,
node
):
""" 构造获取tuple元素的PaddleLayer。
TorchScript示例:
%x.223 : Tensor, %aux.3 : Tensor? = prim::TupleUnpack(%4492)
参数含义:
%x.223 (Tensor/其他): 输出,tuple第一个元素信息。
%aux.3 (Tensor/其他): 输出,tuple第二个元素信息。
%4492 (tuple): 需要获取元素的tuple。
"""
outputs_name
=
mapper
.
_get_outputs_name
(
node
)
layer_outputs
=
outputs_name
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
outputs_name
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.tuple_unpack"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
)
return
current_inputs
,
current_outputs
def
prim_Uninitialized
(
mapper
,
graph
,
node
):
""" 构造表示编译器永远不会使用的值的PaddleLayer,该节点转换为None。
TorchScript示例:
%345 : bool = prim::Uninitialized()
参数含义:
%345 (bool): 输出,为赋值的bool。
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
output
=
list
(
node
.
outputs
())[
0
]
mapper
.
attrs
[
output_name
]
=
None
graph
.
add_layer
(
"prim.constant"
,
inputs
=
{},
outputs
=
[
output_name
],
value
=
None
)
return
[],
[
output_name
]
x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py
浏览文件 @
cbc3efdb
...
...
@@ -56,11 +56,19 @@ class PyTorchOpMapper(OpMapper):
def
traverse
(
self
,
script_graph
,
parent_layer
=
None
):
# 用于获取graph的输入
def
_update_graph_inputs
(
inputs
,
outputs
):
current_node_outputs
.
extend
(
outputs
)
def
_update_graph_inputs
(
kind
,
inputs
,
outputs
):
# extend只能放更新graph_inputs之前的情况:
# 1. loop的输出i也是输入;i是输入的原因是:子图中为父图得到的。
# 2. 在_check_input中需要使用to_variable。
# extend只能放更新graph_inputs之后的情况:
# 使用了append。
if
kind
!=
"aten::append"
:
current_node_outputs
.
extend
(
outputs
)
for
name
in
inputs
:
if
name
not
in
current_node_outputs
:
graph_inputs
.
append
(
name
)
if
kind
==
"aten::append"
:
current_node_outputs
.
extend
(
outputs
)
# 初始化
graph
=
PaddleGraph
(
parent_layer
)
...
...
@@ -80,11 +88,11 @@ class PyTorchOpMapper(OpMapper):
if
hasattr
(
prim
,
func_name
):
func
=
getattr
(
prim
,
func_name
)
inputs
,
outputs
=
func
(
self
,
graph
,
node
)
_update_graph_inputs
(
inputs
,
outputs
)
_update_graph_inputs
(
kind
,
inputs
,
outputs
)
elif
hasattr
(
aten
,
func_name
):
func
=
getattr
(
aten
,
func_name
)
inputs
,
outputs
=
func
(
self
,
graph
,
node
)
_update_graph_inputs
(
inputs
,
outputs
)
_update_graph_inputs
(
kind
,
inputs
,
outputs
)
# 转换输出节点
if
hasattr
(
script_graph
,
'returnNode'
):
...
...
@@ -99,7 +107,7 @@ class PyTorchOpMapper(OpMapper):
uid
=
script_unique_id
,
parent_layer
=
parent_layer
,
index
=
i
)
_update_graph_inputs
(
inputs
,
outputs
)
_update_graph_inputs
(
"equal"
,
inputs
,
outputs
)
# 设置graph的参数
if
isinstance
(
script_graph
,
torch
.
_C
.
Graph
):
graph
.
set_parameters
(
self
.
paddle_params
)
...
...
@@ -190,8 +198,10 @@ class PyTorchOpMapper(OpMapper):
if
parent_layer
.
kernel
==
"prim.loop"
:
control_output_id
=
index
-
1
output_node_name
=
parent_layer
.
outputs
[
control_output_id
]
current_outputs
=
[
output_node_name
]
self
.
_check_input
(
graph
,
node
,
input_node_name
,
current_outputs
)
graph
.
add_layer
(
"prim.equal"
,
inputs
=
{
'input'
:
input_node_name
},
outputs
=
[
output_node_name
])
return
[
input_node_name
],
[
output_node_name
]
return
[
input_node_name
],
current_outputs
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录