未验证 提交 b8c166f6 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #17029 from wzzju/add_graph_checkpoint

add checkpoint functions for graph. test=develop
......@@ -15,6 +15,7 @@
from __future__ import print_function
import os
import six
import numpy as np
import unittest
import paddle
import paddle.fluid as fluid
......@@ -53,10 +54,11 @@ class TestGraph(unittest.TestCase):
def graph_apis(self, use_cuda=False, for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
feeds, loss = conv_block()
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
feeds, loss = conv_block()
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False)
backup_graph = graph.clone()
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
......@@ -77,16 +79,39 @@ class TestGraph(unittest.TestCase):
paddle.dataset.mnist.train(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
def train(binary):
def _train(binary):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[loss.name])
print('{}: {}'.format('loss', loss_v))
if not for_ci:
print('{}: {}'.format('loss', loss_v))
train(origin_binary)
train(backup_binary)
_train(origin_binary)
_train(backup_binary)
checkponit_dir = "checkpoint_gpu" if use_cuda else "checkpoint_cpu"
def _set_zero(var_name, scope, place):
var = scope.find_var(var_name).get_tensor()
var_array = np.zeros(var._get_dims()).astype("float32")
var.set(var_array, place)
sum_before = np.sum(
np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor(
)))
fluid.io._save_persistable_nodes(exe, checkponit_dir, graph)
_set_zero('conv2d_1.w_0', fluid.global_scope(), place)
set_after = np.sum(
np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor(
)))
self.assertEqual(set_after, 0)
fluid.io._load_persistable_nodes(exe, checkponit_dir, graph)
sum_after = np.sum(
np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor(
)))
self.assertEqual(sum_before, sum_after)
marked_nodes = set()
for op in graph.all_op_nodes():
......
......@@ -20,6 +20,7 @@ import warnings
import time
import shutil
import six
import logging
from functools import reduce
from paddle.fluid import layers
......@@ -29,12 +30,17 @@ from paddle.fluid.framework import Program, Parameter, default_main_program, def
from . import reader
from .reader import *
from . import core
from .. import compat as cpt
__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model'
] + reader.__all__
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def is_parameter(var):
"""
......@@ -1181,3 +1187,80 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program()
var = program.global_block().var(name)
return get_parameter_value(var, executor)
def _save_persistable_nodes(executor, dirname, graph):
"""
Save persistable nodes to the given directory by the executor.
Args:
executor(Executor): The executor to run for saving node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be saved.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = cpt.to_text(node.name())
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
var_list.append(var)
save_vars(executor=executor, dirname=dirname, vars=var_list)
def _load_persistable_nodes(executor, dirname, graph):
"""
Load persistable node values from the given directory by the executor.
Args:
executor(Executor): The executor to run for loading node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be loaded.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = cpt.to_text(node.name())
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
def _exist(var):
return os.path.exists(os.path.join(dirname, var.name))
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
if _exist(var):
var_list.append(var)
else:
_logger.warn("Cannot find the var %s!!!" % (node.name()))
load_vars(executor=executor, dirname=dirname, vars=var_list)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册