提交 2f44585e 编写于 作者: T tangwei12

code optimized

上级 53409a29
...@@ -476,14 +476,14 @@ def save_checkpoint(executor, ...@@ -476,14 +476,14 @@ def save_checkpoint(executor,
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval between two saved checkpoints must greater than save_interval_secs. The interval between two saved checkpoints must greater than save_interval_secs.
:param executor :param executor executor for save the value
:param checkpoint_dir :param checkpoint_dir the checkpoint directory
:param trainer_id :param trainer_id currect trainer id
:param is_chief :param is_chief if the trainer id equals 0, the is_chief will be true
:param main_program :param main_program will save all variables in program
:param max_num_checkpoints :param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
""" """
if checkpoint_dir.strip() is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
if trainer_args: if trainer_args:
...@@ -500,7 +500,7 @@ def save_checkpoint(executor, ...@@ -500,7 +500,7 @@ def save_checkpoint(executor,
if is_chief: if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program) save_persist_vars_without_grad(executor, cur_dir, main_program)
_lru_delete(checkpoint_dir, max_num_checkpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir, serial, main_program): def load_checkpoint(executor, checkpoint_dir, serial, main_program):
...@@ -508,13 +508,13 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -508,13 +508,13 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
Load checkpoint from a directory by executor, Load checkpoint from a directory by executor,
it will find the most recent saved checkpoint file and load it auto. it will find the most recent saved checkpoint file and load it auto.
:param executor :param executor executor for load the value
:param checkpoint_dir :param checkpoint_dir the checkpoint directory
:param serial :param serial the serial folder in checkpoint directory will be load
:param main_program :param main_program will load all variables in program
""" """
if checkpoint_dir.strip() is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
if serial is None or serial < 0: if serial is None or serial < 0:
...@@ -536,9 +536,9 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): ...@@ -536,9 +536,9 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
:param delete_dir :param delete_dir
""" """
if checkpoint_dir.strip() is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
_lru_delete(checkpoint_dir, max_num_checkpoints=0) _scroll_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir): if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir) os.rmdir(checkpoint_dir)
...@@ -681,7 +681,7 @@ def _get_trainer_dir(dirname, trainer_id): ...@@ -681,7 +681,7 @@ def _get_trainer_dir(dirname, trainer_id):
return trainer_dir return trainer_dir
def _lru_delete(dirname, max_num_checkpoints=3): def _scroll_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname) dirs = os.listdir(dirname)
serial_map = {} serial_map = {}
for serial in dirs: for serial in dirs:
...@@ -717,7 +717,7 @@ def get_latest_checkpoint_serial(checkpoint_dir): ...@@ -717,7 +717,7 @@ def get_latest_checkpoint_serial(checkpoint_dir):
:param checkpoint_dir :param checkpoint_dir
""" """
if not checkpoint_dir.strip(): if not checkpoint_dir:
return -1 return -1
def has_success(checkpoint_dir, cur_dir): def has_success(checkpoint_dir, cur_dir):
...@@ -726,10 +726,8 @@ def get_latest_checkpoint_serial(checkpoint_dir): ...@@ -726,10 +726,8 @@ def get_latest_checkpoint_serial(checkpoint_dir):
""" """
serial = _get_dir_serial(cur_dir) serial = _get_dir_serial(cur_dir)
if serial == -1: if serial == -1 or not os.path.isdir(
return -1 os.path.join(checkpoint_dir, cur_dir)):
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1 return -1
success_path = os.path.join( success_path = os.path.join(
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
import paddle.fluid as fluid import paddle.fluid as fluid
import unittest import unittest
import os import os
import tempfile
class TestCheckpoint(unittest.TestCase): class TestCheckpoint(unittest.TestCase):
def setUp(self): def setUp(self):
self.dirname = "/tmp/ckpt" self.dirname = tempfile.mktemp()
self.max_num_checkpoints = 3 self.max_num_checkpoints = 3
self.epoch_interval = 1 self.epoch_interval = 1
self.step_interval = 1 self.step_interval = 1
......
...@@ -132,19 +132,18 @@ class Trainer(object): ...@@ -132,19 +132,18 @@ class Trainer(object):
# 1. we need to generate a framework.Program by calling # 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in # program_func. Reference: fluid.program_guard in
# test_word2vec.py # test_word2vec.py
if not isinstance(optimizer, opt_module.Optimizer): assert isinstance(optimizer, opt_module.Optimizer)
raise TypeError("The optimizer should be an instance of Optimizer")
# config for checkpoint # config for checkpoint
# only chief worker will save variables # only chief worker will save variables
self.trainer_id = 0 self.trainer_id = 0
self.chief = True self.chief = True
self.checkpoint = checkpoint_config self.checkpoint_cfg = checkpoint_config
if self.checkpoint: if self.checkpoint_cfg:
assert isinstance(self.checkpoint, CheckpointConfig) assert isinstance(self.checkpoint_cfg, CheckpointConfig)
serial = io.get_latest_checkpoint_serial( serial = io.get_latest_checkpoint_serial(
self.checkpoint.checkpoint_dir) self.checkpoint_cfg.checkpoint_dir)
self.checkpoint.load_serial = serial if serial >= 0 else None self.checkpoint_cfg.load_serial = serial if serial >= 0 else None
self.scope = core.Scope() self.scope = core.Scope()
...@@ -174,19 +173,20 @@ class Trainer(object): ...@@ -174,19 +173,20 @@ class Trainer(object):
exe = executor.Executor(place) exe = executor.Executor(place)
exe.run(self.startup_program) exe.run(self.startup_program)
if self.checkpoint and self.checkpoint.load_serial: if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
with self._prog_and_scope_guard(): with self._prog_and_scope_guard():
exe = executor.Executor(place) exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint.load_serial, self.checkpoint_cfg.load_serial,
self.startup_program) self.startup_program)
if not self.checkpoint.is_pserver: if not self.checkpoint_cfg.is_pserver:
epoch_id, step_id = io.load_trainer_args( epoch_id, step_id = io.load_trainer_args(
self.checkpoint.checkpoint_dir, self.checkpoint.load_serial, self.checkpoint_cfg.checkpoint_dir,
self.trainer_id, self._get_checkpoint_load_args()) self.checkpoint_cfg.load_serial, self.trainer_id,
self.checkpoint.epoch_id = int(epoch_id) self._get_checkpoint_load_args())
self.checkpoint.step_id = int(step_id) self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
...@@ -256,7 +256,7 @@ class Trainer(object): ...@@ -256,7 +256,7 @@ class Trainer(object):
t.transpile( t.transpile(
self.trainer_id, pservers=pserver_endpoints, trainers=trainers) self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
if self.checkpoint: if self.checkpoint_cfg:
self.is_pserver = True self.is_pserver = True
self.train_program = t.get_pserver_program(current_endpoint) self.train_program = t.get_pserver_program(current_endpoint)
...@@ -351,10 +351,10 @@ class Trainer(object): ...@@ -351,10 +351,10 @@ class Trainer(object):
self._train_by_any_executor(event_handler, exe, num_epochs, reader) self._train_by_any_executor(event_handler, exe, num_epochs, reader)
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
if self.checkpoint: if self.checkpoint_cfg:
epochs = [ epochs = [
epoch_id for epoch_id in range(num_epochs) epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint.epoch_id if epoch_id >= self.checkpoint_cfg.epoch_id
] ]
else: else:
epochs = [epoch_id for epoch_id in range(num_epochs)] epochs = [epoch_id for epoch_id in range(num_epochs)]
...@@ -366,8 +366,8 @@ class Trainer(object): ...@@ -366,8 +366,8 @@ class Trainer(object):
self._clean_checkpoint() self._clean_checkpoint()
return return
if self.checkpoint and self.checkpoint.load_serial \ if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \
and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id: and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id:
continue continue
begin_event = BeginStepEvent(epoch_id, step_id) begin_event = BeginStepEvent(epoch_id, step_id)
...@@ -381,10 +381,12 @@ class Trainer(object): ...@@ -381,10 +381,12 @@ class Trainer(object):
else: else:
metrics = exe.run(feed=data, fetch_list=[]) metrics = exe.run(feed=data, fetch_list=[])
self._save_checkpoint(epoch_id, step_id) if self.checkpoint_cfg:
self._save_checkpoint(epoch_id, step_id)
event_handler(EndStepEvent(epoch_id, step_id, metrics)) event_handler(EndStepEvent(epoch_id, step_id, metrics))
event_handler(EndEpochEvent(epoch_id)) event_handler(EndEpochEvent(epoch_id))
self._clean_checkpoint() if self.checkpoint_cfg:
self._clean_checkpoint()
def _test_by_executor(self, reader, feed_order, fetch_list): def _test_by_executor(self, reader, feed_order, fetch_list):
with executor.scope_guard(self.scope): with executor.scope_guard(self.scope):
...@@ -424,9 +426,8 @@ class Trainer(object): ...@@ -424,9 +426,8 @@ class Trainer(object):
return self._get_parallel_executor() return self._get_parallel_executor()
def _clean_checkpoint(self): def _clean_checkpoint(self):
if not self.checkpoint: assert self.checkpoint_cfg
return io.clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir)
io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir)
def _get_checkpoint_load_args(self): def _get_checkpoint_load_args(self):
""" """
...@@ -444,19 +445,18 @@ class Trainer(object): ...@@ -444,19 +445,18 @@ class Trainer(object):
return trainer_args return trainer_args
def _save_checkpoint(self, epoch_id, step_id): def _save_checkpoint(self, epoch_id, step_id):
if not self.checkpoint: assert self.checkpoint_cfg
return
if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0: if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0:
exe = executor.Executor(self.place) exe = executor.Executor(self.place)
io.save_checkpoint( io.save_checkpoint(
executor=exe, executor=exe,
checkpoint_dir=self.checkpoint.checkpoint_dir, checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
trainer_id=self.trainer_id, trainer_id=self.trainer_id,
is_chief=self.chief, is_chief=self.chief,
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id), trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
main_program=self.train_program, main_program=self.train_program,
max_num_checkpoints=self.checkpoint.max_num_checkpoints) max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
def build_feed_var_list(program, feed_order): def build_feed_var_list(program, feed_order):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册