Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
5bc026f0
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5bc026f0
编写于
8月 11, 2020
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update the program.py
上级
d8bb8920
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
208 addition
and
181 deletion
+208
-181
x2paddle/core/convert_prim.py
x2paddle/core/convert_prim.py
+136
-0
x2paddle/core/program.py
x2paddle/core/program.py
+72
-181
未找到文件。
x2paddle/core/convert_prim.py
0 → 100644
浏览文件 @
5bc026f0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
convert_prim
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
def
gen_codes
(
code_list
,
indent
=
0
):
indent_blank
=
" "
*
indent
codes
=
[]
for
code_line
in
code_list
:
if
code_line
.
strip
()
==
""
:
codes
.
append
(
'
\n
'
)
else
:
codes
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
return
codes
if
layer
.
kernel
==
"prim.if"
:
line
=
"if {} :"
.
format
(
list
(
layer
.
inputs
.
values
())[
0
])
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
block
=
layer
.
blocks
[
1
]
if
len
(
block
.
layers
)
>
0
:
line
=
"else:"
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
return
elif
layer
.
kernel
==
"prim.loop"
:
loop_range
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
loop_range
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
line
=
"for {} in range({}):"
.
format
(
layer
.
outputs
[
1
],
loop_range
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
init_func
.
extend
(
b_init_lines
)
forward_func
.
extend
(
b_forward_lines
)
return
elif
layer
.
kernel
==
"prim.equal"
:
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.constant"
:
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
attrs
[
"value"
])
elif
layer
.
kernel
==
"prim.list"
:
inputs_list
=
list
(
layer
.
inputs
.
values
())
for
i
,
input
in
enumerate
(
inputs_list
):
if
input
is
None
:
inputs_list
[
i
]
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
i
]])
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = [{}]"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
elif
layer
.
kernel
==
"prim.exception"
:
exception
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
exception
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
line
=
"raise RaiseException({})"
.
format
(
exception
)
elif
layer
.
kernel
==
"prim.min"
:
line
=
"{} = min({})"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.add"
:
line
=
"{} = {} + {} * {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
layer
.
attrs
[
"alpha"
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.append"
:
line
=
"{} = {}.append({})"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.shape"
:
line
=
"{} = {}.shape"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.len"
:
line
=
"{} = len({})"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.eq"
:
line
=
"{} = {} == {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.assert"
:
if
layer
.
attrs
[
"type"
]
==
"eq"
:
if
isinstance
(
layer
.
attrs
[
"value"
],
list
):
s
=
""
for
v
in
layer
.
attrs
[
"value"
]:
s
+=
"{} == {} or "
.
format
(
layer
.
attrs
[
"key"
],
v
)
if
len
(
s
)
>
0
:
s
=
s
[:
-
4
]
line
=
"assert {},
\'
The {} must be {}!
\'
"
.
format
(
s
,
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
])
else
:
line
=
"assert {} == {},
\'
The {} must be {}!
\'
"
.
format
(
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
],
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
])
else
:
raise
Exception
(
"Not implement yet!"
)
elif
layer
.
kernel
==
"prim.getitem"
:
item0
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
item0
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
item1
=
list
(
layer
.
inputs
.
values
())[
1
]
if
list
(
layer
.
inputs
.
values
())[
1
]
is
None
:
item1
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
1
]])
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
item0
,
item1
)
elif
layer
.
kernel
==
"prim.le"
:
item0
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
item0
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
item1
=
list
(
layer
.
inputs
.
values
())[
1
]
if
list
(
layer
.
inputs
.
values
())[
1
]
is
None
:
item1
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
1
]])
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
item0
,
item1
)
elif
layer
.
kernel
==
"prim.slice"
:
attrs_str
=
""
for
k
,
v
in
layer
.
attrs
.
items
():
attrs_str
+=
"{}:"
.
format
(
v
)
attrs_str
=
attrs_str
[:
-
1
]
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
attrs_str
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
x2paddle/core/program.py
浏览文件 @
5bc026f0
...
...
@@ -59,7 +59,7 @@ class PaddleLayer(object):
class
PaddleGraph
(
object
):
def
__init__
(
self
,
father_layer
=
None
):
def
__init__
(
self
,
father_layer
=
None
,
graph_type
=
"dygraph"
):
self
.
layers
=
OrderedDict
()
self
.
edges_out
=
dict
()
self
.
edges_in
=
dict
()
...
...
@@ -67,6 +67,7 @@ class PaddleGraph(object):
self
.
outputs
=
list
()
self
.
parameters
=
dict
()
self
.
father_layer
=
father_layer
self
.
graph_type
=
graph_type
def
set_name
(
self
,
name
):
self
.
name
=
name
...
...
@@ -128,6 +129,10 @@ class PaddleGraph(object):
if
len
(
layer
.
blocks
)
>
0
:
for
block
in
layer
.
blocks
:
block
.
build
(
layer
.
inputs
,
layer
.
outputs
)
if
self
.
graph_type
==
"dygraph"
:
self
.
get_dygraph_inputs
()
self
.
get_dygraph_outputs
()
def
get_global_layers
(
self
):
# 该全局layers的信息是按住奥拓扑排序组成的
...
...
@@ -265,145 +270,25 @@ class PaddleGraph(object):
param
.
tofile
(
fp
)
fp
.
close
()
def
convert_prim
(
self
,
layer
,
indent
=
1
):
def
gen_lines
(
code_list
,
indent
=
0
):
indent_blank
=
" "
*
indent
lines
=
[]
for
code_line
in
code_list
:
if
code_line
.
strip
()
==
""
:
lines
.
append
(
'
\n
'
)
else
:
lines
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
return
lines
if
layer
.
kernel
==
"prim.if"
:
line
=
"if {} :"
.
format
(
list
(
layer
.
inputs
.
values
())[
0
])
self
.
forward_lines
.
extend
(
gen_lines
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
self
.
init_lines
.
extend
(
b_init_lines
)
self
.
forward_lines
.
extend
(
b_forward_lines
)
block
=
layer
.
blocks
[
1
]
if
len
(
block
.
layers
)
>
0
:
line
=
"else:"
self
.
forward_lines
.
extend
(
gen_lines
([
line
],
indent
=
indent
))
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
self
.
init_lines
.
extend
(
b_init_lines
)
self
.
forward_lines
.
extend
(
b_forward_lines
)
return
elif
layer
.
kernel
==
"prim.loop"
:
loop_range
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
loop_range
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
line
=
"for {} in range({}):"
.
format
(
layer
.
outputs
[
1
],
loop_range
)
self
.
forward_lines
.
extend
(
gen_lines
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
self
.
init_lines
.
extend
(
b_init_lines
)
self
.
forward_lines
.
extend
(
b_forward_lines
)
return
elif
layer
.
kernel
==
"prim.equal"
:
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.constant"
:
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
attrs
[
"value"
])
elif
layer
.
kernel
==
"prim.list"
:
inputs_list
=
list
(
layer
.
inputs
.
values
())
for
i
,
input
in
enumerate
(
inputs_list
):
if
input
is
None
:
inputs_list
[
i
]
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
i
]])
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = [{}]"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
elif
layer
.
kernel
==
"prim.exception"
:
exception
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
exception
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
line
=
"raise RaiseException({})"
.
format
(
exception
)
elif
layer
.
kernel
==
"prim.min"
:
line
=
"{} = min({})"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.add"
:
line
=
"{} = {} + {} * {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
layer
.
attrs
[
"alpha"
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.append"
:
line
=
"{} = {}.append({})"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.shape"
:
line
=
"{} = {}.shape"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.len"
:
line
=
"{} = len({})"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
])
elif
layer
.
kernel
==
"prim.eq"
:
line
=
"{} = {} == {}"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
list
(
layer
.
inputs
.
values
())[
1
])
elif
layer
.
kernel
==
"prim.assert"
:
if
layer
.
attrs
[
"type"
]
==
"eq"
:
if
isinstance
(
layer
.
attrs
[
"value"
],
list
):
s
=
""
for
v
in
layer
.
attrs
[
"value"
]:
s
+=
"{} == {} or "
.
format
(
layer
.
attrs
[
"key"
],
v
)
if
len
(
s
)
>
0
:
s
=
s
[:
-
4
]
line
=
"assert {},
\'
The {} must be {}!
\'
"
.
format
(
s
,
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
])
else
:
line
=
"assert {} == {},
\'
The {} must be {}!
\'
"
.
format
(
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
],
layer
.
attrs
[
"key"
],
layer
.
attrs
[
"value"
])
else
:
raise
Exception
(
"Not implement yet!"
)
elif
layer
.
kernel
==
"prim.getitem"
:
item0
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
item0
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
item1
=
list
(
layer
.
inputs
.
values
())[
1
]
if
list
(
layer
.
inputs
.
values
())[
1
]
is
None
:
item1
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
1
]])
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
item0
,
item1
)
elif
layer
.
kernel
==
"prim.le"
:
item0
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
item0
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
item1
=
list
(
layer
.
inputs
.
values
())[
1
]
if
list
(
layer
.
inputs
.
values
())[
1
]
is
None
:
item1
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
1
]])
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
item0
,
item1
)
elif
layer
.
kernel
==
"prim.slice"
:
attrs_str
=
""
for
k
,
v
in
layer
.
attrs
.
items
():
attrs_str
+=
"{}:"
.
format
(
v
)
attrs_str
=
attrs_str
[:
-
1
]
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
list
(
layer
.
inputs
.
values
())[
0
],
attrs_str
)
self
.
forward_lines
.
extend
(
gen_lines
([
line
],
indent
=
indent
))
return
def
get_dygraph_inputs
(
self
,
layers
):
for
layer_id
,
layer
in
layers
.
items
():
if
self
.
edges_in
.
get
(
layer_id
,
0
)
==
0
and
self
.
edges_out
.
get
(
layer_id
,
0
)
==
0
:
continue
if
layer
.
kernel
==
"fluid.dygraph.base.to_variable"
:
value
=
layer
.
attrs
[
"value"
]
if
not
value
.
startswith
(
"params["
):
self
.
inputs
.
append
(
value
)
if
len
(
layer
.
blocks
)
>
0
:
for
block
in
layer
.
blocks
:
block
.
get_dygraph_inputs
(
block
.
layers
)
self
.
inputs
.
extend
(
block
.
inputs
)
def
get_dygraph_outputs
(
self
,
layers
):
for
layer_id
,
layer
in
layers
.
items
():
def
get_dygraph_inputs
(
self
):
def
update
(
layers
):
for
layer_id
,
layer
in
layers
.
items
():
if
self
.
edges_in
.
get
(
layer_id
,
0
)
==
0
and
self
.
edges_out
.
get
(
layer_id
,
0
)
==
0
:
continue
if
layer
.
kernel
==
"fluid.dygraph.base.to_variable"
:
value
=
layer
.
attrs
[
"value"
]
if
not
value
.
startswith
(
"params["
):
self
.
inputs
.
append
(
value
)
if
len
(
layer
.
blocks
)
>
0
:
for
block
in
layer
.
blocks
:
block
.
get_dygraph_inputs
()
self
.
inputs
.
extend
(
block
.
inputs
)
update
(
self
.
layers
)
self
.
inputs
=
list
(
set
(
self
.
inputs
))
def
get_dygraph_outputs
(
self
):
for
layer_id
,
layer
in
self
.
layers
.
items
():
if
self
.
edges_in
.
get
(
layer_id
,
0
)
==
0
and
self
.
edges_out
.
get
(
layer_id
,
0
)
==
0
:
continue
...
...
@@ -413,24 +298,21 @@ class PaddleGraph(object):
"_assert"
)
or
not
output_name
.
startswith
(
"x"
):
continue
self
.
outputs
.
append
(
output_name
)
self
.
outputs
=
list
(
set
(
self
.
outputs
))
def
gen_dygraph_code
(
self
,
code_dir
=
None
,
indent
=
2
):
def
gen_
lin
es
(
code_list
,
indent
=
0
):
def
gen_
cod
es
(
code_list
,
indent
=
0
):
indent_blank
=
" "
*
indent
lin
es
=
[]
cod
es
=
[]
for
code_line
in
code_list
:
if
code_line
.
strip
()
==
""
:
lin
es
.
append
(
'
\n
'
)
cod
es
.
append
(
'
\n
'
)
else
:
lines
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
return
lines
self
.
init_lines
=
[]
# forward_func
self
.
forward_lines
=
[]
# def gen_head
if
indent
==
2
and
code_dir
is
not
None
:
start_lines
=
gen_lines
(
codes
.
append
(
indent_blank
+
code_line
+
'
\n
'
)
return
codes
def
gen_head
():
self
.
head
=
gen_codes
(
[
"from paddle.fluid.initializer import Constant"
,
"from paddle.fluid.param_attr import ParamAttr"
,
...
...
@@ -439,18 +321,39 @@ class PaddleGraph(object):
"class {}(fluid.dygraph.Layer):"
.
format
(
self
.
name
),
],
indent
=
0
)
self
.
get_dygraph_inputs
(
self
.
layers
)
input_data_name
=
', '
.
join
(
self
.
inputs
)
self
.
init_
lines
.
extend
(
gen_
lin
es
(
self
.
init_
func
.
extend
(
gen_
cod
es
(
[
"def __init__(self, params):"
],
indent
=
1
))
self
.
init_
lines
.
extend
(
gen_
lin
es
(
self
.
init_
func
.
extend
(
gen_
cod
es
(
[
"super({}, self).__init__()"
.
format
(
self
.
name
)],
indent
=
2
))
self
.
forward_
lines
.
extend
(
gen_
lin
es
(
self
.
forward_
func
.
extend
(
gen_
cod
es
(
[
"def forward(self, {}):"
.
format
(
input_data_name
)],
indent
=
1
))
def
write_code
(
code_dir
):
f
=
open
(
os
.
path
.
join
(
code_dir
,
'code.py'
),
'w'
)
for
code_line
in
self
.
head
:
f
.
write
(
code_line
)
init_writen_codes
=
[]
for
code_line
in
self
.
init_func
:
if
code_line
in
init_writen_codes
:
continue
f
.
write
(
code_line
)
init_writen_codes
.
append
(
code_line
)
f
.
write
(
"
\n
"
)
return_code
=
"return {}"
.
format
(
", "
.
join
(
self
.
outputs
))
self
.
forward_func
.
extend
(
gen_codes
([
return_code
],
indent
=
2
))
for
code_line
in
self
.
forward_func
:
f
.
write
(
code_line
)
f
.
close
()
self
.
init_func
=
[]
self
.
forward_func
=
[]
if
indent
==
2
and
code_dir
is
not
None
:
gen_head
()
for
layer_id
,
layer
in
self
.
layers
.
items
():
if
self
.
edges_in
.
get
(
layer_id
,
0
)
==
0
and
self
.
edges_out
.
get
(
...
...
@@ -470,10 +373,10 @@ class PaddleGraph(object):
if
layer
.
kernel
==
"fluid.dygraph.base.to_variable"
and
not
layer
.
attrs
[
"value"
].
startswith
(
"params["
):
self
.
forward_
lines
.
extend
(
gen_lin
es
([
line
],
indent
=
indent
))
self
.
forward_
func
.
extend
(
gen_cod
es
([
line
],
indent
=
indent
))
continue
else
:
self
.
init_
lines
.
extend
(
gen_lin
es
([
line
],
indent
=
2
))
self
.
init_
func
.
extend
(
gen_cod
es
([
line
],
indent
=
2
))
if
len
(
layer
.
outputs
)
==
1
:
line
=
layer
.
outputs
[
0
]
...
...
@@ -490,9 +393,12 @@ class PaddleGraph(object):
line
+=
"{}, "
.
format
(
v
)
line
=
line
.
strip
(
", "
)
line
+=
")"
self
.
forward_
lines
.
extend
(
gen_lin
es
([
line
],
indent
=
indent
))
self
.
forward_
func
.
extend
(
gen_cod
es
([
line
],
indent
=
indent
))
elif
"prim"
in
layer
.
kernel
:
self
.
convert_prim
(
layer
,
indent
=
indent
)
from
.convert_prim
import
convert_prim
convert_prim
(
layer
,
indent
=
indent
,
init_func
=
self
.
init_func
,
forward_func
=
self
.
forward_func
)
else
:
if
len
(
layer
.
outputs
)
==
1
:
line
=
layer
.
outputs
[
0
]
...
...
@@ -505,26 +411,11 @@ class PaddleGraph(object):
line
+=
"{}={}, "
.
format
(
k
,
v
)
line
=
line
.
strip
(
", "
)
line
+=
")"
self
.
forward_
lines
.
extend
(
gen_lin
es
([
line
],
indent
=
indent
))
self
.
forward_
func
.
extend
(
gen_cod
es
([
line
],
indent
=
indent
))
if
indent
==
2
:
f
=
open
(
os
.
path
.
join
(
code_dir
,
'code.py'
),
'w'
)
for
line
in
start_lines
:
f
.
write
(
line
)
init_writen_line
=
[]
for
line
in
self
.
init_lines
:
if
line
in
init_writen_line
:
continue
f
.
write
(
line
)
init_writen_line
.
append
(
line
)
f
.
write
(
"
\n
"
)
self
.
get_dygraph_outputs
(
self
.
layers
)
return_line
=
"return {}"
.
format
(
", "
.
join
(
self
.
outputs
))
self
.
forward_lines
.
extend
(
gen_lines
([
return_line
],
indent
=
2
))
for
line
in
self
.
forward_lines
:
f
.
write
(
line
)
f
.
close
()
write_code
(
code_dir
)
else
:
return
self
.
init_
lines
,
self
.
forward_lines
return
self
.
init_
func
,
self
.
forward_func
def
dump_dygraph_parameter
(
self
,
code_dir
):
params_output
=
open
(
os
.
path
.
join
(
code_dir
,
'model.pdparams'
),
'wb'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录