Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
91879f50
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91879f50
编写于
4月 16, 2021
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix fro pre-commit
上级
2fc9ffd0
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
613 addition
and
217 deletion
+613
-217
x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py
x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py
+123
-51
x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py
x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py
+460
-143
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py
...dle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py
+30
-23
未找到文件。
x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py
浏览文件 @
91879f50
...
@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node):
...
@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node):
tensor_value
=
value
tensor_value
=
value
value
=
"{}"
.
format
(
value
)
value
=
"{}"
.
format
(
value
)
if
"tensor"
in
value
:
if
"tensor"
in
value
:
if
isinstance
(
tensor_value
,
list
)
or
isinstance
(
tensor_value
,
tuple
):
if
isinstance
(
tensor_value
,
list
)
or
isinstance
(
tensor_value
,
tuple
):
name_dict
=
dict
()
name_dict
=
dict
()
for
i
,
tv
in
enumerate
(
tensor_value
):
for
i
,
tv
in
enumerate
(
tensor_value
):
output_name_i
=
"{}_p{}"
.
format
(
output_name
,
i
)
output_name_i
=
"{}_p{}"
.
format
(
output_name
,
i
)
key_i
=
"input{}"
.
format
(
i
)
key_i
=
"input{}"
.
format
(
i
)
mapper
.
paddle_params
[
output_name_i
]
=
tv
.
cpu
().
detach
().
numpy
()
mapper
.
paddle_params
[
output_name_i
]
=
tv
.
cpu
().
detach
(
).
numpy
()
graph
.
add_layer
(
graph
.
add_layer
(
"self.create_parameter"
,
"self.create_parameter"
,
inputs
=
{},
inputs
=
{},
outputs
=
[
output_name_i
],
outputs
=
[
output_name_i
],
scope_name
=
scope_name
,
scope_name
=
scope_name
,
dtype
=
string
(
str
(
mapper
.
paddle_params
[
output_name_i
].
dtype
)),
dtype
=
string
(
shape
=
mapper
.
paddle_params
[
output_name_i
].
shape
,
str
(
mapper
.
paddle_params
[
output_name_i
].
dtype
)),
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
shape
=
mapper
.
paddle_params
[
output_name_i
].
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
name_dict
[
key_i
]
=
output_name_i
name_dict
[
key_i
]
=
output_name_i
graph
.
add_layer
(
graph
.
add_layer
(
"prim.list"
,
"prim.list"
,
...
@@ -59,16 +63,18 @@ def prim_Constant(mapper, graph, node):
...
@@ -59,16 +63,18 @@ def prim_Constant(mapper, graph, node):
scope_name
=
scope_name
)
scope_name
=
scope_name
)
return
[],
[
output_name
]
return
[],
[
output_name
]
else
:
else
:
# mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
# mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
mapper
.
paddle_params
[
output_name
]
=
tensor_value
.
cpu
().
detach
().
numpy
()
mapper
.
paddle_params
[
output_name
]
=
tensor_value
.
cpu
().
detach
(
).
numpy
()
graph
.
add_layer
(
graph
.
add_layer
(
"self.create_parameter"
,
"self.create_parameter"
,
inputs
=
{},
inputs
=
{},
outputs
=
[
output_name
],
outputs
=
[
output_name
],
scope_name
=
scope_name
,
scope_name
=
scope_name
,
dtype
=
string
(
str
(
mapper
.
paddle_params
[
output_name
].
dtype
)),
dtype
=
string
(
str
(
mapper
.
paddle_params
[
output_name
].
dtype
)),
shape
=
mapper
.
paddle_params
[
output_name
].
shape
,
shape
=
mapper
.
paddle_params
[
output_name
].
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
return
[],
[
output_name
]
return
[],
[
output_name
]
if
"inf"
in
str
(
value
):
if
"inf"
in
str
(
value
):
t
=
str
(
type
(
value
)).
split
(
"'"
)[
1
]
t
=
str
(
type
(
value
)).
split
(
"'"
)[
1
]
...
@@ -81,7 +87,11 @@ def prim_Constant(mapper, graph, node):
...
@@ -81,7 +87,11 @@ def prim_Constant(mapper, graph, node):
value
=
int
(
math
.
pow
(
2
,
31
)
-
1
)
value
=
int
(
math
.
pow
(
2
,
31
)
-
1
)
mapper
.
attrs
[
output_name
]
=
value
mapper
.
attrs
[
output_name
]
=
value
graph
.
add_layer
(
graph
.
add_layer
(
"prim.constant"
,
inputs
=
{},
outputs
=
[
output_name
],
scope_name
=
scope_name
,
value
=
value
)
"prim.constant"
,
inputs
=
{},
outputs
=
[
output_name
],
scope_name
=
scope_name
,
value
=
value
)
return
[],
[
output_name
]
return
[],
[
output_name
]
...
@@ -105,18 +115,23 @@ def prim_data(mapper, graph, node):
...
@@ -105,18 +115,23 @@ def prim_data(mapper, graph, node):
# 获取当前节点输出的list
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
current_outputs
=
[
output_name
]
# 处理输入0,即%4336
# 处理输入0,即%4336
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
def
prim_DictConstruct
(
mapper
,
graph
,
node
):
def
prim_DictConstruct
(
mapper
,
graph
,
node
):
""" 构建dict。
""" 构建dict。
TorchScript示例:
TorchScript示例:
%32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29)
%32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29)
参数含义:
参数含义:
...
@@ -136,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node):
...
@@ -136,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node):
current_outputs
=
[
output_name
]
current_outputs
=
[
output_name
]
# 处理每个输入
# 处理每个输入
for
i
,
input_name
in
enumerate
(
inputs_name
):
for
i
,
input_name
in
enumerate
(
inputs_name
):
if
i
%
2
==
0
:
if
i
%
2
==
0
:
layer_attrs
[
"key{}"
.
format
(
int
(
i
/
2
))]
=
mapper
.
attrs
[
input_name
]
layer_attrs
[
"key{}"
.
format
(
int
(
i
/
2
))]
=
mapper
.
attrs
[
input_name
]
else
:
else
:
layer_inputs
[
"value{}"
.
format
(
int
(
i
/
2
))]
=
input_name
layer_inputs
[
"value{}"
.
format
(
int
(
i
/
2
))]
=
input_name
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.dict_construct"
,
graph
.
add_layer
(
inputs
=
layer_inputs
,
"prim.dict_construct"
,
outputs
=
layer_outputs
,
inputs
=
layer_inputs
,
scope_name
=
scope_name
,
outputs
=
layer_outputs
,
**
layer_attrs
)
scope_name
=
scope_name
,
**
layer_attrs
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
def
prim_GetAttr
(
mapper
,
graph
,
node
):
def
prim_GetAttr
(
mapper
,
graph
,
node
):
""" 获取attribute信息。
""" 获取attribute信息。
...
@@ -212,8 +227,13 @@ def prim_If(mapper, graph, node):
...
@@ -212,8 +227,13 @@ def prim_If(mapper, graph, node):
input_node
=
list
(
node
.
inputs
())[
0
].
node
()
input_node
=
list
(
node
.
inputs
())[
0
].
node
()
script_input_unique_id
=
list
(
node
.
inputs
())[
0
].
unique
()
script_input_unique_id
=
list
(
node
.
inputs
())[
0
].
unique
()
input_node_name
=
mapper
.
outputs_info
[
script_input_unique_id
]
input_node_name
=
mapper
.
outputs_info
[
script_input_unique_id
]
mapper
.
_check_input
(
graph
,
input_node
,
input_node_name
,
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
input_node
,
input_node_name
,
current_outputs
,
graph
.
add_layer
(
"prim.if"
,
inputs
=
{
'input'
:
input_node_name
},
outputs
=
node_outputs
,
scope_name
=
scope_name
)
scope_name
)
graph
.
add_layer
(
"prim.if"
,
inputs
=
{
'input'
:
input_node_name
},
outputs
=
node_outputs
,
scope_name
=
scope_name
)
current_layer
=
list
(
graph
.
layers
.
values
())[
-
1
]
current_layer
=
list
(
graph
.
layers
.
values
())[
-
1
]
block0
=
list
(
node
.
blocks
())[
0
]
block0
=
list
(
node
.
blocks
())[
0
]
block0_graph
,
graph_inputs0
=
mapper
.
traverse
(
block0
,
current_layer
)
block0_graph
,
graph_inputs0
=
mapper
.
traverse
(
block0
,
current_layer
)
...
@@ -249,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node):
...
@@ -249,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node):
current_outputs
=
[
output_name
]
current_outputs
=
[
output_name
]
# 处理每个输入
# 处理每个输入
for
i
,
input_name
in
enumerate
(
inputs_name
):
for
i
,
input_name
in
enumerate
(
inputs_name
):
mapper
.
_check_input
(
graph
,
inputs_node
[
i
],
input_name
,
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
i
],
input_name
,
current_outputs
,
scope_name
)
layer_inputs
[
"input{}"
.
format
(
i
)]
=
input_name
layer_inputs
[
"input{}"
.
format
(
i
)]
=
input_name
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
layer_id
=
graph
.
add_layer
(
"prim.list"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
layer_id
=
graph
.
add_layer
(
"prim.list"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
mapper
.
output2id
[
output_name
]
=
layer_id
mapper
.
output2id
[
output_name
]
=
layer_id
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
...
@@ -277,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node):
...
@@ -277,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node):
# 获取当前节点输出的list
# 获取当前节点输出的list
current_outputs
=
layer_outputs
.
copy
()
current_outputs
=
layer_outputs
.
copy
()
# 处理输入0,即%4354
# 处理输入0,即%4354
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
graph
.
add_layer
(
"prim.list_unpack"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"prim.list_unpack"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
mapper
.
split_len
[
list
(
layer_inputs
.
values
())[
0
]]
=
len
(
layer_outputs
)
mapper
.
split_len
[
list
(
layer_inputs
.
values
())[
0
]]
=
len
(
layer_outputs
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
...
@@ -342,7 +371,11 @@ def prim_Loop(mapper, graph, node):
...
@@ -342,7 +371,11 @@ def prim_Loop(mapper, graph, node):
scope_name
=
scope_name
)
scope_name
=
scope_name
)
node_outputs
.
append
(
block_input_node_name
)
node_outputs
.
append
(
block_input_node_name
)
graph
.
add_layer
(
"prim.loop"
,
inputs
=
loop_inputs
,
outputs
=
loop_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.loop"
,
inputs
=
loop_inputs
,
outputs
=
loop_outputs
,
scope_name
=
scope_name
)
current_layer
=
list
(
graph
.
layers
.
values
())[
-
1
]
current_layer
=
list
(
graph
.
layers
.
values
())[
-
1
]
block_graph
,
graph_inputs
=
mapper
.
traverse
(
block
,
current_layer
)
block_graph
,
graph_inputs
=
mapper
.
traverse
(
block
,
current_layer
)
for
i
,
input_name
in
enumerate
(
graph_inputs
):
for
i
,
input_name
in
enumerate
(
graph_inputs
):
...
@@ -370,12 +403,17 @@ def prim_min(mapper, graph, node):
...
@@ -370,12 +403,17 @@ def prim_min(mapper, graph, node):
# 获取当前节点输出的list
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
current_outputs
=
[
output_name
]
# 处理输入0,即%86
# 处理输入0,即%86
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.min"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.min"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
...
@@ -397,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node):
...
@@ -397,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node):
# 获取当前节点输出的list
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
current_outputs
=
[
output_name
]
# 处理输入0,即%86
# 处理输入0,即%86
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
inputs_inputs_name
,
inputs_inputs_node
=
mapper
.
_get_inputs_name
(
inputs_node
[
0
])
scope_name
)
inputs_inputs_name
,
inputs_inputs_node
=
mapper
.
_get_inputs_name
(
inputs_node
[
0
])
if
inputs_node
[
0
].
kind
()
==
"aten::size"
and
len
(
inputs_inputs_name
)
>
1
:
if
inputs_node
[
0
].
kind
()
==
"aten::size"
and
len
(
inputs_inputs_name
)
>
1
:
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
graph
.
add_layer
(
"prim_equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"prim_equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
else
:
else
:
layer_inputs
[
"fill_value"
]
=
inputs_name
[
0
]
layer_inputs
[
"fill_value"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
# 获取当前节点输入的list
...
@@ -437,13 +480,17 @@ def prim_RaiseException(mapper, graph, node):
...
@@ -437,13 +480,17 @@ def prim_RaiseException(mapper, graph, node):
# 获取当前节点输出的list
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
current_outputs
=
[
output_name
]
# 处理输入0,即%76
# 处理输入0,即%76
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
graph
.
add_layer
(
"prim.exception"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"prim.exception"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
...
@@ -464,13 +511,17 @@ def prim_requires_grad(mapper, graph, node):
...
@@ -464,13 +511,17 @@ def prim_requires_grad(mapper, graph, node):
# 获取当前节点输出的list
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
current_outputs
=
[
output_name
]
# 处理输入0,即%86
# 处理输入0,即%86
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
graph
.
add_layer
(
"prim.requires_grad"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"prim.requires_grad"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
...
@@ -527,13 +578,17 @@ def prim_shape(mapper, graph, node):
...
@@ -527,13 +578,17 @@ def prim_shape(mapper, graph, node):
# 获取当前节点输出的list
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
current_outputs
=
[
output_name
]
# 处理输入0,即%input.8
# 处理输入0,即%input.8
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
graph
.
add_layer
(
"paddle.shape"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"paddle.shape"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
...
@@ -560,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node):
...
@@ -560,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node):
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.tuple"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.tuple"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
...
@@ -590,7 +649,11 @@ def prim_TupleUnpack(mapper, graph, node):
...
@@ -590,7 +649,11 @@ def prim_TupleUnpack(mapper, graph, node):
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
graph
.
add_layer
(
"prim.tuple_unpack"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
,
**
layer_attrs
)
"prim.tuple_unpack"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
,
**
layer_attrs
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
...
@@ -614,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node):
...
@@ -614,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node):
# 获取当前节点输出的list
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
current_outputs
=
[
output_name
]
# 处理输入0,即%size.63
# 处理输入0,即%size.63
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
return
current_inputs
,
current_outputs
...
@@ -636,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node):
...
@@ -636,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node):
output
=
list
(
node
.
outputs
())[
0
]
output
=
list
(
node
.
outputs
())[
0
]
mapper
.
attrs
[
output_name
]
=
None
mapper
.
attrs
[
output_name
]
=
None
graph
.
add_layer
(
graph
.
add_layer
(
"prim.constant"
,
inputs
=
{},
outputs
=
[
output_name
],
scope_name
=
scope_name
,
value
=
None
)
"prim.constant"
,
inputs
=
{},
outputs
=
[
output_name
],
scope_name
=
scope_name
,
value
=
None
)
return
[],
[
output_name
]
return
[],
[
output_name
]
x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py
浏览文件 @
91879f50
...
@@ -14,7 +14,8 @@
...
@@ -14,7 +14,8 @@
# limitations under the License.
# limitations under the License.
NO_OUTPUT_COUNT
=
0
NO_OUTPUT_COUNT
=
0
def
gen_codes
(
code_list
,
indent
=
0
):
def
gen_codes
(
code_list
,
indent
=
0
):
indent_blank
=
" "
*
indent
indent_blank
=
" "
*
indent
codes
=
[]
codes
=
[]
...
@@ -53,13 +54,24 @@ def get_value(layer, key, layer_id=None, different_attrs=None):
...
@@ -53,13 +54,24 @@ def get_value(layer, key, layer_id=None, different_attrs=None):
return
str
(
layer
.
attrs
[
key
])
return
str
(
layer
.
attrs
[
key
])
def
prim_add
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_add
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {} + {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} + {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_add_
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_add_
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {} + {} * {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} + {} * {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
layer
.
attrs
[
"alpha"
],
layer
.
attrs
[
"alpha"
],
...
@@ -67,22 +79,39 @@ def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
...
@@ -67,22 +79,39 @@ def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_and
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_and
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} and {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} and {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_append
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_append
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{}.append({})"
.
format
(
line
=
"{}.append({})"
.
format
(
get_value
(
layer
,
"list"
,
layer_id
,
different_attrs
),
get_value
(
layer
,
"list"
,
layer_id
,
different_attrs
),
get_value
(
layer
,
"element"
,
layer_id
,
different_attrs
))
get_value
(
layer
,
"element"
,
layer_id
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_assert
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_assert
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
if
layer
.
attrs
[
"type"
]
==
"eq"
:
if
layer
.
attrs
[
"type"
]
==
"eq"
:
values
=
get_value
(
layer
,
"key"
)
values
=
get_value
(
layer
,
"key"
)
if
"value"
in
layer
.
attrs
:
if
"value"
in
layer
.
attrs
:
...
@@ -93,15 +122,15 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d
...
@@ -93,15 +122,15 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d
s
+=
"{} == {} or "
.
format
(
get_value
(
layer
,
"key"
),
v
)
s
+=
"{} == {} or "
.
format
(
get_value
(
layer
,
"key"
),
v
)
if
len
(
s
)
>
0
:
if
len
(
s
)
>
0
:
s
=
s
[:
-
4
]
s
=
s
[:
-
4
]
lc
=
locals
()
lc
=
locals
()
exec
(
"assert_result = {}"
.
format
(
s
))
exec
(
"assert_result = {}"
.
format
(
s
))
assert_result
=
lc
[
'assert_result'
]
assert_result
=
lc
[
'assert_result'
]
line
=
"assert {},
\'
The {} must be {}!
\'
"
.
format
(
line
=
"assert {},
\'
The {} must be {}!
\'
"
.
format
(
s
,
get_value
(
layer
,
"key"
),
get_value
(
layer
,
"value"
))
s
,
get_value
(
layer
,
"key"
),
get_value
(
layer
,
"value"
))
else
:
else
:
s
=
"{} == {}"
.
format
(
get_value
(
layer
,
"key"
),
s
=
"{} == {}"
.
format
(
get_value
(
layer
,
"value"
))
get_value
(
layer
,
"key"
),
get_value
(
layer
,
"value"
))
lc
=
locals
()
lc
=
locals
()
exec
(
"assert_result = {}"
.
format
(
s
))
exec
(
"assert_result = {}"
.
format
(
s
))
assert_result
=
lc
[
'assert_result'
]
assert_result
=
lc
[
'assert_result'
]
line
=
"assert {},
\'
The {} must be {}!
\'
"
.
format
(
line
=
"assert {},
\'
The {} must be {}!
\'
"
.
format
(
...
@@ -112,7 +141,12 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d
...
@@ -112,7 +141,12 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_check_dim
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_check_dim
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
lines
=
[]
lines
=
[]
dim
=
get_value
(
layer
,
"dim"
,
different_attrs
)
dim
=
get_value
(
layer
,
"dim"
,
different_attrs
)
lines
.
append
(
"if {} < 0:"
.
format
(
dim
))
lines
.
append
(
"if {} < 0:"
.
format
(
dim
))
...
@@ -123,12 +157,23 @@ def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None
...
@@ -123,12 +157,23 @@ def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
def
prim_constant
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_constant
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
attrs
[
"value"
])
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
attrs
[
"value"
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_contain
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_contain
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} in {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} in {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"element"
,
different_attrs
),
get_value
(
layer
,
"element"
,
different_attrs
),
get_value
(
layer
,
"input"
,
different_attrs
))
get_value
(
layer
,
"input"
,
different_attrs
))
...
@@ -137,108 +182,182 @@ def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None,
...
@@ -137,108 +182,182 @@ def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None,
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_dict
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_dict
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = dict()"
.
format
(
layer
.
outputs
[
0
])
line
=
"{} = dict()"
.
format
(
layer
.
outputs
[
0
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_dict_construct
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_dict_construct
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
lines
=
list
()
lines
=
list
()
line
=
"{} = dict()"
.
format
(
layer
.
outputs
[
0
])
line
=
"{} = dict()"
.
format
(
layer
.
outputs
[
0
])
lines
.
append
(
line
)
lines
.
append
(
line
)
for
i
in
range
(
len
(
layer
.
inputs
)):
for
i
in
range
(
len
(
layer
.
inputs
)):
line
=
"{}[{}] = {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{}[{}] = {}"
.
format
(
get_value
(
layer
,
"key{}"
.
format
(
i
),
different_attrs
),
layer
.
outputs
[
0
],
get_value
(
layer
,
"value{}"
.
format
(
i
),
different_attrs
))
get_value
(
layer
,
"key{}"
.
format
(
i
),
different_attrs
),
get_value
(
layer
,
"value{}"
.
format
(
i
),
different_attrs
))
lines
.
append
(
line
)
lines
.
append
(
line
)
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
def
prim_dict2values
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_dict2values
(
layer
,
line
=
"{} = list({}.values())"
.
format
(
layer
.
outputs
[
0
],
indent
=
1
,
get_value
(
layer
,
"x"
,
different_attrs
))
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = list({}.values())"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_div
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_div
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {} / {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} / {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_eq
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_eq
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} == {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} == {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_equal
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_equal
(
layer
,
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_exception
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_exception
(
layer
,
line
=
"raise Exception({})"
.
format
(
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"raise Exception({})"
.
format
(
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_float
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_float
(
layer
,
line
=
"{} = float({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = float({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_floor
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_floor
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = math.floor({})"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = math.floor({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
))
get_value
(
layer
,
"x"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_floordiv
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_floordiv
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {} // {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} // {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_getitem
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_getitem
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"list"
,
different_attrs
),
get_value
(
layer
,
"list"
,
different_attrs
),
get_value
(
layer
,
"index"
,
different_attrs
))
get_value
(
layer
,
"index"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_gt
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_gt
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} > {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} > {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_if
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_if
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
try
:
try
:
exec_s
=
None
exec_s
=
None
for
line
in
forward_func
:
for
line
in
forward_func
:
s
=
line
.
replace
(
" "
,
""
)
s
=
line
.
replace
(
" "
,
""
)
if
s
.
startswith
(
"{} = "
.
format
(
get_value
(
layer
,
"input"
,
different_attrs
))):
if
s
.
startswith
(
"{} = "
.
format
(
get_value
(
layer
,
"input"
,
different_attrs
))):
exec_s
=
s
.
split
(
" = "
)[
1
]
exec_s
=
s
.
split
(
" = "
)[
1
]
lc
=
locals
()
lc
=
locals
()
if
exec_s
is
not
None
:
if
exec_s
is
not
None
:
exec
(
"if_result = {}"
.
format
(
exec_s
))
exec
(
"if_result = {}"
.
format
(
exec_s
))
else
:
else
:
exec
(
"if_result = {}"
.
format
(
get_value
(
layer
,
"input"
,
different_attrs
)))
exec
(
"if_result = {}"
.
format
(
get_value
(
layer
,
"input"
,
different_attrs
)))
if_result
=
lc
[
'if_result'
]
if_result
=
lc
[
'if_result'
]
if
if_result
:
if
if_result
:
block
=
layer
.
blocks
[
0
]
block
=
layer
.
blocks
[
0
]
else
:
else
:
block
=
layer
.
blocks
[
1
]
block
=
layer
.
blocks
[
1
]
if
len
(
block
.
layers
)
>
0
:
if
len
(
block
.
layers
)
>
0
:
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
)
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
)
init_func
.
extend
(
b_init_lines
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
forward_func
.
extend
(
b_forward_lines
)
except
:
except
:
...
@@ -249,7 +368,8 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe
...
@@ -249,7 +368,8 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe
line
=
"pass"
line
=
"pass"
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
+
1
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
+
1
))
else
:
else
:
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
init_func
.
extend
(
b_init_lines
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
forward_func
.
extend
(
b_forward_lines
)
block
=
layer
.
blocks
[
1
]
block
=
layer
.
blocks
[
1
]
...
@@ -263,30 +383,54 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe
...
@@ -263,30 +383,54 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe
forward_func
.
extend
(
b_forward_lines
)
forward_func
.
extend
(
b_forward_lines
)
def
prim_int
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_int
(
layer
,
line
=
"{} = int({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = int({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_is
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_is
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} is {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} is {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_isinstance
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_isinstance
(
layer
,
line
=
"{} = isinstance({}, {})"
.
format
(
layer
.
outputs
[
0
],
indent
=
1
,
get_value
(
layer
,
"input"
,
different_attrs
),
init_func
=
[],
layer
.
attrs
[
"cls"
])
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = isinstance({}, {})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
),
layer
.
attrs
[
"cls"
])
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_isnot
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_isnot
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} is not {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} is not {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
...
@@ -295,53 +439,94 @@ def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
...
@@ -295,53 +439,94 @@ def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_le
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_le
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} <= {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} <= {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_len
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_len
(
layer
,
line
=
"{} = len({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = len({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_len2list
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_len2list
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
lines
=
[]
lines
=
[]
lines
.
append
(
"{} = []"
.
format
(
layer
.
outputs
[
0
]))
lines
.
append
(
"{} = []"
.
format
(
layer
.
outputs
[
0
]))
lines
.
append
(
"for i in range({}):"
.
format
(
get_value
(
layer
,
"len"
,
different_attrs
)))
lines
.
append
(
"for i in range({}):"
.
format
(
get_value
(
layer
,
"len"
,
different_attrs
)))
lines
.
append
(
" {}.append(i)"
.
format
(
layer
.
outputs
[
0
]))
lines
.
append
(
" {}.append(i)"
.
format
(
layer
.
outputs
[
0
]))
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
def
prim_lt
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_lt
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_list
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_list
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
input_len
=
len
(
layer
.
inputs
)
+
len
(
layer
.
attrs
)
input_len
=
len
(
layer
.
inputs
)
+
len
(
layer
.
attrs
)
inputs_list
=
list
()
inputs_list
=
list
()
for
i
in
range
(
input_len
):
for
i
in
range
(
input_len
):
inputs_list
.
append
(
get_value
(
layer
,
"input{}"
.
format
(
i
),
different_attrs
))
inputs_list
.
append
(
get_value
(
layer
,
"input{}"
.
format
(
i
),
different_attrs
))
inputs_str
=
', '
.
join
(
inputs_list
)
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = [{}]"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
line
=
"{} = [{}]"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_list_unpack
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_list_unpack
(
layer
,
line
=
"{} = {}"
.
format
(
", "
.
join
(
layer
.
outputs
),
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}"
.
format
(
", "
.
join
(
layer
.
outputs
),
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_loop
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_loop
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
loop_range
=
get_value
(
layer
,
"input"
,
different_attrs
)
loop_range
=
get_value
(
layer
,
"input"
,
different_attrs
)
line
=
"for {} in range({}):"
.
format
(
layer
.
outputs
[
1
],
loop_range
)
line
=
"for {} in range({}):"
.
format
(
layer
.
outputs
[
1
],
loop_range
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
...
@@ -351,171 +536,303 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
...
@@ -351,171 +536,303 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
forward_func
.
extend
(
b_forward_lines
)
forward_func
.
extend
(
b_forward_lines
)
def
prim_min
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_min
(
layer
,
line
=
"{} = min({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = min({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_mul
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_mul
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {} * {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} * {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_ne
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_ne
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} != {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} != {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_neg
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_neg
(
layer
,
line
=
"{} = -{}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = -{}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_not
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_not
(
layer
,
line
=
"{} = not {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = not {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_or
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
def
prim_or
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {} or {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} or {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_replaceitem
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_replaceitem
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{}[{}] = {}"
.
format
(
line
=
"{}[{}] = {}"
.
format
(
get_value
(
layer
,
"list"
,
layer_id
,
different_attrs
),
get_value
(
layer
,
"list"
,
layer_id
,
different_attrs
),
get_value
(
layer
,
"index"
,
layer_id
,
different_attrs
),
get_value
(
layer
,
"index"
,
layer_id
,
different_attrs
),
get_value
(
layer
,
"item"
,
layer_id
,
different_attrs
))
get_value
(
layer
,
"item"
,
layer_id
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_requires_grad
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_requires_grad
(
layer
,
line
=
"{} = not {}.stop_gradient"
.
format
(
layer
.
outputs
[
0
],
indent
=
1
,
get_value
(
layer
,
"input"
,
different_attrs
))
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = not {}.stop_gradient"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_rsub
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_rsub
(
layer
,
line
=
"{} = {} - {} * {}"
.
format
(
layer
.
outputs
[
0
],
indent
=
1
,
get_value
(
layer
,
"y"
,
different_attrs
),
init_func
=
[],
get_value
(
layer
,
"x"
,
different_attrs
),
forward_func
=
[],
get_value
(
layer
,
"alpha"
,
different_attrs
))
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {} - {} * {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"y"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"alpha"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_select
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_select
(
layer
,
line
=
"{} = {}["
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}["
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
for
dim
in
range
(
layer
.
attrs
[
"dim"
]):
for
dim
in
range
(
layer
.
attrs
[
"dim"
]):
line
+=
":, "
line
+=
":, "
line
+=
(
get_value
(
layer
,
"index"
,
different_attrs
)
+
"]"
)
line
+=
(
get_value
(
layer
,
"index"
,
different_attrs
)
+
"]"
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_set_attr
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_set_attr
(
layer
,
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_set_item
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_set_item
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{}[{}] = {}"
.
format
(
line
=
"{}[{}] = {}"
.
format
(
get_value
(
layer
,
"dict"
,
different_attrs
),
get_value
(
layer
,
"dict"
,
different_attrs
),
get_value
(
layer
,
"key"
,
different_attrs
),
get_value
(
layer
,
"value"
,
different_attrs
))
get_value
(
layer
,
"key"
,
different_attrs
),
get_value
(
layer
,
"value"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_shape
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}.shape"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_shape_dim
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_shape
(
layer
,
line
=
"{} = {}.shape[{}]"
.
format
(
layer
.
outputs
[
0
],
indent
=
1
,
get_value
(
layer
,
"input"
,
different_attrs
),
init_func
=
[],
get_value
(
layer
,
"dim"
,
different_attrs
))
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}.shape"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_slice
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_shape_dim
(
layer
,
line
=
"{} = {}[{}: {}: {}]"
.
format
(
layer
.
outputs
[
0
],
indent
=
1
,
get_value
(
layer
,
"input"
,
different_attrs
),
init_func
=
[],
get_value
(
layer
,
"start"
,
different_attrs
),
forward_func
=
[],
get_value
(
layer
,
"end"
,
different_attrs
),
layer_id
=
None
,
get_value
(
layer
,
"step"
,
different_attrs
))
different_attrs
=
None
):
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
line
=
"{} = {}.shape[{}]"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
),
def
prim_startswith
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
get_value
(
layer
,
"dim"
,
different_attrs
))
line
=
"{} = {}.startswith({})"
.
format
(
layer
.
outputs
[
0
],
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
get_value
(
layer
,
"input"
,
different_attrs
),
get_value
(
layer
,
"start_str"
,
different_attrs
))
def
prim_slice
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}[{}: {}: {}]"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
),
get_value
(
layer
,
"start"
,
different_attrs
),
get_value
(
layer
,
"end"
,
different_attrs
),
get_value
(
layer
,
"step"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_startswith
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
,
is_return_line
=
False
):
line
=
"{} = {}.startswith({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
),
get_value
(
layer
,
"start_str"
,
different_attrs
))
if
is_return_line
:
if
is_return_line
:
return
line
.
split
(
" = "
)[
1
]
return
line
.
split
(
" = "
)[
1
]
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_str
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_str
(
layer
,
line
=
"{} = str({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = str({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_sub
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_sub
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
if
int
(
float
(
get_value
(
layer
,
"alpha"
,
different_attrs
)))
==
1
:
if
int
(
float
(
get_value
(
layer
,
"alpha"
,
different_attrs
)))
==
1
:
line
=
"{} = {} - {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} - {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"y"
,
different_attrs
))
else
:
else
:
line
=
"{} = {} - {} * {}"
.
format
(
layer
.
outputs
[
0
],
line
=
"{} = {} - {} * {}"
.
format
(
get_value
(
layer
,
"x"
,
different_attrs
),
layer
.
outputs
[
0
],
get_value
(
layer
,
"alpha"
,
different_attrs
),
get_value
(
layer
,
"x"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
get_value
(
layer
,
"alpha"
,
different_attrs
),
get_value
(
layer
,
"y"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_tuple
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_tuple
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
input_len
=
len
(
layer
.
inputs
)
+
len
(
layer
.
attrs
)
input_len
=
len
(
layer
.
inputs
)
+
len
(
layer
.
attrs
)
inputs_list
=
list
()
inputs_list
=
list
()
for
i
in
range
(
input_len
):
for
i
in
range
(
input_len
):
inputs_list
.
append
(
get_value
(
layer
,
"input{}"
.
format
(
i
),
different_attrs
))
inputs_list
.
append
(
get_value
(
layer
,
"input{}"
.
format
(
i
),
different_attrs
))
inputs_str
=
', '
.
join
(
inputs_list
)
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = ({})"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
line
=
"{} = ({})"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_tuple_unpack
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_tuple_unpack
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
outputs_str
=
', '
.
join
(
layer
.
outputs
)
outputs_str
=
', '
.
join
(
layer
.
outputs
)
line
=
"{} = {}"
.
format
(
outputs_str
,
get_value
(
layer
,
"input"
,
different_attrs
))
line
=
"{} = {}"
.
format
(
outputs_str
,
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_type
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_type
(
layer
,
line
=
"{} = {}.dtype"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}.dtype"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_var2list
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_var2list
(
layer
,
line
=
"{} = {}.numpy().tolist()"
.
format
(
layer
.
outputs
[
0
],
indent
=
1
,
get_value
(
layer
,
"input"
,
different_attrs
))
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
line
=
"{} = {}.numpy().tolist()"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
,
different_attrs
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_warnings
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
def
prim_warnings
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[],
layer_id
=
None
,
different_attrs
=
None
):
lines
=
[
"import warnings"
]
lines
=
[
"import warnings"
]
line
=
"warnings.warn({}, stacklevel={})"
.
format
(
line
=
"warnings.warn({}, stacklevel={})"
.
format
(
get_value
(
layer
,
"input"
,
different_attrs
),
layer
.
attrs
[
"stacklevel"
])
get_value
(
layer
,
"input"
,
different_attrs
),
layer
.
attrs
[
"stacklevel"
])
lines
.
append
(
line
)
lines
.
append
(
line
)
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py
浏览文件 @
91879f50
...
@@ -37,7 +37,7 @@ class PyTorchOpMapper(OpMapper):
...
@@ -37,7 +37,7 @@ class PyTorchOpMapper(OpMapper):
self
.
scope_name_list
=
list
()
self
.
scope_name_list
=
list
()
self
.
scope_name2id
=
dict
()
self
.
scope_name2id
=
dict
()
self
.
inputs_info
=
dict
()
self
.
inputs_info
=
dict
()
self
.
output2id
=
dict
()
# output名字和layer_id的关系,用于lstm去除前面的node
self
.
output2id
=
dict
()
# output名字和layer_id的关系,用于lstm去除前面的node
# 转换
# 转换
if
not
self
.
op_checker
(
decoder
.
graph
):
if
not
self
.
op_checker
(
decoder
.
graph
):
raise
Exception
(
"Model is not supported yet."
)
raise
Exception
(
"Model is not supported yet."
)
...
@@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper):
...
@@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper):
op_list
.
append
(
node
.
kind
())
op_list
.
append
(
node
.
kind
())
for
block
in
node
.
blocks
():
for
block
in
node
.
blocks
():
_update_op_list
(
block
)
_update_op_list
(
block
)
op_list
=
list
()
op_list
=
list
()
_update_op_list
(
script_graph
)
_update_op_list
(
script_graph
)
op_list
=
list
(
set
(
op_list
))
op_list
=
list
(
set
(
op_list
))
...
@@ -62,11 +63,11 @@ class PyTorchOpMapper(OpMapper):
...
@@ -62,11 +63,11 @@ class PyTorchOpMapper(OpMapper):
return
True
return
True
else
:
else
:
if
len
(
unsupported_ops
)
>
0
:
if
len
(
unsupported_ops
)
>
0
:
print
(
"
\n
========= {} OPs are not supported yet ==========="
.
format
(
print
(
"
\n
========= {} OPs are not supported yet ==========="
.
len
(
unsupported_ops
)))
format
(
len
(
unsupported_ops
)))
for
op
in
unsupported_ops
:
for
op
in
unsupported_ops
:
print
(
"========== {} ============"
.
format
(
op
))
print
(
"========== {} ============"
.
format
(
op
))
return
False
return
False
def
traverse
(
self
,
script_graph
,
parent_layer
=
None
):
def
traverse
(
self
,
script_graph
,
parent_layer
=
None
):
# 用于获取graph的输入
# 用于获取graph的输入
...
@@ -85,20 +86,24 @@ class PyTorchOpMapper(OpMapper):
...
@@ -85,20 +86,24 @@ class PyTorchOpMapper(OpMapper):
current_node_outputs
.
extend
(
outputs
)
current_node_outputs
.
extend
(
outputs
)
# 初始化
# 初始化
graph
=
PaddleGraph
(
source_type
=
"pytorch"
,
parent_layer
=
parent_layer
,
graph_type
=
"dygraph"
)
graph
=
PaddleGraph
(
source_type
=
"pytorch"
,
parent_layer
=
parent_layer
,
graph_type
=
"dygraph"
)
if
"TopLevelTracedModule"
in
str
(
type
(
self
.
script
)):
if
"TopLevelTracedModule"
in
str
(
type
(
self
.
script
)):
graph
.
set_script
(
self
.
script
)
graph
.
set_script
(
self
.
script
)
current_node_outputs
=
[]
current_node_outputs
=
[]
graph_inputs
=
[]
graph_inputs
=
[]
# 转换输入节点
# 转换输入节点
if
isinstance
(
script_graph
,
torch
.
_C
.
Graph
):
if
isinstance
(
script_graph
,
torch
.
_C
.
Graph
):
input_ct
=
0
input_ct
=
0
for
i
,
ivalue
in
enumerate
(
script_graph
.
inputs
()):
for
i
,
ivalue
in
enumerate
(
script_graph
.
inputs
()):
node
=
ivalue
.
node
()
node
=
ivalue
.
node
()
if
str
(
ivalue
.
type
())
not
in
[
"Tensor"
,
"Dict[str, Tensor]"
]:
if
str
(
ivalue
.
type
())
not
in
[
"Tensor"
,
"Dict[str, Tensor]"
]:
graph
.
set_name
(
str
(
ivalue
.
type
()).
split
(
"."
)[
-
1
])
graph
.
set_name
(
str
(
ivalue
.
type
()).
split
(
"."
)[
-
1
])
continue
continue
inputs
,
outputs
=
self
.
data
(
graph
,
node
,
ivalue
.
unique
(),
input_ct
)
inputs
,
outputs
=
self
.
data
(
graph
,
node
,
ivalue
.
unique
(),
input_ct
)
input_ct
+=
1
input_ct
+=
1
# 转换中间节点
# 转换中间节点
for
node
in
script_graph
.
nodes
():
for
node
in
script_graph
.
nodes
():
...
@@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper):
...
@@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper):
outputs
=
[
output_name
],
outputs
=
[
output_name
],
scope_name
=
scope_name
,
scope_name
=
scope_name
,
dtype
=
string
(
str
(
param
.
dtype
)),
dtype
=
string
(
str
(
param
.
dtype
)),
shape
=
param
.
shape
,
shape
=
param
.
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
self
.
output2id
[
output_name
]
=
layer_id
self
.
output2id
[
output_name
]
=
layer_id
else
:
else
:
if
isinstance
(
param
,
dict
)
and
"Tensor"
in
param
and
\
if
isinstance
(
param
,
dict
)
and
"Tensor"
in
param
and
\
...
@@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper):
...
@@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper):
outputs
=
[
output_name
],
outputs
=
[
output_name
],
scope_name
=
scope_name
,
scope_name
=
scope_name
,
dtype
=
string
(
str
(
param
.
dtype
)),
dtype
=
string
(
str
(
param
.
dtype
)),
shape
=
param
.
shape
,
shape
=
param
.
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
node_outputs
.
append
(
output_name
)
node_outputs
.
append
(
output_name
)
self
.
output2id
[
output_name
]
=
layer_id
self
.
output2id
[
output_name
]
=
layer_id
return
return
...
@@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper):
...
@@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper):
value
=
string
(
param
)
value
=
string
(
param
)
if
isinstance
(
param
,
str
)
else
param
)
if
isinstance
(
param
,
str
)
else
param
)
node_outputs
.
append
(
output_name
)
node_outputs
.
append
(
output_name
)
elif
node
.
kind
()
==
"prim::Constant"
and
output_name
in
self
.
pytorch_params
:
elif
node
.
kind
(
)
==
"prim::Constant"
and
output_name
in
self
.
pytorch_params
:
param
=
self
.
pytorch_params
[
output_name
]
param
=
self
.
pytorch_params
[
output_name
]
self
.
paddle_params
[
output_name
]
=
param
self
.
paddle_params
[
output_name
]
=
param
layer_id
=
graph
.
add_layer
(
layer_id
=
graph
.
add_layer
(
...
@@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper):
...
@@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper):
outputs
=
[
output_name
],
outputs
=
[
output_name
],
scope_name
=
scope_name
,
scope_name
=
scope_name
,
dtype
=
string
(
str
(
param
.
dtype
)),
dtype
=
string
(
str
(
param
.
dtype
)),
shape
=
param
.
shape
,
shape
=
param
.
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
self
.
output2id
[
output_name
]
=
layer_id
self
.
output2id
[
output_name
]
=
layer_id
def
_get_inputs_name
(
self
,
node
):
def
_get_inputs_name
(
self
,
node
):
inputs_name
=
[]
inputs_name
=
[]
inputs_node
=
[]
inputs_node
=
[]
...
@@ -256,7 +263,6 @@ class PyTorchOpMapper(OpMapper):
...
@@ -256,7 +263,6 @@ class PyTorchOpMapper(OpMapper):
inputs_node
.
append
(
script_input_node
)
inputs_node
.
append
(
script_input_node
)
inputs_name
.
append
(
input_name
)
inputs_name
.
append
(
input_name
)
return
inputs_name
,
inputs_node
return
inputs_name
,
inputs_node
def
data
(
self
,
graph
,
node
,
uid
,
input_ct
):
def
data
(
self
,
graph
,
node
,
uid
,
input_ct
):
scope_name
=
self
.
normalize_scope_name
(
node
)
scope_name
=
self
.
normalize_scope_name
(
node
)
...
@@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper):
...
@@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper):
data
=
output_name
)
data
=
output_name
)
if
self
.
input_examples
is
not
None
:
if
self
.
input_examples
is
not
None
:
input_np
=
self
.
input_examples
[
input_ct
].
detach
().
numpy
()
input_np
=
self
.
input_examples
[
input_ct
].
detach
().
numpy
()
self
.
inputs_info
[
output_name
]
=
[
list
(
input_np
.
shape
),
str
(
input_np
.
dtype
)]
self
.
inputs_info
[
output_name
]
=
[
list
(
input_np
.
shape
),
str
(
input_np
.
dtype
)]
return
[],
[
output_name
]
return
[],
[
output_name
]
def
equal
(
self
,
graph
,
node
,
uid
=
None
,
parent_layer
=
None
,
index
=
None
):
def
equal
(
self
,
graph
,
node
,
uid
=
None
,
parent_layer
=
None
,
index
=
None
):
...
@@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper):
...
@@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper):
control_output_id
=
index
-
1
control_output_id
=
index
-
1
output_node_name
=
parent_layer
.
outputs
[
control_output_id
]
output_node_name
=
parent_layer
.
outputs
[
control_output_id
]
current_outputs
=
[
output_node_name
]
current_outputs
=
[
output_node_name
]
self
.
_check_input
(
graph
,
node
,
input_node_name
,
current_outputs
,
scope_name
)
self
.
_check_input
(
graph
,
node
,
input_node_name
,
current_outputs
,
scope_name
)
graph
.
add_layer
(
graph
.
add_layer
(
"prim.equal"
,
"prim.equal"
,
inputs
=
{
'input'
:
input_node_name
},
inputs
=
{
'input'
:
input_node_name
},
...
@@ -321,20 +329,20 @@ class PyTorchOpMapper(OpMapper):
...
@@ -321,20 +329,20 @@ class PyTorchOpMapper(OpMapper):
self
.
scope_name2id
[
i
][
ns
]
=
0
self
.
scope_name2id
[
i
][
ns
]
=
0
real_scope_name
=
"/"
.
join
(
name_segments
[
1
:])
real_scope_name
=
"/"
.
join
(
name_segments
[
1
:])
real_father_scope_name
=
"/"
.
join
(
name_segments
[
1
:
-
1
])
real_father_scope_name
=
"/"
.
join
(
name_segments
[
1
:
-
1
])
for
i
,
ns
in
enumerate
(
name_segments
):
for
i
,
ns
in
enumerate
(
name_segments
):
if
i
==
0
:
if
i
==
0
:
continue
continue
if
self
.
scope_name2id
[
i
][
ns
]
!=
0
:
if
self
.
scope_name2id
[
i
][
ns
]
!=
0
:
name_segments
[
i
]
=
name_segments
[
i
]
+
\
name_segments
[
i
]
=
name_segments
[
i
]
+
\
"__{}"
.
format
(
self
.
scope_name2id
[
i
][
ns
])
"__{}"
.
format
(
self
.
scope_name2id
[
i
][
ns
])
prefix_scope_name
=
"/"
.
join
(
name_segments
[
1
:
i
+
1
])
prefix_scope_name
=
"/"
.
join
(
name_segments
[
1
:
i
+
1
])
is_found
=
False
is_found
=
False
for
j
in
range
(
len
(
self
.
scope_name_list
)):
for
j
in
range
(
len
(
self
.
scope_name_list
)):
last_scope_name
=
self
.
scope_name_list
[
-
1
-
j
]
last_scope_name
=
self
.
scope_name_list
[
-
1
-
j
]
if
last_scope_name
.
startswith
(
prefix_scope_name
+
"/"
)
\
if
last_scope_name
.
startswith
(
prefix_scope_name
+
"/"
)
\
or
last_scope_name
==
prefix_scope_name
:
or
last_scope_name
==
prefix_scope_name
:
if
j
!=
0
:
# and i != len(name_segments) - 1:
if
j
!=
0
:
# and i != len(name_segments) - 1:
is_found
=
True
is_found
=
True
origin_name_segment_i
=
name_segments
[
i
].
split
(
"__"
)[
0
]
origin_name_segment_i
=
name_segments
[
i
].
split
(
"__"
)[
0
]
self
.
scope_name2id
[
i
][
origin_name_segment_i
]
+=
1
self
.
scope_name2id
[
i
][
origin_name_segment_i
]
+=
1
...
@@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper):
...
@@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper):
real_scope_name
=
"/"
.
join
(
name_segments
[
1
:])
real_scope_name
=
"/"
.
join
(
name_segments
[
1
:])
self
.
scope_name_list
.
append
(
real_scope_name
)
self
.
scope_name_list
.
append
(
real_scope_name
)
return
real_scope_name
return
real_scope_name
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录