Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
23662841
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
23662841
编写于
10月 26, 2017
作者:
Y
Yu Yang
提交者:
fengjiayi
10月 26, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Python API for save/load variables (#5136)
* Python API for save/load variables * Polish names
上级
8623e48b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
159 addition
and
2 deletion
+159
-2
python/paddle/v2/framework/executor.py
python/paddle/v2/framework/executor.py
+7
-2
python/paddle/v2/framework/framework.py
python/paddle/v2/framework/framework.py
+5
-0
python/paddle/v2/framework/io.py
python/paddle/v2/framework/io.py
+143
-0
python/paddle/v2/framework/tests/.gitignore
python/paddle/v2/framework/tests/.gitignore
+1
-0
python/paddle/v2/framework/tests/test_fit_a_line.py
python/paddle/v2/framework/tests/test_fit_a_line.py
+3
-0
未找到文件。
python/paddle/v2/framework/executor.py
浏览文件 @
23662841
...
@@ -19,11 +19,16 @@ class Executor(object):
...
@@ -19,11 +19,16 @@ class Executor(object):
def
run
(
self
,
def
run
(
self
,
program
,
program
,
feed
,
feed
=
None
,
fetch_list
,
fetch_list
=
None
,
feed_var_name
=
'feed'
,
feed_var_name
=
'feed'
,
fetch_var_name
=
'fetch'
,
fetch_var_name
=
'fetch'
,
scope
=
None
):
scope
=
None
):
if
feed
is
None
:
feed
=
{}
if
fetch_list
is
None
:
fetch_list
=
[]
if
not
isinstance
(
program
,
Program
):
if
not
isinstance
(
program
,
Program
):
raise
TypeError
()
raise
TypeError
()
...
...
python/paddle/v2/framework/framework.py
浏览文件 @
23662841
...
@@ -486,6 +486,11 @@ class Program(object):
...
@@ -486,6 +486,11 @@ class Program(object):
for
block
in
self
.
blocks
:
for
block
in
self
.
blocks
:
block
.
sync_with_cpp
()
block
.
sync_with_cpp
()
def
list_vars
(
self
):
for
each_block
in
self
.
blocks
:
for
each_var
in
each_block
.
vars
.
itervalues
():
yield
each_var
class
Parameter
(
Variable
):
class
Parameter
(
Variable
):
def
__init__
(
self
,
block
,
shape
,
dtype
,
**
kwargs
):
def
__init__
(
self
,
block
,
shape
,
dtype
,
**
kwargs
):
...
...
python/paddle/v2/framework/io.py
0 → 100644
浏览文件 @
23662841
import
os
from
paddle.v2.framework.framework
import
Program
,
Parameter
,
g_program
,
\
Variable
__all__
=
[
'save_vars'
,
'save_params'
,
'save_persistables'
,
'load_vars'
,
'load_params'
,
'load_persistables'
]
def
is_parameter
(
var
):
return
isinstance
(
var
,
Parameter
)
def
is_persistable
(
var
):
return
var
.
persistable
def
_clone_var_in_block_
(
block
,
var
):
assert
isinstance
(
var
,
Variable
)
return
block
.
create_var
(
name
=
var
.
name
,
shape
=
var
.
shape
,
dtype
=
var
.
data_type
,
type
=
var
.
type
,
lod_level
=
var
.
lod_level
,
persistable
=
True
)
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
,
predicate
=
None
):
"""
Save variables to directory by executor.
:param executor: executor that save variable
:param dirname: directory path
:param program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be saved.
:param vars: variables need to be saved. If specify vars, program & predicate
will be ignored
:return: None
"""
if
vars
is
None
:
if
program
is
None
:
program
=
g_program
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"program should be as Program type or None"
)
save_vars
(
executor
,
dirname
=
dirname
,
vars
=
filter
(
predicate
,
program
.
list_vars
()))
else
:
save_program
=
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
new_var
=
_clone_var_in_block_
(
save_block
,
each_var
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
new_var
.
name
)})
executor
.
run
(
save_program
)
def
save_params
(
executor
,
dirname
,
program
=
None
):
"""
Save all parameters to directory with executor.
"""
save_vars
(
executor
,
dirname
=
dirname
,
program
=
program
,
vars
=
None
,
predicate
=
is_parameter
)
def
save_persistables
(
executor
,
dirname
,
program
=
None
):
"""
Save all persistables to directory with executor.
"""
save_vars
(
executor
,
dirname
=
dirname
,
program
=
program
,
vars
=
None
,
predicate
=
is_persistable
)
def
load_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
,
predicate
=
None
):
"""
Load variables from directory by executor.
:param executor: executor that save variable
:param dirname: directory path
:param program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be loaded.
:param vars: variables need to be loaded. If specify vars, program &
predicate will be ignored
:return: None
"""
if
vars
is
None
:
if
program
is
None
:
program
=
g_program
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"program's type should be Program"
)
load_vars
(
executor
,
dirname
=
dirname
,
vars
=
filter
(
predicate
,
program
.
list_vars
()))
else
:
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
for
each_var
in
vars
:
assert
isinstance
(
each_var
,
Variable
)
new_var
=
_clone_var_in_block_
(
load_block
,
each_var
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
"Out"
:
[
new_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
new_var
.
name
)})
executor
.
run
(
load_prog
)
def
load_params
(
executor
,
dirname
,
program
=
None
):
"""
load all parameters from directory by executor.
"""
load_vars
(
executor
,
dirname
=
dirname
,
program
=
program
,
predicate
=
is_parameter
)
def
load_persistables
(
executor
,
dirname
,
program
=
None
):
"""
load all persistables from directory by executor.
"""
load_vars
(
executor
,
dirname
=
dirname
,
program
=
program
,
predicate
=
is_persistable
)
python/paddle/v2/framework/tests/.gitignore
浏览文件 @
23662841
image/
image/
fit_a_line.model/
python/paddle/v2/framework/tests/test_fit_a_line.py
浏览文件 @
23662841
...
@@ -4,6 +4,7 @@ import paddle.v2.framework.core as core
...
@@ -4,6 +4,7 @@ import paddle.v2.framework.core as core
import
paddle.v2.framework.optimizer
as
optimizer
import
paddle.v2.framework.optimizer
as
optimizer
from
paddle.v2.framework.framework
import
Program
,
g_program
from
paddle.v2.framework.framework
import
Program
,
g_program
from
paddle.v2.framework.io
import
save_persistables
,
load_persistables
from
paddle.v2.framework.executor
import
Executor
from
paddle.v2.framework.executor
import
Executor
import
numpy
as
np
import
numpy
as
np
...
@@ -51,6 +52,8 @@ exe.run(init_program, feed={}, fetch_list=[])
...
@@ -51,6 +52,8 @@ exe.run(init_program, feed={}, fetch_list=[])
PASS_NUM
=
100
PASS_NUM
=
100
for
pass_id
in
range
(
PASS_NUM
):
for
pass_id
in
range
(
PASS_NUM
):
save_persistables
(
exe
,
"./fit_a_line.model/"
,
program
=
program
)
load_persistables
(
exe
,
"./fit_a_line.model/"
,
program
=
program
)
for
data
in
train_reader
():
for
data
in
train_reader
():
x_data
=
np
.
array
(
map
(
lambda
x
:
x
[
0
],
data
)).
astype
(
"float32"
)
x_data
=
np
.
array
(
map
(
lambda
x
:
x
[
0
],
data
)).
astype
(
"float32"
)
y_data
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"float32"
)
y_data
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"float32"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录