Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
91879f50
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91879f50
编写于
4月 16, 2021
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix fro pre-commit
上级
2fc9ffd0
变更
3
展开全部
显示空白变更内容
内联
并排
Showing
3 changed file
with
613 addition
and
217 deletion
+613
-217
x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py
x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py
+123
-51
x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py
x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py
+460
-143
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py
...dle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py
+30
-23
未找到文件。
x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py
浏览文件 @
91879f50
...
...
@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node):
tensor_value
=
value
value
=
"{}"
.
format
(
value
)
if
"tensor"
in
value
:
if
isinstance
(
tensor_value
,
list
)
or
isinstance
(
tensor_value
,
tuple
):
if
isinstance
(
tensor_value
,
list
)
or
isinstance
(
tensor_value
,
tuple
):
name_dict
=
dict
()
for
i
,
tv
in
enumerate
(
tensor_value
):
output_name_i
=
"{}_p{}"
.
format
(
output_name
,
i
)
output_name_i
=
"{}_p{}"
.
format
(
output_name
,
i
)
key_i
=
"input{}"
.
format
(
i
)
mapper
.
paddle_params
[
output_name_i
]
=
tv
.
cpu
().
detach
().
numpy
()
mapper
.
paddle_params
[
output_name_i
]
=
tv
.
cpu
().
detach
(
).
numpy
()
graph
.
add_layer
(
"self.create_parameter"
,
inputs
=
{},
outputs
=
[
output_name_i
],
scope_name
=
scope_name
,
dtype
=
string
(
str
(
mapper
.
paddle_params
[
output_name_i
].
dtype
)),
shape
=
mapper
.
paddle_params
[
output_name_i
].
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
dtype
=
string
(
str
(
mapper
.
paddle_params
[
output_name_i
].
dtype
)),
shape
=
mapper
.
paddle_params
[
output_name_i
].
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
name_dict
[
key_i
]
=
output_name_i
graph
.
add_layer
(
"prim.list"
,
...
...
@@ -59,16 +63,18 @@ def prim_Constant(mapper, graph, node):
scope_name
=
scope_name
)
return
[],
[
output_name
]
else
:
# mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
mapper
.
paddle_params
[
output_name
]
=
tensor_value
.
cpu
().
detach
().
numpy
()
# mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
mapper
.
paddle_params
[
output_name
]
=
tensor_value
.
cpu
().
detach
(
).
numpy
()
graph
.
add_layer
(
"self.create_parameter"
,
inputs
=
{},
outputs
=
[
output_name
],
scope_name
=
scope_name
,
dtype
=
string
(
str
(
mapper
.
paddle_params
[
output_name
].
dtype
)),
shape
=
mapper
.
paddle_params
[
output_name
].
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
shape
=
mapper
.
paddle_params
[
output_name
].
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
return
[],
[
output_name
]
if
"inf"
in
str
(
value
):
t
=
str
(
type
(
value
)).
split
(
"'"
)[
1
]
...
...
@@ -81,7 +87,11 @@ def prim_Constant(mapper, graph, node):
value
=
int
(
math
.
pow
(
2
,
31
)
-
1
)
mapper
.
attrs
[
output_name
]
=
value
graph
.
add_layer
(
"prim.constant"
,
inputs
=
{},
outputs
=
[
output_name
],
scope_name
=
scope_name
,
value
=
value
)
"prim.constant"
,
inputs
=
{},
outputs
=
[
output_name
],
scope_name
=
scope_name
,
value
=
value
)
return
[],
[
output_name
]
...
...
@@ -105,12 +115,17 @@ def prim_data(mapper, graph, node):
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%4336
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
...
...
@@ -136,14 +151,15 @@ def prim_DictConstruct(mapper, graph, node):
current_outputs
=
[
output_name
]
# 处理每个输入
for
i
,
input_name
in
enumerate
(
inputs_name
):
if
i
%
2
==
0
:
layer_attrs
[
"key{}"
.
format
(
int
(
i
/
2
))]
=
mapper
.
attrs
[
input_name
]
if
i
%
2
==
0
:
layer_attrs
[
"key{}"
.
format
(
int
(
i
/
2
))]
=
mapper
.
attrs
[
input_name
]
else
:
layer_inputs
[
"value{}"
.
format
(
int
(
i
/
2
))]
=
input_name
layer_inputs
[
"value{}"
.
format
(
int
(
i
/
2
))]
=
input_name
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.dict_construct"
,
graph
.
add_layer
(
"prim.dict_construct"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
,
...
...
@@ -151,7 +167,6 @@ def prim_DictConstruct(mapper, graph, node):
return
current_inputs
,
current_outputs
def
prim_GetAttr
(
mapper
,
graph
,
node
):
""" 获取attribute信息。
...
...
@@ -212,8 +227,13 @@ def prim_If(mapper, graph, node):
input_node
=
list
(
node
.
inputs
())[
0
].
node
()
script_input_unique_id
=
list
(
node
.
inputs
())[
0
].
unique
()
input_node_name
=
mapper
.
outputs_info
[
script_input_unique_id
]
mapper
.
_check_input
(
graph
,
input_node
,
input_node_name
,
current_outputs
,
scope_name
)
graph
.
add_layer
(
"prim.if"
,
inputs
=
{
'input'
:
input_node_name
},
outputs
=
node_outputs
,
scope_name
=
scope_name
)
mapper
.
_check_input
(
graph
,
input_node
,
input_node_name
,
current_outputs
,
scope_name
)
graph
.
add_layer
(
"prim.if"
,
inputs
=
{
'input'
:
input_node_name
},
outputs
=
node_outputs
,
scope_name
=
scope_name
)
current_layer
=
list
(
graph
.
layers
.
values
())[
-
1
]
block0
=
list
(
node
.
blocks
())[
0
]
block0_graph
,
graph_inputs0
=
mapper
.
traverse
(
block0
,
current_layer
)
...
...
@@ -249,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node):
current_outputs
=
[
output_name
]
# 处理每个输入
for
i
,
input_name
in
enumerate
(
inputs_name
):
mapper
.
_check_input
(
graph
,
inputs_node
[
i
],
input_name
,
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
i
],
input_name
,
current_outputs
,
scope_name
)
layer_inputs
[
"input{}"
.
format
(
i
)]
=
input_name
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
layer_id
=
graph
.
add_layer
(
"prim.list"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
layer_id
=
graph
.
add_layer
(
"prim.list"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
mapper
.
output2id
[
output_name
]
=
layer_id
return
current_inputs
,
current_outputs
...
...
@@ -277,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node):
# 获取当前节点输出的list
current_outputs
=
layer_outputs
.
copy
()
# 处理输入0,即%4354
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.list_unpack"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"prim.list_unpack"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
mapper
.
split_len
[
list
(
layer_inputs
.
values
())[
0
]]
=
len
(
layer_outputs
)
return
current_inputs
,
current_outputs
...
...
@@ -342,7 +371,11 @@ def prim_Loop(mapper, graph, node):
scope_name
=
scope_name
)
node_outputs
.
append
(
block_input_node_name
)
graph
.
add_layer
(
"prim.loop"
,
inputs
=
loop_inputs
,
outputs
=
loop_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.loop"
,
inputs
=
loop_inputs
,
outputs
=
loop_outputs
,
scope_name
=
scope_name
)
current_layer
=
list
(
graph
.
layers
.
values
())[
-
1
]
block_graph
,
graph_inputs
=
mapper
.
traverse
(
block
,
current_layer
)
for
i
,
input_name
in
enumerate
(
graph_inputs
):
...
...
@@ -370,12 +403,17 @@ def prim_min(mapper, graph, node):
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%86
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.min"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.min"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
...
...
@@ -397,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node):
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%86
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
inputs_inputs_name
,
inputs_inputs_node
=
mapper
.
_get_inputs_name
(
inputs_node
[
0
])
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
inputs_inputs_name
,
inputs_inputs_node
=
mapper
.
_get_inputs_name
(
inputs_node
[
0
])
if
inputs_node
[
0
].
kind
()
==
"aten::size"
and
len
(
inputs_inputs_name
)
>
1
:
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim_equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"prim_equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
else
:
layer_inputs
[
"fill_value"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
...
...
@@ -437,13 +480,17 @@ def prim_RaiseException(mapper, graph, node):
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%76
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.exception"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"prim.exception"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
...
...
@@ -464,13 +511,17 @@ def prim_requires_grad(mapper, graph, node):
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%86
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.requires_grad"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"prim.requires_grad"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
...
...
@@ -527,13 +578,17 @@ def prim_shape(mapper, graph, node):
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%input.8
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"paddle.shape"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
"paddle.shape"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
...
...
@@ -560,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node):
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.tuple"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.tuple"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
...
...
@@ -590,7 +649,11 @@ def prim_TupleUnpack(mapper, graph, node):
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.tuple_unpack"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
,
**
layer_attrs
)
"prim.tuple_unpack"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
,
**
layer_attrs
)
return
current_inputs
,
current_outputs
...
...
@@ -614,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node):
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%size.63
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"prim.equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.equal"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
scope_name
=
scope_name
)
return
current_inputs
,
current_outputs
...
...
@@ -636,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node):
output
=
list
(
node
.
outputs
())[
0
]
mapper
.
attrs
[
output_name
]
=
None
graph
.
add_layer
(
"prim.constant"
,
inputs
=
{},
outputs
=
[
output_name
],
scope_name
=
scope_name
,
value
=
None
)
"prim.constant"
,
inputs
=
{},
outputs
=
[
output_name
],
scope_name
=
scope_name
,
value
=
None
)
return
[],
[
output_name
]
x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py
浏览文件 @
91879f50
此差异已折叠。
点击以展开。
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py
浏览文件 @
91879f50
...
...
@@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper):
op_list
.
append
(
node
.
kind
())
for
block
in
node
.
blocks
():
_update_op_list
(
block
)
op_list
=
list
()
_update_op_list
(
script_graph
)
op_list
=
list
(
set
(
op_list
))
...
...
@@ -62,8 +63,8 @@ class PyTorchOpMapper(OpMapper):
return
True
else
:
if
len
(
unsupported_ops
)
>
0
:
print
(
"
\n
========= {} OPs are not supported yet ==========="
.
format
(
len
(
unsupported_ops
)))
print
(
"
\n
========= {} OPs are not supported yet ==========="
.
format
(
len
(
unsupported_ops
)))
for
op
in
unsupported_ops
:
print
(
"========== {} ============"
.
format
(
op
))
return
False
...
...
@@ -85,7 +86,10 @@ class PyTorchOpMapper(OpMapper):
current_node_outputs
.
extend
(
outputs
)
# 初始化
graph
=
PaddleGraph
(
source_type
=
"pytorch"
,
parent_layer
=
parent_layer
,
graph_type
=
"dygraph"
)
graph
=
PaddleGraph
(
source_type
=
"pytorch"
,
parent_layer
=
parent_layer
,
graph_type
=
"dygraph"
)
if
"TopLevelTracedModule"
in
str
(
type
(
self
.
script
)):
graph
.
set_script
(
self
.
script
)
current_node_outputs
=
[]
...
...
@@ -98,7 +102,8 @@ class PyTorchOpMapper(OpMapper):
if
str
(
ivalue
.
type
())
not
in
[
"Tensor"
,
"Dict[str, Tensor]"
]:
graph
.
set_name
(
str
(
ivalue
.
type
()).
split
(
"."
)[
-
1
])
continue
inputs
,
outputs
=
self
.
data
(
graph
,
node
,
ivalue
.
unique
(),
input_ct
)
inputs
,
outputs
=
self
.
data
(
graph
,
node
,
ivalue
.
unique
(),
input_ct
)
input_ct
+=
1
# 转换中间节点
for
node
in
script_graph
.
nodes
():
...
...
@@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper):
outputs
=
[
output_name
],
scope_name
=
scope_name
,
dtype
=
string
(
str
(
param
.
dtype
)),
shape
=
param
.
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
shape
=
param
.
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
self
.
output2id
[
output_name
]
=
layer_id
else
:
if
isinstance
(
param
,
dict
)
and
"Tensor"
in
param
and
\
...
...
@@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper):
outputs
=
[
output_name
],
scope_name
=
scope_name
,
dtype
=
string
(
str
(
param
.
dtype
)),
shape
=
param
.
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
shape
=
param
.
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
node_outputs
.
append
(
output_name
)
self
.
output2id
[
output_name
]
=
layer_id
return
...
...
@@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper):
value
=
string
(
param
)
if
isinstance
(
param
,
str
)
else
param
)
node_outputs
.
append
(
output_name
)
elif
node
.
kind
()
==
"prim::Constant"
and
output_name
in
self
.
pytorch_params
:
elif
node
.
kind
(
)
==
"prim::Constant"
and
output_name
in
self
.
pytorch_params
:
param
=
self
.
pytorch_params
[
output_name
]
self
.
paddle_params
[
output_name
]
=
param
layer_id
=
graph
.
add_layer
(
...
...
@@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper):
outputs
=
[
output_name
],
scope_name
=
scope_name
,
dtype
=
string
(
str
(
param
.
dtype
)),
shape
=
param
.
shape
,
shape
=
param
.
shape
,
default_initializer
=
"paddle.nn.initializer.Constant(value=0.0)"
)
self
.
output2id
[
output_name
]
=
layer_id
def
_get_inputs_name
(
self
,
node
):
inputs_name
=
[]
inputs_node
=
[]
...
...
@@ -257,7 +264,6 @@ class PyTorchOpMapper(OpMapper):
inputs_name
.
append
(
input_name
)
return
inputs_name
,
inputs_node
def
data
(
self
,
graph
,
node
,
uid
,
input_ct
):
scope_name
=
self
.
normalize_scope_name
(
node
)
for
output_ivalue
in
node
.
outputs
():
...
...
@@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper):
data
=
output_name
)
if
self
.
input_examples
is
not
None
:
input_np
=
self
.
input_examples
[
input_ct
].
detach
().
numpy
()
self
.
inputs_info
[
output_name
]
=
[
list
(
input_np
.
shape
),
str
(
input_np
.
dtype
)]
self
.
inputs_info
[
output_name
]
=
[
list
(
input_np
.
shape
),
str
(
input_np
.
dtype
)]
return
[],
[
output_name
]
def
equal
(
self
,
graph
,
node
,
uid
=
None
,
parent_layer
=
None
,
index
=
None
):
...
...
@@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper):
control_output_id
=
index
-
1
output_node_name
=
parent_layer
.
outputs
[
control_output_id
]
current_outputs
=
[
output_node_name
]
self
.
_check_input
(
graph
,
node
,
input_node_name
,
current_outputs
,
scope_name
)
self
.
_check_input
(
graph
,
node
,
input_node_name
,
current_outputs
,
scope_name
)
graph
.
add_layer
(
"prim.equal"
,
inputs
=
{
'input'
:
input_node_name
},
...
...
@@ -328,10 +336,10 @@ class PyTorchOpMapper(OpMapper):
if
self
.
scope_name2id
[
i
][
ns
]
!=
0
:
name_segments
[
i
]
=
name_segments
[
i
]
+
\
"__{}"
.
format
(
self
.
scope_name2id
[
i
][
ns
])
prefix_scope_name
=
"/"
.
join
(
name_segments
[
1
:
i
+
1
])
prefix_scope_name
=
"/"
.
join
(
name_segments
[
1
:
i
+
1
])
is_found
=
False
for
j
in
range
(
len
(
self
.
scope_name_list
)):
last_scope_name
=
self
.
scope_name_list
[
-
1
-
j
]
last_scope_name
=
self
.
scope_name_list
[
-
1
-
j
]
if
last_scope_name
.
startswith
(
prefix_scope_name
+
"/"
)
\
or
last_scope_name
==
prefix_scope_name
:
if
j
!=
0
:
# and i != len(name_segments) - 1:
...
...
@@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper):
real_scope_name
=
"/"
.
join
(
name_segments
[
1
:])
self
.
scope_name_list
.
append
(
real_scope_name
)
return
real_scope_name
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录