Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e5bb4edb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e5bb4edb
编写于
1月 15, 2021
作者:
W
WeiXin
提交者:
GitHub
1月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perfect 'var_list' of static.load/fluid.load (#30457)
上级
05f06d9a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
122 addition
and
1 deletion
+122
-1
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+8
-1
python/paddle/fluid/tests/unittests/test_static_save_load.py
python/paddle/fluid/tests/unittests/test_static_save_load.py
+114
-0
未找到文件。
python/paddle/fluid/io.py
浏览文件 @
e5bb4edb
...
...
@@ -1895,6 +1895,12 @@ def load(program, model_path, executor=None, var_list=None):
raise
ValueError
(
"executor is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
)
if
var_list
is
not
None
:
var_list_names
=
[
var
.
name
for
var
in
var_list
]
else
:
var_list_names
=
None
if
os
.
path
.
isdir
(
model_path
):
binary_file_set
=
set
()
for
root
,
dirs
,
files
in
os
.
walk
(
model_path
,
topdown
=
False
):
...
...
@@ -1905,7 +1911,8 @@ def load(program, model_path, executor=None, var_list=None):
loaded_var_list
=
[]
for
var
in
program_var_list
:
var_path
=
os
.
path
.
join
(
model_path
,
var
.
name
).
replace
(
"
\\
"
,
"/"
)
if
var_path
in
binary_file_set
:
load_condition
=
var_list_names
is
None
or
var
.
name
in
var_list_names
if
var_path
in
binary_file_set
and
load_condition
:
loaded_var_list
.
append
(
var
)
binary_file_set
.
remove
(
var_path
)
if
len
(
binary_file_set
)
>
0
:
...
...
python/paddle/fluid/tests/unittests/test_static_save_load.py
浏览文件 @
e5bb4edb
...
...
@@ -794,6 +794,9 @@ class TestLoadFromOldInterface(unittest.TestCase):
if
os
.
path
.
exists
(
"test_path.pdparams"
):
os
.
remove
(
"test_path.pdparams"
)
if
os
.
path
.
exists
(
"test_static_load_var_list.pdparams"
):
os
.
remove
(
"test_static_load_var_list.pdparams"
)
def
test_load_from_old_interface
(
self
):
seed
=
90
hidden_size
=
10
...
...
@@ -910,6 +913,117 @@ class TestLoadFromOldInterface(unittest.TestCase):
fluid
.
load
(
test_clone_program
,
"test_path"
,
exe
)
def
test_load_from_old_interface_var_list
(
self
):
seed
=
90
hidden_size
=
10
vocab_size
=
1000
num_layers
=
1
num_steps
=
3
init_scale
=
0.1
batch_size
=
4
batch_num
=
200
with
new_program_scope
():
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
ptb_model
=
PtbModel
(
"ptb_model"
,
hidden_size
=
hidden_size
,
vocab_size
=
vocab_size
,
num_layers
=
num_layers
,
num_steps
=
num_steps
,
init_scale
=
init_scale
)
place
=
fluid
.
CPUPlace
()
if
not
core
.
is_compiled_with_cuda
(
)
else
fluid
.
CUDAPlace
(
0
)
exe
=
fluid
.
Executor
(
place
)
sgd
=
Adam
(
learning_rate
=
1e-3
)
x
=
fluid
.
layers
.
data
(
name
=
"x"
,
shape
=
[
-
1
,
num_steps
],
dtype
=
'int64'
)
y
=
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
-
1
,
1
],
dtype
=
'float32'
)
init_hidden
=
fluid
.
layers
.
data
(
name
=
"init_hidden"
,
shape
=
[
1
],
dtype
=
'float32'
)
init_cell
=
fluid
.
layers
.
data
(
name
=
"init_cell"
,
shape
=
[
1
],
dtype
=
'float32'
)
static_loss
,
static_last_hidden
,
static_last_cell
=
ptb_model
(
x
,
y
,
init_hidden
,
init_cell
)
test_clone_program
=
fluid
.
default_main_program
().
clone
()
sgd
.
minimize
(
static_loss
)
static_param_updated
=
dict
()
static_param_init
=
dict
()
out
=
exe
.
run
(
framework
.
default_startup_program
())
static_loss_value
=
None
static_last_cell_value
=
None
static_last_hidden_value
=
None
for
i
in
range
(
batch_num
):
x_data
=
np
.
arange
(
12
).
reshape
(
4
,
3
).
astype
(
'int64'
)
y_data
=
np
.
arange
(
1
,
13
).
reshape
(
4
,
3
).
astype
(
'int64'
)
x_data
=
x_data
.
reshape
((
-
1
,
num_steps
,
1
))
y_data
=
y_data
.
reshape
((
-
1
,
1
))
init_hidden_data
=
np
.
zeros
(
(
num_layers
,
batch_size
,
hidden_size
),
dtype
=
'float32'
)
init_cell_data
=
np
.
zeros
(
(
num_layers
,
batch_size
,
hidden_size
),
dtype
=
'float32'
)
fetch_list
=
[
static_loss
,
static_last_hidden
,
static_last_cell
]
out
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"x"
:
x_data
,
"y"
:
y_data
,
"init_hidden"
:
init_hidden_data
,
"init_cell"
:
init_cell_data
},
fetch_list
=
fetch_list
)
static_loss_value
=
out
[
0
]
static_last_hidden_value
=
out
[
1
]
static_last_cell_value
=
out
[
2
]
# get value before save
main_program
=
framework
.
default_main_program
()
base_map
=
{}
for
var
in
main_program
.
list_vars
():
if
isinstance
(
var
,
framework
.
Parameter
)
or
var
.
persistable
:
t
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
var
.
name
)
.
get_tensor
())
# make sure all the paramerter or optimizer var have been update
self
.
assertTrue
(
np
.
sum
(
np
.
abs
(
t
))
!=
0
)
base_map
[
var
.
name
]
=
t
#fluid.save(main_program, "./test_1")
fluid
.
io
.
save_persistables
(
exe
,
"test_static_load_var_list"
,
main_program
)
# set var to zero
var_list
=
[]
for
i
,
var
in
enumerate
(
main_program
.
list_vars
()):
if
isinstance
(
var
,
framework
.
Parameter
)
or
var
.
persistable
:
if
i
%
2
==
0
:
var_list
.
append
(
var
)
ten
=
fluid
.
global_scope
().
find_var
(
var
.
name
).
get_tensor
()
ten
.
set
(
np
.
zeros_like
(
np
.
array
(
ten
)),
place
)
new_t
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
var
.
name
)
.
get_tensor
())
# make sure all the paramerter or optimizer var have been set to zero
self
.
assertTrue
(
np
.
sum
(
np
.
abs
(
new_t
))
==
0
)
fluid
.
load
(
main_program
,
"test_static_load_var_list"
,
exe
,
var_list
)
var_list_names
=
[
var
.
name
for
var
in
var_list
]
for
var
in
main_program
.
list_vars
():
if
isinstance
(
var
,
framework
.
Parameter
)
or
var
.
persistable
:
new_t
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
var
.
name
)
.
get_tensor
())
if
var
.
name
in
var_list_names
:
# loaded vars
base_t
=
base_map
[
var
.
name
]
self
.
assertTrue
(
np
.
array_equal
(
new_t
,
base_t
))
else
:
#not loaded vars
self
.
assertTrue
(
np
.
sum
(
np
.
abs
(
new_t
))
==
0
)
class
TestLoadFromOldInterfaceSingleFile
(
unittest
.
TestCase
):
def
test_load_from_old_interface
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录