Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
23662841
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
23662841
编写于
10月 26, 2017
作者:
Y
Yu Yang
提交者:
fengjiayi
10月 26, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Python API for save/load variables (#5136)
* Python API for save/load variables * Polish names
上级
8623e48b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
159 addition
and
2 deletion
+159
-2
python/paddle/v2/framework/executor.py
python/paddle/v2/framework/executor.py
+7
-2
python/paddle/v2/framework/framework.py
python/paddle/v2/framework/framework.py
+5
-0
python/paddle/v2/framework/io.py
python/paddle/v2/framework/io.py
+143
-0
python/paddle/v2/framework/tests/.gitignore
python/paddle/v2/framework/tests/.gitignore
+1
-0
python/paddle/v2/framework/tests/test_fit_a_line.py
python/paddle/v2/framework/tests/test_fit_a_line.py
+3
-0
未找到文件。
python/paddle/v2/framework/executor.py
浏览文件 @
23662841
...
...
@@ -19,11 +19,16 @@ class Executor(object):
def
run
(
self
,
program
,
feed
,
fetch_list
,
feed
=
None
,
fetch_list
=
None
,
feed_var_name
=
'feed'
,
fetch_var_name
=
'fetch'
,
scope
=
None
):
if
feed
is
None
:
feed
=
{}
if
fetch_list
is
None
:
fetch_list
=
[]
if
not
isinstance
(
program
,
Program
):
raise
TypeError
()
...
...
python/paddle/v2/framework/framework.py
浏览文件 @
23662841
...
...
@@ -486,6 +486,11 @@ class Program(object):
for
block
in
self
.
blocks
:
block
.
sync_with_cpp
()
def
list_vars
(
self
):
for
each_block
in
self
.
blocks
:
for
each_var
in
each_block
.
vars
.
itervalues
():
yield
each_var
class
Parameter
(
Variable
):
def
__init__
(
self
,
block
,
shape
,
dtype
,
**
kwargs
):
...
...
python/paddle/v2/framework/io.py
0 → 100644
浏览文件 @
23662841
import
os
from
paddle.v2.framework.framework
import
Program
,
Parameter
,
g_program
,
\
Variable
__all__
=
[
'save_vars'
,
'save_params'
,
'save_persistables'
,
'load_vars'
,
'load_params'
,
'load_persistables'
]
def
is_parameter
(
var
):
return
isinstance
(
var
,
Parameter
)
def
is_persistable
(
var
):
return
var
.
persistable
def
_clone_var_in_block_
(
block
,
var
):
assert
isinstance
(
var
,
Variable
)
return
block
.
create_var
(
name
=
var
.
name
,
shape
=
var
.
shape
,
dtype
=
var
.
data_type
,
type
=
var
.
type
,
lod_level
=
var
.
lod_level
,
persistable
=
True
)
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
,
predicate
=
None
):
"""
Save variables to directory by executor.
:param executor: executor that save variable
:param dirname: directory path
:param program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be saved.
:param vars: variables need to be saved. If specify vars, program & predicate
will be ignored
:return: None
"""
if
vars
is
None
:
if
program
is
None
:
program
=
g_program
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"program should be as Program type or None"
)
save_vars
(
executor
,
dirname
=
dirname
,
vars
=
filter
(
predicate
,
program
.
list_vars
()))
else
:
save_program
=
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
new_var
=
_clone_var_in_block_
(
save_block
,
each_var
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
new_var
.
name
)})
executor
.
run
(
save_program
)
def
save_params
(
executor
,
dirname
,
program
=
None
):
"""
Save all parameters to directory with executor.
"""
save_vars
(
executor
,
dirname
=
dirname
,
program
=
program
,
vars
=
None
,
predicate
=
is_parameter
)
def
save_persistables
(
executor
,
dirname
,
program
=
None
):
"""
Save all persistables to directory with executor.
"""
save_vars
(
executor
,
dirname
=
dirname
,
program
=
program
,
vars
=
None
,
predicate
=
is_persistable
)
def
load_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
,
predicate
=
None
):
"""
Load variables from directory by executor.
:param executor: executor that save variable
:param dirname: directory path
:param program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be loaded.
:param vars: variables need to be loaded. If specify vars, program &
predicate will be ignored
:return: None
"""
if
vars
is
None
:
if
program
is
None
:
program
=
g_program
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"program's type should be Program"
)
load_vars
(
executor
,
dirname
=
dirname
,
vars
=
filter
(
predicate
,
program
.
list_vars
()))
else
:
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
for
each_var
in
vars
:
assert
isinstance
(
each_var
,
Variable
)
new_var
=
_clone_var_in_block_
(
load_block
,
each_var
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
"Out"
:
[
new_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
new_var
.
name
)})
executor
.
run
(
load_prog
)
def
load_params
(
executor
,
dirname
,
program
=
None
):
"""
load all parameters from directory by executor.
"""
load_vars
(
executor
,
dirname
=
dirname
,
program
=
program
,
predicate
=
is_parameter
)
def
load_persistables
(
executor
,
dirname
,
program
=
None
):
"""
load all persistables from directory by executor.
"""
load_vars
(
executor
,
dirname
=
dirname
,
program
=
program
,
predicate
=
is_persistable
)
python/paddle/v2/framework/tests/.gitignore
浏览文件 @
23662841
image/
fit_a_line.model/
python/paddle/v2/framework/tests/test_fit_a_line.py
浏览文件 @
23662841
...
...
@@ -4,6 +4,7 @@ import paddle.v2.framework.core as core
import
paddle.v2.framework.optimizer
as
optimizer
from
paddle.v2.framework.framework
import
Program
,
g_program
from
paddle.v2.framework.io
import
save_persistables
,
load_persistables
from
paddle.v2.framework.executor
import
Executor
import
numpy
as
np
...
...
@@ -51,6 +52,8 @@ exe.run(init_program, feed={}, fetch_list=[])
PASS_NUM
=
100
for
pass_id
in
range
(
PASS_NUM
):
save_persistables
(
exe
,
"./fit_a_line.model/"
,
program
=
program
)
load_persistables
(
exe
,
"./fit_a_line.model/"
,
program
=
program
)
for
data
in
train_reader
():
x_data
=
np
.
array
(
map
(
lambda
x
:
x
[
0
],
data
)).
astype
(
"float32"
)
y_data
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"float32"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录