提交 5eea5db9 编写于 作者: T tangwei12

optimized checkpoint and save_model

上级 514b2427
......@@ -26,6 +26,7 @@ from trainer import BeginEpochEvent
from trainer import EndEpochEvent
from trainer import BeginStepEvent
from trainer import EndStepEvent
from trainer import CheckpointConfig
import inferencer
from inferencer import Inferencer
......
......@@ -491,7 +491,6 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
checkpoint_dir=None,
max_num_checkpoints=3,
save_interval_secs=600,
main_program=None):
"""
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
......@@ -511,15 +510,10 @@ def save_checkpoint(executor,
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial >= 0 and not _interval_secs_exceed(
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
return
serial += 1
cur_dir = _get_serial_dir(serial, checkpoint_dir)
serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program)
save_persist_vars_without_grad(executor, cur_dir, main_program)
_write_success(cur_dir)
_lru_delete(checkpoint_dir, max_num_checkpoints)
......@@ -542,7 +536,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
if serial < 0:
return
cur_dir = _get_serial_dir(serial, checkpoint_dir)
cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program)
......@@ -559,11 +553,6 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os.rmdir(checkpoint_dir)
def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder)
def _is_checkpoint_var(var):
"""
the checkpoint will not save or load all the variables.
......@@ -582,29 +571,37 @@ def _is_checkpoint_var(var):
return var.persistable
def _interval_secs_exceed(dirname, save_interval_secs):
dir_time = os.path.getmtime(dirname)
if save_interval_secs > (time.time() - dir_time):
return False
return True
def _get_dir_serial(dirname):
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
serial_num = -1
try:
serial_num = int(serial)
except ValueError:
serial_num = -1
return serial_num
def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(dirname, serial_folder)
def _lru_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname)
serials = []
serial_map = {}
for serial in dirs:
try:
serials.append(int(serial))
except ValueError:
continue
serial_num = _get_dir_serial(serial)
serial_map[serial_num] = serial
if len(serials) <= max_num_checkpoints:
if len(serial_map.keys()) <= max_num_checkpoints:
return
serials = serial_map.keys()
serials.sort(reverse=True)
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = os.path.join(dirname, str(serial))
cur_dir = _get_serial_dir(dirname, serial)
shutil.rmtree(cur_dir)
......@@ -633,20 +630,18 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
"""
is _SUCCESS in this dir
"""
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
try:
int(serial)
except ValueError:
serial = _get_dir_serial(cur_dir)
if serial == -1:
return -1
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
success_path = os.path.join(
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
_get_serial_dir(checkpoint_dir, serial), SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path):
return int(serial)
return serial
if not os.path.isdir(checkpoint_dir):
return -1
......
......@@ -60,11 +60,24 @@ class CheckpointConfig(object):
def __init__(self,
checkpoint_dir=None,
max_num_checkpoints=3,
save_interval_secs=600):
epoch_interval=1,
step_interval=10):
if checkpoint_dir is None:
self.checkpoint_dir = os.getcwd()
else:
self.checkpoint_dir = checkpoint_dir
self.max_num_checkpoints = max_num_checkpoints
self.save_interval_secs = save_interval_secs
if epoch_interval < 1:
self.epoch_interval = 1
else:
self.epoch_interval = epoch_interval
if step_interval < 1:
self.step_interval = 10
else:
self.step_interval = step_interval
def check_and_get_place(place):
......@@ -290,14 +303,6 @@ class Trainer(object):
exe = executor.Executor(self.place)
io.save_persistables(exe, dirname=param_path)
def _save_checkpoint(self):
if self.checkpoint and self.chief:
exe = executor.Executor(self.place)
io.save_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.checkpoint.max_num_checkpoints,
self.checkpoint.save_interval_secs,
self.train_program)
@contextlib.contextmanager
def _prog_and_scope_guard(self):
with framework.program_guard(
......@@ -343,8 +348,9 @@ class Trainer(object):
])
else:
metrics = exe.run(feed=data, fetch_list=[])
event_handler(EndStepEvent(epoch_id, step_id, metrics))
self._save_checkpoint()
self._save_checkpoint(epoch_id, step_id)
event_handler(EndEpochEvent(epoch_id))
def _test_by_executor(self, reader, feed_order, fetch_list):
......@@ -384,6 +390,18 @@ class Trainer(object):
loss_name=self.train_func_outputs[0].name)
return self._get_parallel_executor()
def _save_checkpoint(self, epoch_id, step_id):
if not self.checkpoint or not self.chief:
return
if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0:
exe = executor.Executor(self.place)
io.save_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint.checkpoint_dir,
max_num_checkpoints=self.checkpoint.max_num_checkpoints,
main_program=self.train_program)
def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册