Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c09de13e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c09de13e
编写于
8月 21, 2020
作者:
C
Chen Weihang
提交者:
GitHub
8月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine jit load model by extra_var_info (#26461)
* refine load model by extra_var_info * polish unittest for coverage
上级
83cd1859
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
98 addition
and
16 deletion
+98
-16
python/paddle/fluid/dygraph/io.py
python/paddle/fluid/dygraph/io.py
+25
-15
python/paddle/fluid/tests/unittests/test_jit_save_load.py
python/paddle/fluid/tests/unittests/test_jit_save_load.py
+73
-1
未找到文件。
python/paddle/fluid/dygraph/io.py
浏览文件 @
c09de13e
...
...
@@ -437,8 +437,16 @@ def _load_persistable_vars(model_path,
value
:
key
for
key
,
value
in
program_holder
.
_suffix_varname_dict
.
items
()
}
# NOTE: some var may not be Parameter
for
name
in
sorted
(
extra_var_info
):
# NOTE(chenweihang): we need load persistable vars based the program,
# because the program may be pruned when `save_inference_model`, some
# var in `extra_var_info` may have been pruned
for
name
in
sorted
(
inv_suffix_varname_dict
):
if
name
not
in
extra_var_info
:
raise
RuntimeError
(
"The model to be loaded is not complete."
"The variable `%s` of program cannot be found in loaded model."
,
name
)
# get suffix var name, see [why need to append suffix to persistable vars]
new_name
=
inv_suffix_varname_dict
[
name
]
# create output varbase
...
...
@@ -641,19 +649,21 @@ class TranslatedLayer(layers.Layer):
# name contains `.` originally, such as `linear_0.w_0`, so here
# need to generate new var name for each var
self
.
_persistable_var_name_dict
=
dict
()
for
name
,
var
in
persistable_vars
.
items
():
if
isinstance
(
var
,
framework
.
ParamBase
):
dy_name
=
_generate_unique_var_name
(
PARAMETER_NAME_PREFIX
)
self
.
_persistable_var_name_dict
[
name
]
=
dy_name
self
.
add_parameter
(
dy_name
,
var
)
elif
isinstance
(
var
,
core
.
VarBase
):
dy_name
=
_generate_unique_var_name
(
BUFFER_NAME_PREFIX
)
self
.
_persistable_var_name_dict
[
name
]
=
dy_name
self
.
register_buffer
(
dy_name
,
var
)
else
:
raise
TypeError
(
"Adding persistent variable which to layer is not supported now"
)
# the TranslatedLayer object holded var names count started from 0
with
unique_name
.
guard
():
for
name
,
var
in
persistable_vars
.
items
():
if
isinstance
(
var
,
framework
.
ParamBase
):
dy_name
=
_generate_unique_var_name
(
PARAMETER_NAME_PREFIX
)
self
.
_persistable_var_name_dict
[
name
]
=
dy_name
self
.
add_parameter
(
dy_name
,
var
)
elif
isinstance
(
var
,
core
.
VarBase
):
dy_name
=
_generate_unique_var_name
(
BUFFER_NAME_PREFIX
)
self
.
_persistable_var_name_dict
[
name
]
=
dy_name
self
.
register_buffer
(
dy_name
,
var
)
else
:
raise
TypeError
(
"Adding persistent variable which to layer is not supported now"
)
self
.
_is_test
=
True
...
...
python/paddle/fluid/tests/unittests/test_jit_save_load.py
浏览文件 @
c09de13e
...
...
@@ -15,6 +15,7 @@
from
__future__
import
print_function
import
os
import
pickle
import
unittest
import
numpy
as
np
...
...
@@ -25,7 +26,7 @@ from paddle.fluid.dygraph import declarative, ProgramTranslator
from
paddle.fluid.dygraph.io
import
VARIABLE_FILENAME
,
EXTRA_VAR_INFO_FILENAME
BATCH_SIZE
=
32
BATCH_NUM
=
2
0
BATCH_NUM
=
1
0
SEED
=
10
...
...
@@ -318,5 +319,76 @@ class TestJitMultipleLoading(unittest.TestCase):
name_set
.
add
(
var
.
name
)
class
LinearNetReturnHidden
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
in_size
,
out_size
):
super
(
LinearNetReturnHidden
,
self
).
__init__
()
self
.
_linear_1
=
Linear
(
in_size
,
out_size
)
self
.
_linear_2
=
Linear
(
in_size
,
out_size
)
@
declarative
def
forward
(
self
,
x
):
y
=
self
.
_linear_1
(
x
)
z
=
self
.
_linear_2
(
y
)
loss
=
fluid
.
layers
.
mean
(
z
)
return
y
,
loss
class
TestJitPruneModelAndLoad
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
linear_size
=
4
self
.
model_path
=
"model.jit_prune_model_and_load"
# enable dygraph mode
fluid
.
enable_dygraph
()
# config seed
fluid
.
default_main_program
().
random_seed
=
SEED
def
train_and_save
(
self
):
train_layer
=
LinearNetReturnHidden
(
8
,
8
)
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.1
,
parameter_list
=
train_layer
.
parameters
())
x
=
fluid
.
dygraph
.
to_variable
(
np
.
random
.
random
((
4
,
8
)).
astype
(
'float32'
))
for
i
in
range
(
10
):
hidden
,
loss
=
train_layer
(
x
)
loss
.
backward
()
adam
.
minimize
(
loss
)
train_layer
.
clear_gradients
()
configs
=
fluid
.
dygraph
.
jit
.
SaveLoadConfig
()
configs
.
output_spec
=
[
hidden
]
fluid
.
dygraph
.
jit
.
save
(
layer
=
train_layer
,
model_path
=
self
.
model_path
,
input_spec
=
[
x
],
configs
=
configs
)
return
train_layer
def
test_load_pruned_model
(
self
):
train_layer
=
self
.
train_and_save
()
train_layer
.
eval
()
infer_layer
=
fluid
.
dygraph
.
jit
.
load
(
self
.
model_path
)
x
=
fluid
.
dygraph
.
to_variable
(
np
.
random
.
random
((
4
,
8
)).
astype
(
'float32'
))
self
.
assertTrue
(
np
.
array_equal
(
train_layer
(
x
)[
0
].
numpy
(),
infer_layer
(
x
).
numpy
()))
def
test_load_var_not_in_extra_var_info
(
self
):
self
.
train_and_save
()
# chage extra var info
var_info_path
=
os
.
path
.
join
(
self
.
model_path
,
EXTRA_VAR_INFO_FILENAME
)
with
open
(
var_info_path
,
'rb'
)
as
f
:
extra_var_info
=
pickle
.
load
(
f
)
extra_var_info
.
clear
()
with
open
(
var_info_path
,
'wb'
)
as
f
:
pickle
.
dump
(
extra_var_info
,
f
,
protocol
=
2
)
with
self
.
assertRaises
(
RuntimeError
):
fluid
.
dygraph
.
jit
.
load
(
self
.
model_path
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录