提交 2f44585e 编写于 作者: T tangwei12

code optimized

上级 53409a29
......@@ -476,14 +476,14 @@ def save_checkpoint(executor,
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval between two saved checkpoints must greater than save_interval_secs.
:param executor
:param checkpoint_dir
:param trainer_id
:param is_chief
:param main_program
:param max_num_checkpoints
"""
if checkpoint_dir.strip() is None:
:param executor executor for save the value
:param checkpoint_dir the checkpoint directory
:param trainer_id currect trainer id
:param is_chief if the trainer id equals 0, the is_chief will be true
:param main_program will save all variables in program
:param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
"""
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
if trainer_args:
......@@ -500,7 +500,7 @@ def save_checkpoint(executor,
if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program)
_lru_delete(checkpoint_dir, max_num_checkpoints)
_scroll_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
......@@ -508,13 +508,13 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
Load checkpoint from a directory by executor,
it will find the most recent saved checkpoint file and load it auto.
:param executor
:param checkpoint_dir
:param serial
:param main_program
:param executor executor for load the value
:param checkpoint_dir the checkpoint directory
:param serial the serial folder in checkpoint directory will be load
:param main_program will load all variables in program
"""
if checkpoint_dir.strip() is None:
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
if serial is None or serial < 0:
......@@ -536,9 +536,9 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
:param delete_dir
"""
if checkpoint_dir.strip() is None:
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
_scroll_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir)
......@@ -681,7 +681,7 @@ def _get_trainer_dir(dirname, trainer_id):
return trainer_dir
def _lru_delete(dirname, max_num_checkpoints=3):
def _scroll_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname)
serial_map = {}
for serial in dirs:
......@@ -717,7 +717,7 @@ def get_latest_checkpoint_serial(checkpoint_dir):
:param checkpoint_dir
"""
if not checkpoint_dir.strip():
if not checkpoint_dir:
return -1
def has_success(checkpoint_dir, cur_dir):
......@@ -726,10 +726,8 @@ def get_latest_checkpoint_serial(checkpoint_dir):
"""
serial = _get_dir_serial(cur_dir)
if serial == -1:
return -1
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
if serial == -1 or not os.path.isdir(
os.path.join(checkpoint_dir, cur_dir)):
return -1
success_path = os.path.join(
......
......@@ -15,11 +15,12 @@
import paddle.fluid as fluid
import unittest
import os
import tempfile
class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.dirname = "/tmp/ckpt"
self.dirname = tempfile.mktemp()
self.max_num_checkpoints = 3
self.epoch_interval = 1
self.step_interval = 1
......
......@@ -132,19 +132,18 @@ class Trainer(object):
# 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in
# test_word2vec.py
if not isinstance(optimizer, opt_module.Optimizer):
raise TypeError("The optimizer should be an instance of Optimizer")
assert isinstance(optimizer, opt_module.Optimizer)
# config for checkpoint
# only chief worker will save variables
self.trainer_id = 0
self.chief = True
self.checkpoint = checkpoint_config
if self.checkpoint:
assert isinstance(self.checkpoint, CheckpointConfig)
self.checkpoint_cfg = checkpoint_config
if self.checkpoint_cfg:
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
serial = io.get_latest_checkpoint_serial(
self.checkpoint.checkpoint_dir)
self.checkpoint.load_serial = serial if serial >= 0 else None
self.checkpoint_cfg.checkpoint_dir)
self.checkpoint_cfg.load_serial = serial if serial >= 0 else None
self.scope = core.Scope()
......@@ -174,19 +173,20 @@ class Trainer(object):
exe = executor.Executor(place)
exe.run(self.startup_program)
if self.checkpoint and self.checkpoint.load_serial:
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
with self._prog_and_scope_guard():
exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.checkpoint.load_serial,
io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial,
self.startup_program)
if not self.checkpoint.is_pserver:
if not self.checkpoint_cfg.is_pserver:
epoch_id, step_id = io.load_trainer_args(
self.checkpoint.checkpoint_dir, self.checkpoint.load_serial,
self.trainer_id, self._get_checkpoint_load_args())
self.checkpoint.epoch_id = int(epoch_id)
self.checkpoint.step_id = int(step_id)
self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
if param_path and os.path.isdir(param_path):
# load params from param_path into scope
......@@ -256,7 +256,7 @@ class Trainer(object):
t.transpile(
self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
if self.checkpoint:
if self.checkpoint_cfg:
self.is_pserver = True
self.train_program = t.get_pserver_program(current_endpoint)
......@@ -351,10 +351,10 @@ class Trainer(object):
self._train_by_any_executor(event_handler, exe, num_epochs, reader)
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
if self.checkpoint:
if self.checkpoint_cfg:
epochs = [
epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint.epoch_id
if epoch_id >= self.checkpoint_cfg.epoch_id
]
else:
epochs = [epoch_id for epoch_id in range(num_epochs)]
......@@ -366,8 +366,8 @@ class Trainer(object):
self._clean_checkpoint()
return
if self.checkpoint and self.checkpoint.load_serial \
and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id:
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \
and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id:
continue
begin_event = BeginStepEvent(epoch_id, step_id)
......@@ -381,10 +381,12 @@ class Trainer(object):
else:
metrics = exe.run(feed=data, fetch_list=[])
self._save_checkpoint(epoch_id, step_id)
if self.checkpoint_cfg:
self._save_checkpoint(epoch_id, step_id)
event_handler(EndStepEvent(epoch_id, step_id, metrics))
event_handler(EndEpochEvent(epoch_id))
self._clean_checkpoint()
if self.checkpoint_cfg:
self._clean_checkpoint()
def _test_by_executor(self, reader, feed_order, fetch_list):
with executor.scope_guard(self.scope):
......@@ -424,9 +426,8 @@ class Trainer(object):
return self._get_parallel_executor()
def _clean_checkpoint(self):
if not self.checkpoint:
return
io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir)
assert self.checkpoint_cfg
io.clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir)
def _get_checkpoint_load_args(self):
"""
......@@ -444,19 +445,18 @@ class Trainer(object):
return trainer_args
def _save_checkpoint(self, epoch_id, step_id):
if not self.checkpoint:
return
assert self.checkpoint_cfg
if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0:
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0:
exe = executor.Executor(self.place)
io.save_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint.checkpoint_dir,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
trainer_id=self.trainer_id,
is_chief=self.chief,
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
main_program=self.train_program,
max_num_checkpoints=self.checkpoint.max_num_checkpoints)
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
def build_feed_var_list(program, feed_order):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册