Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
6a78f20a
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6a78f20a
编写于
12月 03, 2020
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the program
上级
55bfc88f
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
20 addition
and
21 deletion
+20
-21
x2paddle/core/program.py
x2paddle/core/program.py
+20
-21
未找到文件。
x2paddle/core/program.py
浏览文件 @
6a78f20a
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
from
__future__
import
print_function
from
__future__
import
print_function
from
__future__
import
division
from
__future__
import
division
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
os.path
as
osp
import
paddle
import
paddle
from
paddle.fluid.proto
import
framework_pb2
from
paddle.fluid.proto
import
framework_pb2
from
collections
import
OrderedDict
from
collections
import
OrderedDict
...
@@ -26,6 +25,7 @@ import os
...
@@ -26,6 +25,7 @@ import os
import
six
import
six
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
from
os
import
path
as
osp
class
PaddleLayer
(
object
):
class
PaddleLayer
(
object
):
...
@@ -232,7 +232,7 @@ class PaddleGraph(object):
...
@@ -232,7 +232,7 @@ class PaddleGraph(object):
return
update
(
self
.
layers
)
return
update
(
self
.
layers
)
def
gen_model
(
self
,
save_dir
,
jit_type
=
None
):
def
gen_model
(
self
,
save_dir
,
jit_type
=
None
):
if
not
os
.
path
.
exists
(
save_dir
):
if
not
os
p
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
os
.
makedirs
(
save_dir
)
if
self
.
graph_type
==
"static"
:
if
self
.
graph_type
==
"static"
:
self
.
gen_static_model
(
save_dir
)
self
.
gen_static_model
(
save_dir
)
...
@@ -240,8 +240,8 @@ class PaddleGraph(object):
...
@@ -240,8 +240,8 @@ class PaddleGraph(object):
self
.
gen_dygraph_model
(
save_dir
,
jit_type
)
self
.
gen_dygraph_model
(
save_dir
,
jit_type
)
def
gen_static_model
(
self
,
save_dir
):
def
gen_static_model
(
self
,
save_dir
):
code_dir
=
os
.
path
.
join
(
save_dir
,
'model_with_code'
)
code_dir
=
os
p
.
join
(
save_dir
,
'model_with_code'
)
infer_dir
=
os
.
path
.
join
(
save_dir
,
'inference_model'
)
infer_dir
=
os
p
.
join
(
save_dir
,
'inference_model'
)
self
.
gen_static_code
(
code_dir
)
self
.
gen_static_code
(
code_dir
)
sys
.
path
.
append
(
code_dir
)
sys
.
path
.
append
(
code_dir
)
import
x2paddle_model
import
x2paddle_model
...
@@ -254,13 +254,13 @@ class PaddleGraph(object):
...
@@ -254,13 +254,13 @@ class PaddleGraph(object):
inputs
,
outputs
=
x2paddle_model
.
x2paddle_net
()
inputs
,
outputs
=
x2paddle_model
.
x2paddle_net
()
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
startup_program
)
exe
.
run
(
startup_program
)
param_dir
=
os
.
path
.
join
(
code_dir
,
'weights'
)
param_dir
=
os
p
.
join
(
code_dir
,
'weights'
)
for
k
,
v
in
self
.
parameters
.
items
():
for
k
,
v
in
self
.
parameters
.
items
():
if
scope
.
find_var
(
k
):
if
scope
.
find_var
(
k
):
self
.
dump_parameter
(
k
,
v
,
param_dir
)
self
.
dump_parameter
(
k
,
v
,
param_dir
)
def
if_exist
(
var
):
def
if_exist
(
var
):
b
=
os
.
path
.
exists
(
b
=
os
p
.
exists
(
os
.
path
.
join
(
os
.
path
.
join
(
param_dir
,
var
.
name
)))
os
p
.
join
(
osp
.
join
(
param_dir
,
var
.
name
)))
return
b
return
b
fluid
.
io
.
load_vars
(
fluid
.
io
.
load_vars
(
exe
,
param_dir
,
main_program
,
predicate
=
if_exist
)
exe
,
param_dir
,
main_program
,
predicate
=
if_exist
)
...
@@ -282,6 +282,8 @@ class PaddleGraph(object):
...
@@ -282,6 +282,8 @@ class PaddleGraph(object):
self
.
gen_dygraph_code
(
save_dir
)
self
.
gen_dygraph_code
(
save_dir
)
self
.
dump_dygraph_parameter
(
save_dir
)
self
.
dump_dygraph_parameter
(
save_dir
)
# 动转静
# 动转静
code_path
=
osp
.
join
(
osp
.
abspath
(
save_dir
),
"x2paddle_code.py"
)
print
(
"Exporting inference model from python code ('{}')...
\n
"
.
format
(
code_path
))
if
len
(
self
.
inputs_info
)
>
0
:
if
len
(
self
.
inputs_info
)
>
0
:
input_shapes
=
list
()
input_shapes
=
list
()
input_types
=
list
()
input_types
=
list
()
...
@@ -290,13 +292,10 @@ class PaddleGraph(object):
...
@@ -290,13 +292,10 @@ class PaddleGraph(object):
input_types
.
append
(
self
.
inputs_info
[
input_name
][
1
])
input_types
.
append
(
self
.
inputs_info
[
input_name
][
1
])
try
:
try
:
self
.
dygraph2static
(
save_dir
,
input_shapes
,
input_types
)
self
.
dygraph2static
(
save_dir
,
input_shapes
,
input_types
)
except
Error
as
e
:
except
Exception
as
e
:
print
(
"The Dygraph2Static is failed! The possible reason are:
\n
"
+
print
(
"Fail to generate inference model! Problem happend while export inference model from python code '{}';
\n
"
.
format
(
coda_path
))
"1. The convertor of dygraph2static of current model is not supported yet.
\n
"
+
print
(
"===================Error Information==============="
)
"2. The convertor of pytorch2paddle is wrong. You can run the code of x2paddle_model.py in your save_dir to check whether the convertor of pytorch2paddle is wrong.
\n
"
+
raise
e
"The Error is:
\n
"
+
e
)
exit
(
0
)
def
gen_static_code
(
self
,
code_dir
):
def
gen_static_code
(
self
,
code_dir
):
def
write_code
(
f
,
code_list
,
indent
=
0
):
def
write_code
(
f
,
code_list
,
indent
=
0
):
...
@@ -307,9 +306,9 @@ class PaddleGraph(object):
...
@@ -307,9 +306,9 @@ class PaddleGraph(object):
else
:
else
:
f
.
write
(
indent_blank
+
code_line
+
'
\n
'
)
f
.
write
(
indent_blank
+
code_line
+
'
\n
'
)
if
not
os
.
path
.
exists
(
code_dir
):
if
not
os
p
.
exists
(
code_dir
):
os
.
makedirs
(
code_dir
)
os
.
makedirs
(
code_dir
)
f
=
open
(
os
.
path
.
join
(
code_dir
,
'x2paddle_model.py'
),
'w'
)
f
=
open
(
os
p
.
join
(
code_dir
,
'x2paddle_model.py'
),
'w'
)
write_code
(
write_code
(
f
,
[
f
,
[
...
@@ -372,7 +371,7 @@ class PaddleGraph(object):
...
@@ -372,7 +371,7 @@ class PaddleGraph(object):
def
dump_parameter
(
self
,
param_name
,
param
,
save_dir
):
def
dump_parameter
(
self
,
param_name
,
param
,
save_dir
):
if
not
os
.
path
.
exists
(
save_dir
):
if
not
os
p
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
os
.
makedirs
(
save_dir
)
dtype_map
=
{
dtype_map
=
{
"int16"
:
[
framework_pb2
.
VarType
.
INT16
,
'h'
],
"int16"
:
[
framework_pb2
.
VarType
.
INT16
,
'h'
],
...
@@ -392,7 +391,7 @@ class PaddleGraph(object):
...
@@ -392,7 +391,7 @@ class PaddleGraph(object):
assert
str
(
assert
str
(
param
.
dtype
)
in
dtype_map
,
"Unknown dtype {} of params: {}."
.
format
(
param
.
dtype
)
in
dtype_map
,
"Unknown dtype {} of params: {}."
.
format
(
str
(
param
.
dtype
),
param_name
)
str
(
param
.
dtype
),
param_name
)
fp
=
open
(
os
.
path
.
join
(
save_dir
,
param_name
),
'wb'
)
fp
=
open
(
os
p
.
join
(
save_dir
,
param_name
),
'wb'
)
numpy
.
array
([
0
],
dtype
=
'int32'
).
tofile
(
fp
)
numpy
.
array
([
0
],
dtype
=
'int32'
).
tofile
(
fp
)
numpy
.
array
([
0
],
dtype
=
'int64'
).
tofile
(
fp
)
numpy
.
array
([
0
],
dtype
=
'int64'
).
tofile
(
fp
)
numpy
.
array
([
0
],
dtype
=
'int32'
).
tofile
(
fp
)
numpy
.
array
([
0
],
dtype
=
'int32'
).
tofile
(
fp
)
...
@@ -502,7 +501,7 @@ class PaddleGraph(object):
...
@@ -502,7 +501,7 @@ class PaddleGraph(object):
use_structured_name
=
False
if
self
.
source_type
in
[
"tf"
,
"onnx"
]
else
True
use_structured_name
=
False
if
self
.
source_type
in
[
"tf"
,
"onnx"
]
else
True
self
.
run_func
.
extend
(
self
.
run_func
.
extend
(
gen_codes
([
"paddle.disable_static()"
,
gen_codes
([
"paddle.disable_static()"
,
"params = paddle.load('{}/model.pdparams')"
.
format
(
os
.
path
.
abspath
(
code_dir
)),
"params = paddle.load('{}/model.pdparams')"
.
format
(
os
p
.
abspath
(
code_dir
)),
"model = {}()"
.
format
(
self
.
name
),
"model = {}()"
.
format
(
self
.
name
),
"model.set_dict(params, use_structured_name={})"
.
format
(
use_structured_name
),
"model.set_dict(params, use_structured_name={})"
.
format
(
use_structured_name
),
"model.eval()"
,
"model.eval()"
,
...
@@ -510,7 +509,7 @@ class PaddleGraph(object):
...
@@ -510,7 +509,7 @@ class PaddleGraph(object):
"return out"
],
indent
=
1
))
"return out"
],
indent
=
1
))
def
write_code
(
code_dir
):
def
write_code
(
code_dir
):
f
=
open
(
os
.
path
.
join
(
code_dir
,
'x2paddle_code.py'
),
'w'
)
f
=
open
(
os
p
.
join
(
code_dir
,
'x2paddle_code.py'
),
'w'
)
for
code_line
in
self
.
head
:
for
code_line
in
self
.
head
:
f
.
write
(
code_line
)
f
.
write
(
code_line
)
init_writen_codes
=
[]
init_writen_codes
=
[]
...
@@ -622,7 +621,7 @@ class PaddleGraph(object):
...
@@ -622,7 +621,7 @@ class PaddleGraph(object):
return
self
.
init_func
,
self
.
forward_func
return
self
.
init_func
,
self
.
forward_func
def
dump_dygraph_parameter
(
self
,
code_dir
):
def
dump_dygraph_parameter
(
self
,
code_dir
):
save_path
=
os
.
path
.
join
(
code_dir
,
'model.pdparams'
)
save_path
=
os
p
.
join
(
code_dir
,
'model.pdparams'
)
paddle
.
save
(
self
.
parameters
,
save_path
)
paddle
.
save
(
self
.
parameters
,
save_path
)
def
dygraph2static
(
self
,
save_dir
,
input_shapes
=
[],
input_types
=
[]):
def
dygraph2static
(
self
,
save_dir
,
input_shapes
=
[],
input_types
=
[]):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录