提交 bf6b9d6d 编写于 作者: Z Zhen Wang

add checkpoint functions for graph. test=develop

上级 27cd3efd
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import os import os
import six import six
import numpy as np
import unittest import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -53,6 +54,7 @@ class TestGraph(unittest.TestCase): ...@@ -53,6 +54,7 @@ class TestGraph(unittest.TestCase):
def graph_apis(self, use_cuda=False, for_ci=True): def graph_apis(self, use_cuda=False, for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
feeds, loss = conv_block() feeds, loss = conv_block()
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
...@@ -77,16 +79,39 @@ class TestGraph(unittest.TestCase): ...@@ -77,16 +79,39 @@ class TestGraph(unittest.TestCase):
paddle.dataset.mnist.train(), batch_size=batch_size) paddle.dataset.mnist.train(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place) feeder = fluid.DataFeeder(feed_list=feeds, place=place)
def train(binary): def _train(binary):
for _ in range(iters): for _ in range(iters):
data = next(train_reader()) data = next(train_reader())
loss_v = exe.run(binary, loss_v = exe.run(binary,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss.name]) fetch_list=[loss.name])
if not for_ci:
print('{}: {}'.format('loss', loss_v)) print('{}: {}'.format('loss', loss_v))
train(origin_binary) _train(origin_binary)
train(backup_binary) _train(backup_binary)
checkponit_dir = "checkpoint_gpu" if use_cuda else "checkpoint_cpu"
def _set_zero(var_name, scope, place):
var = scope.find_var(var_name).get_tensor()
var_array = np.zeros(var._get_dims()).astype("float32")
var.set(var_array, place)
sum_before = np.sum(
np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor(
)))
fluid.io._save_persistable_nodes(exe, checkponit_dir, graph)
_set_zero('conv2d_1.w_0', fluid.global_scope(), place)
set_after = np.sum(
np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor(
)))
self.assertEqual(set_after, 0)
fluid.io._load_persistable_nodes(exe, checkponit_dir, graph)
sum_after = np.sum(
np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor(
)))
self.assertEqual(sum_before, sum_after)
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
......
...@@ -20,6 +20,7 @@ import warnings ...@@ -20,6 +20,7 @@ import warnings
import time import time
import shutil import shutil
import six import six
import logging
from functools import reduce from functools import reduce
from paddle.fluid import layers from paddle.fluid import layers
...@@ -29,12 +30,17 @@ from paddle.fluid.framework import Program, Parameter, default_main_program, def ...@@ -29,12 +30,17 @@ from paddle.fluid.framework import Program, Parameter, default_main_program, def
from . import reader from . import reader
from .reader import * from .reader import *
from . import core from . import core
from .. import compat as cpt
__all__ = [ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model' 'load_persistables', 'save_inference_model', 'load_inference_model'
] + reader.__all__ ] + reader.__all__
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def is_parameter(var): def is_parameter(var):
""" """
...@@ -1181,3 +1187,80 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -1181,3 +1187,80 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program() program = default_main_program()
var = program.global_block().var(name) var = program.global_block().var(name)
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
def _save_persistable_nodes(executor, dirname, graph):
"""
Save persistable nodes to the given directory by the executor.
Args:
executor(Executor): The executor to run for saving node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be saved.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = cpt.to_text(node.name())
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
var_list.append(var)
save_vars(executor=executor, dirname=dirname, vars=var_list)
def _load_persistable_nodes(executor, dirname, graph):
"""
Load persistable node values from the given directory by the executor.
Args:
executor(Executor): The executor to run for loading node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be loaded.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = cpt.to_text(node.name())
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
def _exist(var):
return os.path.exists(os.path.join(dirname, var.name))
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
if _exist(var):
var_list.append(var)
else:
_logger.warn("Cannot find the var %s!!!" % (node.name()))
load_vars(executor=executor, dirname=dirname, vars=var_list)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册