提交 23662841 编写于 作者: Y Yu Yang 提交者: fengjiayi

Python API for save/load variables (#5136)

* Python API for save/load variables

* Polish names
上级 8623e48b
......@@ -19,11 +19,16 @@ class Executor(object):
def run(self,
program,
feed,
fetch_list,
feed=None,
fetch_list=None,
feed_var_name='feed',
fetch_var_name='fetch',
scope=None):
if feed is None:
feed = {}
if fetch_list is None:
fetch_list = []
if not isinstance(program, Program):
raise TypeError()
......
......@@ -486,6 +486,11 @@ class Program(object):
for block in self.blocks:
block.sync_with_cpp()
def list_vars(self):
for each_block in self.blocks:
for each_var in each_block.vars.itervalues():
yield each_var
class Parameter(Variable):
def __init__(self, block, shape, dtype, **kwargs):
......
import os
from paddle.v2.framework.framework import Program, Parameter, g_program, \
Variable
__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables'
]
def is_parameter(var):
return isinstance(var, Parameter)
def is_persistable(var):
return var.persistable
def _clone_var_in_block_(block, var):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.data_type,
type=var.type,
lod_level=var.lod_level,
persistable=True)
def save_vars(executor, dirname, program=None, vars=None, predicate=None):
"""
Save variables to directory by executor.
:param executor: executor that save variable
:param dirname: directory path
:param program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be saved.
:param vars: variables need to be saved. If specify vars, program & predicate
will be ignored
:return: None
"""
if vars is None:
if program is None:
program = g_program
if not isinstance(program, Program):
raise TypeError("program should be as Program type or None")
save_vars(
executor,
dirname=dirname,
vars=filter(predicate, program.list_vars()))
else:
save_program = Program()
save_block = save_program.global_block()
for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)})
executor.run(save_program)
def save_params(executor, dirname, program=None):
"""
Save all parameters to directory with executor.
"""
save_vars(
executor,
dirname=dirname,
program=program,
vars=None,
predicate=is_parameter)
def save_persistables(executor, dirname, program=None):
"""
Save all persistables to directory with executor.
"""
save_vars(
executor,
dirname=dirname,
program=program,
vars=None,
predicate=is_persistable)
def load_vars(executor, dirname, program=None, vars=None, predicate=None):
"""
Load variables from directory by executor.
:param executor: executor that save variable
:param dirname: directory path
:param program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be loaded.
:param vars: variables need to be loaded. If specify vars, program &
predicate will be ignored
:return: None
"""
if vars is None:
if program is None:
program = g_program
if not isinstance(program, Program):
raise TypeError("program's type should be Program")
load_vars(
executor,
dirname=dirname,
vars=filter(predicate, program.list_vars()))
else:
load_prog = Program()
load_block = load_prog.global_block()
for each_var in vars:
assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var)
load_block.append_op(
type='load',
inputs={},
outputs={"Out": [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)})
executor.run(load_prog)
def load_params(executor, dirname, program=None):
"""
load all parameters from directory by executor.
"""
load_vars(
executor, dirname=dirname, program=program, predicate=is_parameter)
def load_persistables(executor, dirname, program=None):
"""
load all persistables from directory by executor.
"""
load_vars(
executor, dirname=dirname, program=program, predicate=is_persistable)
......@@ -4,6 +4,7 @@ import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_program
from paddle.v2.framework.io import save_persistables, load_persistables
from paddle.v2.framework.executor import Executor
import numpy as np
......@@ -51,6 +52,8 @@ exe.run(init_program, feed={}, fetch_list=[])
PASS_NUM = 100
for pass_id in range(PASS_NUM):
save_persistables(exe, "./fit_a_line.model/", program=program)
load_persistables(exe, "./fit_a_line.model/", program=program)
for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册