From bf6b9d6d378c1113e981185e1bce04dea597dea9 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Mon, 22 Apr 2019 19:42:05 +0800 Subject: [PATCH] add checkpoint functions for graph. test=develop --- .../fluid/contrib/slim/tests/test_graph.py | 41 +++++++-- python/paddle/fluid/io.py | 83 +++++++++++++++++++ 2 files changed, 116 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph.py b/python/paddle/fluid/contrib/slim/tests/test_graph.py index 3629fed160..cb11c21826 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_graph.py +++ b/python/paddle/fluid/contrib/slim/tests/test_graph.py @@ -15,6 +15,7 @@ from __future__ import print_function import os import six +import numpy as np import unittest import paddle import paddle.fluid as fluid @@ -53,10 +54,11 @@ class TestGraph(unittest.TestCase): def graph_apis(self, use_cuda=False, for_ci=True): main = fluid.Program() startup = fluid.Program() - with fluid.program_guard(main, startup): - feeds, loss = conv_block() - opt = fluid.optimizer.Adam(learning_rate=0.001) - opt.minimize(loss) + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + feeds, loss = conv_block() + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) graph = IrGraph(core.Graph(main.desc), for_test=False) backup_graph = graph.clone() self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes())) @@ -77,16 +79,39 @@ class TestGraph(unittest.TestCase): paddle.dataset.mnist.train(), batch_size=batch_size) feeder = fluid.DataFeeder(feed_list=feeds, place=place) - def train(binary): + def _train(binary): for _ in range(iters): data = next(train_reader()) loss_v = exe.run(binary, feed=feeder.feed(data), fetch_list=[loss.name]) - print('{}: {}'.format('loss', loss_v)) + if not for_ci: + print('{}: {}'.format('loss', loss_v)) - train(origin_binary) - train(backup_binary) + _train(origin_binary) + _train(backup_binary) + + checkponit_dir = "checkpoint_gpu" if use_cuda else "checkpoint_cpu" + + def _set_zero(var_name, scope, place): + var = scope.find_var(var_name).get_tensor() + var_array = np.zeros(var._get_dims()).astype("float32") + var.set(var_array, place) + + sum_before = np.sum( + np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor( + ))) + fluid.io._save_persistable_nodes(exe, checkponit_dir, graph) + _set_zero('conv2d_1.w_0', fluid.global_scope(), place) + set_after = np.sum( + np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor( + ))) + self.assertEqual(set_after, 0) + fluid.io._load_persistable_nodes(exe, checkponit_dir, graph) + sum_after = np.sum( + np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor( + ))) + self.assertEqual(sum_before, sum_after) marked_nodes = set() for op in graph.all_op_nodes(): diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 4d55236272..16524d385f 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -20,6 +20,7 @@ import warnings import time import shutil import six +import logging from functools import reduce from paddle.fluid import layers @@ -29,12 +30,17 @@ from paddle.fluid.framework import Program, Parameter, default_main_program, def from . import reader from .reader import * from . import core +from .. import compat as cpt __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'load_persistables', 'save_inference_model', 'load_inference_model' ] + reader.__all__ +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + def is_parameter(var): """ @@ -1181,3 +1187,80 @@ def get_parameter_value_by_name(name, executor, program=None): program = default_main_program() var = program.global_block().var(name) return get_parameter_value(var, executor) + + +def _save_persistable_nodes(executor, dirname, graph): + """ + Save persistable nodes to the given directory by the executor. + + Args: + executor(Executor): The executor to run for saving node values. + dirname(str): The directory path. + graph(IrGraph): All the required persistable nodes in the graph will be saved. + """ + persistable_node_names = set() + persistable_nodes = [] + all_persistable_nodes = graph.all_persistable_nodes() + for node in all_persistable_nodes: + name = cpt.to_text(node.name()) + if name not in persistable_node_names: + persistable_node_names.add(name) + persistable_nodes.append(node) + program = Program() + var_list = [] + for node in persistable_nodes: + var_desc = node.var() + if var_desc.type() == core.VarDesc.VarType.RAW or \ + var_desc.type() == core.VarDesc.VarType.READER: + continue + var = program.global_block().create_var( + name=var_desc.name(), + shape=var_desc.shape(), + dtype=var_desc.dtype(), + type=var_desc.type(), + lod_level=var_desc.lod_level(), + persistable=var_desc.persistable()) + var_list.append(var) + save_vars(executor=executor, dirname=dirname, vars=var_list) + + +def _load_persistable_nodes(executor, dirname, graph): + """ + Load persistable node values from the given directory by the executor. + + Args: + executor(Executor): The executor to run for loading node values. + dirname(str): The directory path. + graph(IrGraph): All the required persistable nodes in the graph will be loaded. + """ + persistable_node_names = set() + persistable_nodes = [] + all_persistable_nodes = graph.all_persistable_nodes() + for node in all_persistable_nodes: + name = cpt.to_text(node.name()) + if name not in persistable_node_names: + persistable_node_names.add(name) + persistable_nodes.append(node) + program = Program() + var_list = [] + + def _exist(var): + return os.path.exists(os.path.join(dirname, var.name)) + + for node in persistable_nodes: + var_desc = node.var() + if var_desc.type() == core.VarDesc.VarType.RAW or \ + var_desc.type() == core.VarDesc.VarType.READER: + continue + var = program.global_block().create_var( + name=var_desc.name(), + shape=var_desc.shape(), + dtype=var_desc.dtype(), + type=var_desc.type(), + lod_level=var_desc.lod_level(), + persistable=var_desc.persistable()) + if _exist(var): + var_list.append(var) + else: + _logger.warn("Cannot find the var %s!!!" % (node.name())) + load_vars(executor=executor, dirname=dirname, vars=var_list) -- GitLab