From 514b2427edbd30013ca1783769af18fb96ffb626 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 28 May 2018 20:08:23 +0800 Subject: [PATCH] add save/load persist_vars_without_grad --- python/paddle/fluid/io.py | 46 +++++++++++++++++++++++----------- python/paddle/fluid/trainer.py | 3 ++- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 8e58e5eb7..f62603936 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -24,7 +24,8 @@ __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'load_persistables', 'save_inference_model', 'load_inference_model', 'get_inference_program', 'save_checkpoint', 'load_checkpoint', - 'clean_checkpoint' + 'clean_checkpoint', 'load_persist_vars_without_grad', + 'save_persist_vars_without_grad' ] @@ -455,6 +456,33 @@ def get_parameter_value_by_name(name, executor, program=None): return get_parameter_value(var, executor) +def load_persist_vars_without_grad(executor, dirname, program): + """ + load_persist_vars_without_grad will load variables from a directory by an executor, + the variable named end with "@GRAD" will not be loaded. + """ + load_vars( + executor, + dirname=dirname, + main_program=program, + predicate=_is_checkpoint_var, + filename=None) + + +def save_persist_vars_without_grad(executor, dirname, program): + """ + save_persist_vars_without_grad will save variables to a directory by an executor, + the variable named end with "@GRAD" will not be saved. + """ + save_vars( + executor, + dirname=dirname, + main_program=program, + vars=None, + predicate=_is_checkpoint_var, + filename=None) + + SUCCESS_MARK_FILENAME = "_SUCCESS" CHECKPOINT_PREFIX = "checkpoint" CHECKPOINT_SEPARATOR = "_" @@ -491,13 +519,7 @@ def save_checkpoint(executor, serial += 1 cur_dir = _get_serial_dir(serial, checkpoint_dir) - save_vars( - executor, - dirname=cur_dir, - main_program=main_program, - vars=None, - predicate=_is_checkpoint_var, - filename=None) + load_persist_vars_without_grad(executor, cur_dir, main_program) _write_success(cur_dir) _lru_delete(checkpoint_dir, max_num_checkpoints) @@ -521,13 +543,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None): return cur_dir = _get_serial_dir(serial, checkpoint_dir) - - load_vars( - executor, - dirname=cur_dir, - main_program=main_program, - predicate=_is_checkpoint_var, - filename=None) + load_persist_vars_without_grad(executor, cur_dir, main_program) def clean_checkpoint(checkpoint_dir, delete_dir=False): diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 24254b498..b4b7b75b9 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -162,7 +162,8 @@ class Trainer(object): if param_path: # load params from param_path into scope - io.load_persistables(exe, dirname=param_path) + io.load_persist_vars_without_grad( + exe, dirname=param_path, program=self.startup_program) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS -- GitLab