Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
21ff465e
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
21ff465e
编写于
8月 02, 2021
作者:
W
WJJ1995
提交者:
GitHub
8月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixed pytorch codegen bug (#650)
上级
379ce426
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
82 addition
and
59 deletion
+82
-59
x2paddle/op_mapper/pytorch2paddle/aten.py
x2paddle/op_mapper/pytorch2paddle/aten.py
+2
-2
x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py
.../optimizer/pytorch_code_optimizer/layer_code_generator.py
+80
-57
未找到文件。
x2paddle/op_mapper/pytorch2paddle/aten.py
浏览文件 @
21ff465e
...
@@ -5388,7 +5388,7 @@ def aten_upsample_bilinear2d(mapper, graph, node):
...
@@ -5388,7 +5388,7 @@ def aten_upsample_bilinear2d(mapper, graph, node):
%4963 (list): 上采样后的大小。
%4963 (list): 上采样后的大小。
%5421 (bool): 若为True,则将输入和输出张量的4个角落像素的中心对齐,并保留角点像素的值。
%5421 (bool): 若为True,则将输入和输出张量的4个角落像素的中心对齐,并保留角点像素的值。
%4995 (float): 高度的乘数因子。
%4995 (float): 高度的乘数因子。
%499
5
(float): 宽度的乘数因子。
%499
6
(float): 宽度的乘数因子。
"""
"""
scope_name
=
mapper
.
normalize_scope_name
(
node
)
scope_name
=
mapper
.
normalize_scope_name
(
node
)
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
...
@@ -5465,7 +5465,7 @@ def aten_upsample_nearest2d(mapper, graph, node):
...
@@ -5465,7 +5465,7 @@ def aten_upsample_nearest2d(mapper, graph, node):
%4997 (Tensor): 输出,上采样后的Tensor。
%4997 (Tensor): 输出,上采样后的Tensor。
%x.13 (Tensor): 需要上采样的Tensor。
%x.13 (Tensor): 需要上采样的Tensor。
%4963 (list): 上采样后的大小。
%4963 (list): 上采样后的大小。
%
4995
(float): 高度的乘数因子。
%
5421
(float): 高度的乘数因子。
%4995 (float): 宽度的乘数因子。
%4995 (float): 宽度的乘数因子。
"""
"""
scope_name
=
mapper
.
normalize_scope_name
(
node
)
scope_name
=
mapper
.
normalize_scope_name
(
node
)
...
...
x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py
浏览文件 @
21ff465e
...
@@ -19,33 +19,35 @@ import x2paddle
...
@@ -19,33 +19,35 @@ import x2paddle
from
x2paddle.optimizer.pytorch_code_optimizer.parameter_tree
import
PamareterNode
from
x2paddle.optimizer.pytorch_code_optimizer.parameter_tree
import
PamareterNode
from
x2paddle.core.util
import
*
from
x2paddle.core.util
import
*
NN_KERNEL_NAME
=
{
NN_KERNEL_NAME
=
{
"paddle.nn.BatchNorm"
:
"bn"
,
"paddle.nn.BatchNorm"
:
"bn"
,
"paddle.nn.LayerNorm"
:
"layernorm"
,
"paddle.nn.LayerNorm"
:
"layernorm"
,
"paddle.nn.Conv2D"
:
"conv"
,
"paddle.nn.Conv2D"
:
"conv"
,
"paddle.nn.Embedding"
:
"embedding"
,
"paddle.nn.Embedding"
:
"embedding"
,
"paddle.nn.Linear"
:
"linear"
,
"paddle.nn.Linear"
:
"linear"
,
"paddle.nn.Conv2DTranspose"
:
"conv"
,
"paddle.nn.Conv2DTranspose"
:
"conv"
,
"paddle.nn.LSTM"
:
"lstm"
,
"paddle.nn.LSTM"
:
"lstm"
,
"paddle.nn.GRU"
:
"gru"
,
"paddle.nn.GRU"
:
"gru"
,
"custom_layer:InstanceNorm"
:
"instance_norm"
,
"custom_layer:InstanceNorm"
:
"instance_norm"
,
"paddle.nn.PReLU"
:
"prelu"
,
"paddle.nn.PReLU"
:
"prelu"
,
"paddle.nn.ReLU"
:
"relu"
,
"paddle.nn.ReLU"
:
"relu"
,
"paddle.nn.ReLU6"
:
"relu"
,
"paddle.nn.ReLU6"
:
"relu"
,
"paddle.nn.Softmax"
:
"softmax"
,
"paddle.nn.Softmax"
:
"softmax"
,
"paddle.nn.Softplus"
:
"softplus"
,
"paddle.nn.Softplus"
:
"softplus"
,
"paddle.nn.Tanh"
:
"tanh"
,
"paddle.nn.Tanh"
:
"tanh"
,
"paddle.nn.AvgPool2D"
:
"avgpool"
,
"paddle.nn.AvgPool2D"
:
"avgpool"
,
"paddle.nn.MaxPool2D"
:
"maxpool"
,
"paddle.nn.MaxPool2D"
:
"maxpool"
,
"paddle.nn.Pad1D"
:
"pad1d"
,
"paddle.nn.Pad1D"
:
"pad1d"
,
"paddle.nn.Pad2D"
:
"pad2d"
,
"paddle.nn.Pad2D"
:
"pad2d"
,
"paddle.nn.Pad3D"
:
"pad3d"
,
"paddle.nn.Pad3D"
:
"pad3d"
,
"paddle.nn.Dropout"
:
"dropout"
,
"paddle.nn.Dropout"
:
"dropout"
,
"paddle.nn.GELU"
:
"gelu"
,
"paddle.nn.GELU"
:
"gelu"
,
"paddle.nn.Hardtanh"
:
"tanh"
,
"paddle.nn.Hardtanh"
:
"tanh"
,
"paddle.nn.LeakyReLU"
:
"leakly_relu"
}
"paddle.nn.LeakyReLU"
:
"leakly_relu"
}
NN_KERNEL_WITH_PARAMS
=
list
(
NN_KERNEL_NAME
.
keys
())[:
10
]
NN_KERNEL_WITH_PARAMS
=
list
(
NN_KERNEL_NAME
.
keys
())[:
10
]
def
rename_layers
(
layers
,
param_tree
=
None
,
is_rename_module
=
False
):
def
rename_layers
(
layers
,
param_tree
=
None
,
is_rename_module
=
False
):
""" 对子模块的输入输出等进行重命名。
""" 对子模块的输入输出等进行重命名。
"""
"""
...
@@ -58,6 +60,7 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
...
@@ -58,6 +60,7 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
new_names
=
list
()
new_names
=
list
()
for
kernel
in
NN_KERNEL_NAME
.
keys
():
for
kernel
in
NN_KERNEL_NAME
.
keys
():
nn_count_dict
[
kernel
]
=
0
nn_count_dict
[
kernel
]
=
0
def
rename_sub_layers
(
sub_layers
,
count
,
is_block
=
False
):
def
rename_sub_layers
(
sub_layers
,
count
,
is_block
=
False
):
for
layer_id
,
layer
in
sub_layers
.
items
():
for
layer_id
,
layer
in
sub_layers
.
items
():
# 对输入重命名
# 对输入重命名
...
@@ -69,10 +72,9 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
...
@@ -69,10 +72,9 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
count
+=
1
count
+=
1
layer
.
inputs
[
input_k
]
=
new_name
layer
.
inputs
[
input_k
]
=
new_name
name_dict
[
input_v
]
=
new_name
name_dict
[
input_v
]
=
new_name
# 对block重命名
# 对block重命名
for
block
in
layer
.
blocks
:
for
block
in
layer
.
blocks
:
count
=
rename_sub_layers
(
block
.
layers
,
count
=
rename_sub_layers
(
block
.
layers
,
count
,
is_block
=
True
)
count
,
is_block
=
True
)
# 对输出重命名
# 对输出重命名
if
len
(
layer
.
outputs
)
==
0
and
not
is_block
:
if
len
(
layer
.
outputs
)
==
0
and
not
is_block
:
new_names
.
append
(
"layer_id/{}"
.
format
(
layer_id
))
new_names
.
append
(
"layer_id/{}"
.
format
(
layer_id
))
...
@@ -83,9 +85,10 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
...
@@ -83,9 +85,10 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
new_names
.
append
(
name_dict
[
output_v
])
new_names
.
append
(
name_dict
[
output_v
])
else
:
else
:
if
i
==
0
and
layer
.
kernel
in
NN_KERNEL_NAME
.
keys
():
if
i
==
0
and
layer
.
kernel
in
NN_KERNEL_NAME
.
keys
():
new_name
=
NN_KERNEL_NAME
[
layer
.
kernel
]
+
str
(
nn_count_dict
[
layer
.
kernel
])
new_name
=
NN_KERNEL_NAME
[
layer
.
kernel
]
+
str
(
param_node
=
PamareterNode
(
old_name
=
layer
.
outputs
[
0
],
nn_count_dict
[
layer
.
kernel
])
new_name
=
new_name
)
param_node
=
PamareterNode
(
old_name
=
layer
.
outputs
[
0
],
new_name
=
new_name
)
nn_param_nodes
.
append
(
param_node
)
nn_param_nodes
.
append
(
param_node
)
if
param_tree
is
not
None
:
if
param_tree
is
not
None
:
param_tree
.
add_node
(
param_node
)
param_tree
.
add_node
(
param_node
)
...
@@ -94,7 +97,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
...
@@ -94,7 +97,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
elif
i
==
0
and
layer
.
kernel
==
"module"
:
elif
i
==
0
and
layer
.
kernel
==
"module"
:
if
is_rename_module
:
if
is_rename_module
:
if
param_tree
is
not
None
:
if
param_tree
is
not
None
:
param_node
=
param_tree
.
get_node
(
layer
.
outputs
[
0
])
param_node
=
param_tree
.
get_node
(
layer
.
outputs
[
0
])
nn_param_nodes
.
append
(
param_node
)
nn_param_nodes
.
append
(
param_node
)
param_node
.
new_name
=
layer
.
outputs
[
0
]
param_node
.
new_name
=
layer
.
outputs
[
0
]
else
:
else
:
...
@@ -105,7 +109,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
...
@@ -105,7 +109,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
nn_count_dict
[
old_name
]
+=
1
nn_count_dict
[
old_name
]
+=
1
new_name
=
old_name
+
str
(
nn_count_dict
[
old_name
])
new_name
=
old_name
+
str
(
nn_count_dict
[
old_name
])
if
param_tree
is
not
None
:
if
param_tree
is
not
None
:
param_node
=
param_tree
.
get_node
(
layer
.
outputs
[
0
])
param_node
=
param_tree
.
get_node
(
layer
.
outputs
[
0
])
nn_param_nodes
.
append
(
param_node
)
nn_param_nodes
.
append
(
param_node
)
param_node
.
new_name
=
new_name
param_node
.
new_name
=
new_name
layer
.
outputs
[
0
]
=
new_name
layer
.
outputs
[
0
]
=
new_name
...
@@ -116,8 +121,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
...
@@ -116,8 +121,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
layer
.
outputs
[
i
]
=
new_name
layer
.
outputs
[
i
]
=
new_name
name_dict
[
output_v
]
=
new_name
name_dict
[
output_v
]
=
new_name
if
layer
.
kernel
==
"self.create_parameter"
:
if
layer
.
kernel
==
"self.create_parameter"
:
param_node
=
PamareterNode
(
old_name
=
old_name
,
param_node
=
PamareterNode
(
new_name
=
new_name
)
old_name
=
old_name
,
new_name
=
new_name
)
nn_param_nodes
.
append
(
param_node
)
nn_param_nodes
.
append
(
param_node
)
if
param_tree
is
not
None
:
if
param_tree
is
not
None
:
param_tree
.
add_node
(
param_node
)
param_tree
.
add_node
(
param_node
)
...
@@ -129,6 +134,7 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
...
@@ -129,6 +134,7 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
and
attr_v
in
name_dict
:
and
attr_v
in
name_dict
:
layer
.
attrs
[
attr_k
]
=
name_dict
[
attr_v
]
layer
.
attrs
[
attr_k
]
=
name_dict
[
attr_v
]
return
count
return
count
rename_sub_layers
(
layers_cp
,
count
)
rename_sub_layers
(
layers_cp
,
count
)
return
layers_cp
,
nn_param_nodes
,
new_names
return
layers_cp
,
nn_param_nodes
,
new_names
...
@@ -152,22 +158,24 @@ def _update_attrs(layer, different_attrs):
...
@@ -152,22 +158,24 @@ def _update_attrs(layer, different_attrs):
common_attrs
.
update
(
special_attrs
)
common_attrs
.
update
(
special_attrs
)
layer
.
attrs
=
common_attrs
layer
.
attrs
=
common_attrs
def
gen_layer_code
(
graph
,
sub_layers
,
sub_layers_name
,
different_attrs
=
dict
()):
def
gen_layer_code
(
graph
,
sub_layers
,
sub_layers_name
,
different_attrs
=
dict
()):
""" 根据sub_layers生成对应的Module代码。
""" 根据sub_layers生成对应的Module代码。
Args:
Args:
graph (x2paddle.core.program.PaddleGraph): 整个Paddle图。
graph (x2paddle.core.program.PaddleGraph): 整个Paddle图。
sub_layers (dict): 子图的id和其对应layer组成的字典。
sub_layers (dict): 子图的id和其对应layer组成的字典。
sub_layers_name (str): 子图的名字。
sub_layers_name (str): 子图的名字。
different_attrs (dict/list): 属性字典/列表,这些属性表明在被调用时赋予不同值。
different_attrs (dict/list): 属性字典/列表,这些属性表明在被调用时赋予不同值。
"""
"""
def
gen_codes
(
code_list
,
indent
=
0
):
def
gen_codes
(
code_list
,
indent
=
0
):
""" 根据code_list生成代码段。
""" 根据code_list生成代码段。
Args:
Args:
code_list (list): 代码行组成的list。
code_list (list): 代码行组成的list。
indent (int): 每行空格的数量。
indent (int): 每行空格的数量。
Returns:
Returns:
str: 代码段。
str: 代码段。
"""
"""
...
@@ -179,10 +187,11 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
...
@@ -179,10 +187,11 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
else
:
else
:
codes
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
codes
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
return
codes
return
codes
def
gen_head
(
inputs
,
different_attrs
):
def
gen_head
(
inputs
,
different_attrs
):
# 生成Layer的头部代码
# 生成Layer的头部代码
head
=
gen_codes
([
"class {}(paddle.nn.Layer):"
.
format
(
sub_layers_name
)],
indent
=
0
)
head
=
gen_codes
(
[
"class {}(paddle.nn.Layer):"
.
format
(
sub_layers_name
)],
indent
=
0
)
# 生成init函数的头部代码
# 生成init函数的头部代码
diff_str_list
=
list
()
diff_str_list
=
list
()
if
isinstance
(
different_attrs
,
dict
):
if
isinstance
(
different_attrs
,
dict
):
...
@@ -199,8 +208,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
...
@@ -199,8 +208,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
forward_func_head
=
\
forward_func_head
=
\
gen_codes
([
"def forward(self, {}):"
.
format
(
input_data_name
)],
indent
=
1
)
gen_codes
([
"def forward(self, {}):"
.
format
(
input_data_name
)],
indent
=
1
)
return
head
,
init_func_head
,
forward_func_head
return
head
,
init_func_head
,
forward_func_head
init_func
=
[]
init_func
=
[]
forward_func
=
[]
forward_func
=
[]
cur_outputs
=
list
()
cur_outputs
=
list
()
...
@@ -211,7 +219,9 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
...
@@ -211,7 +219,9 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
for
layer_id
,
layer
in
sub_layers
.
items
():
for
layer_id
,
layer
in
sub_layers
.
items
():
if
layer_id
not
in
graph
.
edges_out
:
if
layer_id
not
in
graph
.
edges_out
:
for
index
,
output_name
in
enumerate
(
layer
.
outputs
):
for
index
,
output_name
in
enumerate
(
layer
.
outputs
):
if
layer
.
kernel
.
startswith
(
"paddle.nn"
)
and
index
==
0
:
if
layer
.
kernel
.
startswith
(
"paddle.nn"
)
and
index
==
0
and
"functional"
not
in
layer
.
kernel
:
continue
continue
if
not
output_name
.
startswith
(
"x"
)
or
output_name
in
outputs
\
if
not
output_name
.
startswith
(
"x"
)
or
output_name
in
outputs
\
or
layer
.
kernel
==
"prim.assert"
:
or
layer
.
kernel
==
"prim.assert"
:
...
@@ -225,7 +235,9 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
...
@@ -225,7 +235,9 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
for
out_layer_id
in
graph
.
edges_out
[
layer_id
]:
for
out_layer_id
in
graph
.
edges_out
[
layer_id
]:
if
out_layer_id
not
in
sub_layers
:
if
out_layer_id
not
in
sub_layers
:
for
index
,
output_name
in
enumerate
(
layer
.
outputs
):
for
index
,
output_name
in
enumerate
(
layer
.
outputs
):
if
layer
.
kernel
.
startswith
(
"paddle.nn"
)
and
index
==
0
and
"functional"
not
in
layer
.
kernel
:
if
layer
.
kernel
.
startswith
(
"paddle.nn"
)
and
index
==
0
and
"functional"
not
in
layer
.
kernel
:
continue
continue
if
not
output_name
.
startswith
(
"x"
)
or
output_name
in
outputs
\
if
not
output_name
.
startswith
(
"x"
)
or
output_name
in
outputs
\
or
layer
.
kernel
==
"prim.assert"
:
or
layer
.
kernel
==
"prim.assert"
:
...
@@ -263,17 +275,18 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
...
@@ -263,17 +275,18 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
line
=
line
.
strip
(
", "
)
line
=
line
.
strip
(
", "
)
line
+=
")"
line
+=
")"
init_func
.
extend
(
gen_codes
([
line
],
indent
=
2
))
init_func
.
extend
(
gen_codes
([
line
],
indent
=
2
))
if
len
(
layer
.
outputs
)
==
1
:
if
len
(
layer
.
outputs
)
==
1
:
line
=
layer
.
outputs
[
0
]
line
=
layer
.
outputs
[
0
]
elif
len
(
layer
.
outputs
)
==
2
:
elif
len
(
layer
.
outputs
)
==
2
:
line
=
layer
.
outputs
[
1
]
line
=
layer
.
outputs
[
1
]
else
:
else
:
if
layer
.
kernel
==
"paddle.nn.LSTM"
:
if
layer
.
kernel
==
"paddle.nn.LSTM"
:
line
=
"{}, ({})"
.
format
(
layer
.
outputs
[
1
],
', '
.
join
(
layer
.
outputs
[
-
2
:]))
line
=
"{}, ({})"
.
format
(
layer
.
outputs
[
1
],
', '
.
join
(
layer
.
outputs
[
-
2
:]))
else
:
else
:
line
=
','
.
join
(
layer
.
outputs
[
1
:])
line
=
','
.
join
(
layer
.
outputs
[
1
:])
line
+=
" = self.{}("
.
format
(
layer
.
outputs
[
0
])
line
+=
" = self.{}("
.
format
(
layer
.
outputs
[
0
])
for
k
,
v
in
layer
.
inputs
.
items
():
for
k
,
v
in
layer
.
inputs
.
items
():
if
v
not
in
cur_outputs
and
v
not
in
inputs
:
if
v
not
in
cur_outputs
and
v
not
in
inputs
:
...
@@ -299,15 +312,17 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
...
@@ -299,15 +312,17 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
indent
=
2
,
indent
=
2
,
init_func
=
init_func
,
init_func
=
init_func
,
forward_func
=
forward_func
,
forward_func
=
forward_func
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
different_attrs
=
list
(
different_attrs
.
keys
())
if
isinstance
(
different_attrs
,
dict
)
else
different_attrs
)
different_attrs
=
list
(
different_attrs
.
keys
())
if
isinstance
(
different_attrs
,
dict
)
else
different_attrs
)
cur_outputs
.
extend
(
layer
.
outputs
)
cur_outputs
.
extend
(
layer
.
outputs
)
else
:
else
:
raise
Exception
(
raise
Exception
(
"The kind {} in paddle model is not supported yet."
.
"The kind {} in paddle model is not supported yet."
.
format
(
format
(
layer
.
kernel
))
layer
.
kernel
))
elif
layer
.
kernel
==
"module"
:
elif
layer
.
kernel
==
"module"
:
line
=
"self.{} = {}("
.
format
(
layer
.
outputs
[
0
],
layer
.
attrs
[
"module"
])
line
=
"self.{} = {}("
.
format
(
layer
.
outputs
[
0
],
layer
.
attrs
[
"module"
])
layer
.
attrs
.
pop
(
"module"
)
layer
.
attrs
.
pop
(
"module"
)
for
k
,
v
in
layer
.
attrs
.
items
():
for
k
,
v
in
layer
.
attrs
.
items
():
key_name
=
"{}_{}"
.
format
(
layer
.
outputs
[
0
],
k
)
key_name
=
"{}_{}"
.
format
(
layer
.
outputs
[
0
],
k
)
...
@@ -358,23 +373,31 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
...
@@ -358,23 +373,31 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
key_name
=
"{}_{}"
.
format
(
layer
.
outputs
[
0
],
k
)
key_name
=
"{}_{}"
.
format
(
layer
.
outputs
[
0
],
k
)
if
key_name
in
different_attrs
:
if
key_name
in
different_attrs
:
line
+=
"{}=self.{}, "
.
format
(
k
,
key_name
)
line
+=
"{}=self.{}, "
.
format
(
k
,
key_name
)
init_func
.
extend
(
gen_codes
([
"self.{} = {}"
.
format
(
key_name
,
key_name
)],
indent
=
2
))
init_func
.
extend
(
gen_codes
(
[
"self.{} = {}"
.
format
(
key_name
,
key_name
)],
indent
=
2
))
else
:
else
:
line
+=
"{}={}, "
.
format
(
k
,
v
)
line
+=
"{}={}, "
.
format
(
k
,
v
)
line
=
line
.
strip
(
", "
)
line
=
line
.
strip
(
", "
)
line
+=
")"
line
+=
")"
if
layer
.
kernel
==
"self.create_parameter"
:
if
layer
.
kernel
==
"self.create_parameter"
:
init_func
.
extend
(
gen_codes
([
"self."
+
line
],
indent
=
2
))
init_func
.
extend
(
gen_codes
([
"self."
+
line
],
indent
=
2
))
forward_func
.
extend
(
gen_codes
([
"{} = self.{}"
.
format
(
layer
.
outputs
[
0
],
forward_func
.
extend
(
layer
.
outputs
[
0
])],
indent
=
2
))
gen_codes
(
[
"{} = self.{}"
.
format
(
layer
.
outputs
[
0
],
layer
.
outputs
[
0
])
],
indent
=
2
))
else
:
else
:
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
2
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
2
))
cur_outputs
.
extend
(
layer
.
outputs
)
cur_outputs
.
extend
(
layer
.
outputs
)
head
,
init_func_head
,
forward_func_head
=
gen_head
(
inputs
,
different_attrs
)
head
,
init_func_head
,
forward_func_head
=
gen_head
(
inputs
,
different_attrs
)
output_data_name
=
", "
.
join
(
outputs
)
output_data_name
=
", "
.
join
(
outputs
)
code_list
=
head
+
init_func_head
+
init_func
+
\
code_list
=
head
+
init_func_head
+
init_func
+
\
forward_func_head
+
forward_func
+
\
forward_func_head
+
forward_func
+
\
gen_codes
([
"return {}"
.
format
(
output_data_name
)],
indent
=
2
)
gen_codes
([
"return {}"
.
format
(
output_data_name
)],
indent
=
2
)
code_str
=
""
.
join
(
code_list
)
code_str
=
""
.
join
(
code_list
)
return
code_str
return
code_str
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录