提交 5eea5db9 编写于 作者: T tangwei12

optimized checkpoint and save_model

上级 514b2427
...@@ -26,6 +26,7 @@ from trainer import BeginEpochEvent ...@@ -26,6 +26,7 @@ from trainer import BeginEpochEvent
from trainer import EndEpochEvent from trainer import EndEpochEvent
from trainer import BeginStepEvent from trainer import BeginStepEvent
from trainer import EndStepEvent from trainer import EndStepEvent
from trainer import CheckpointConfig
import inferencer import inferencer
from inferencer import Inferencer from inferencer import Inferencer
......
...@@ -491,7 +491,6 @@ CHECKPOINT_SEPARATOR = "_" ...@@ -491,7 +491,6 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor, def save_checkpoint(executor,
checkpoint_dir=None, checkpoint_dir=None,
max_num_checkpoints=3, max_num_checkpoints=3,
save_interval_secs=600,
main_program=None): main_program=None):
""" """
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
...@@ -511,15 +510,10 @@ def save_checkpoint(executor, ...@@ -511,15 +510,10 @@ def save_checkpoint(executor,
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir) os.makedirs(checkpoint_dir)
serial = _get_lastest_checkpoint_dir(checkpoint_dir) serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
if serial >= 0 and not _interval_secs_exceed( cur_dir = _get_serial_dir(checkpoint_dir, serial)
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
return
serial += 1
cur_dir = _get_serial_dir(serial, checkpoint_dir)
load_persist_vars_without_grad(executor, cur_dir, main_program) save_persist_vars_without_grad(executor, cur_dir, main_program)
_write_success(cur_dir) _write_success(cur_dir)
_lru_delete(checkpoint_dir, max_num_checkpoints) _lru_delete(checkpoint_dir, max_num_checkpoints)
...@@ -542,7 +536,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None): ...@@ -542,7 +536,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
if serial < 0: if serial < 0:
return return
cur_dir = _get_serial_dir(serial, checkpoint_dir) cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program) load_persist_vars_without_grad(executor, cur_dir, main_program)
...@@ -559,11 +553,6 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): ...@@ -559,11 +553,6 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os.rmdir(checkpoint_dir) os.rmdir(checkpoint_dir)
def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder)
def _is_checkpoint_var(var): def _is_checkpoint_var(var):
""" """
the checkpoint will not save or load all the variables. the checkpoint will not save or load all the variables.
...@@ -582,29 +571,37 @@ def _is_checkpoint_var(var): ...@@ -582,29 +571,37 @@ def _is_checkpoint_var(var):
return var.persistable return var.persistable
def _interval_secs_exceed(dirname, save_interval_secs): def _get_dir_serial(dirname):
dir_time = os.path.getmtime(dirname) _, serial = dirname.split(CHECKPOINT_SEPARATOR)
if save_interval_secs > (time.time() - dir_time):
return False serial_num = -1
return True try:
serial_num = int(serial)
except ValueError:
serial_num = -1
return serial_num
def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(dirname, serial_folder)
def _lru_delete(dirname, max_num_checkpoints=3): def _lru_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname) dirs = os.listdir(dirname)
serials = [] serial_map = {}
for serial in dirs: for serial in dirs:
try: serial_num = _get_dir_serial(serial)
serials.append(int(serial)) serial_map[serial_num] = serial
except ValueError:
continue
if len(serials) <= max_num_checkpoints: if len(serial_map.keys()) <= max_num_checkpoints:
return return
serials = serial_map.keys()
serials.sort(reverse=True) serials.sort(reverse=True)
serials = serials[max_num_checkpoints:] serials = serials[max_num_checkpoints:]
for serial in serials: for serial in serials:
cur_dir = os.path.join(dirname, str(serial)) cur_dir = _get_serial_dir(dirname, serial)
shutil.rmtree(cur_dir) shutil.rmtree(cur_dir)
...@@ -633,20 +630,18 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): ...@@ -633,20 +630,18 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
""" """
is _SUCCESS in this dir is _SUCCESS in this dir
""" """
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
try: serial = _get_dir_serial(cur_dir)
int(serial) if serial == -1:
except ValueError:
return -1 return -1
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1 return -1
success_path = os.path.join( success_path = os.path.join(
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME) _get_serial_dir(checkpoint_dir, serial), SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path): if os.path.isfile(success_path):
return int(serial) return serial
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
return -1 return -1
......
...@@ -60,11 +60,24 @@ class CheckpointConfig(object): ...@@ -60,11 +60,24 @@ class CheckpointConfig(object):
def __init__(self, def __init__(self,
checkpoint_dir=None, checkpoint_dir=None,
max_num_checkpoints=3, max_num_checkpoints=3,
save_interval_secs=600): epoch_interval=1,
step_interval=10):
if checkpoint_dir is None: if checkpoint_dir is None:
self.checkpoint_dir = os.getcwd() self.checkpoint_dir = os.getcwd()
else:
self.checkpoint_dir = checkpoint_dir
self.max_num_checkpoints = max_num_checkpoints self.max_num_checkpoints = max_num_checkpoints
self.save_interval_secs = save_interval_secs
if epoch_interval < 1:
self.epoch_interval = 1
else:
self.epoch_interval = epoch_interval
if step_interval < 1:
self.step_interval = 10
else:
self.step_interval = step_interval
def check_and_get_place(place): def check_and_get_place(place):
...@@ -290,14 +303,6 @@ class Trainer(object): ...@@ -290,14 +303,6 @@ class Trainer(object):
exe = executor.Executor(self.place) exe = executor.Executor(self.place)
io.save_persistables(exe, dirname=param_path) io.save_persistables(exe, dirname=param_path)
def _save_checkpoint(self):
if self.checkpoint and self.chief:
exe = executor.Executor(self.place)
io.save_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.checkpoint.max_num_checkpoints,
self.checkpoint.save_interval_secs,
self.train_program)
@contextlib.contextmanager @contextlib.contextmanager
def _prog_and_scope_guard(self): def _prog_and_scope_guard(self):
with framework.program_guard( with framework.program_guard(
...@@ -343,8 +348,9 @@ class Trainer(object): ...@@ -343,8 +348,9 @@ class Trainer(object):
]) ])
else: else:
metrics = exe.run(feed=data, fetch_list=[]) metrics = exe.run(feed=data, fetch_list=[])
event_handler(EndStepEvent(epoch_id, step_id, metrics)) event_handler(EndStepEvent(epoch_id, step_id, metrics))
self._save_checkpoint() self._save_checkpoint(epoch_id, step_id)
event_handler(EndEpochEvent(epoch_id)) event_handler(EndEpochEvent(epoch_id))
def _test_by_executor(self, reader, feed_order, fetch_list): def _test_by_executor(self, reader, feed_order, fetch_list):
...@@ -384,6 +390,18 @@ class Trainer(object): ...@@ -384,6 +390,18 @@ class Trainer(object):
loss_name=self.train_func_outputs[0].name) loss_name=self.train_func_outputs[0].name)
return self._get_parallel_executor() return self._get_parallel_executor()
def _save_checkpoint(self, epoch_id, step_id):
if not self.checkpoint or not self.chief:
return
if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0:
exe = executor.Executor(self.place)
io.save_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint.checkpoint_dir,
max_num_checkpoints=self.checkpoint.max_num_checkpoints,
main_program=self.train_program)
def build_feed_var_list(program, feed_order): def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program): if not isinstance(program, framework.Program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册