提交 ad9dfeb0 编写于 作者: T tangwei12

bug fix and optimize

上级 5f5d6a9d
...@@ -456,40 +456,18 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -456,40 +456,18 @@ def get_parameter_value_by_name(name, executor, program=None):
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
def load_persist_vars_without_grad(executor, dirname, program):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
"""
load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var,
filename=None)
def save_persist_vars_without_grad(executor, dirname, program):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
"""
save_vars(
executor,
dirname=dirname,
main_program=program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
SUCCESS_MARK_FILENAME = "_SUCCESS" SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint" CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__"
TRAINER_PREFIX = "trainer"
CHECKPOINT_SEPARATOR = "_" CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor, def save_checkpoint(executor,
checkpoint_dir, checkpoint_dir,
trainer_id,
is_chief=False,
trainer_args=None,
main_program=None, main_program=None,
max_num_checkpoints=3): max_num_checkpoints=3):
""" """
...@@ -502,22 +480,35 @@ def save_checkpoint(executor, ...@@ -502,22 +480,35 @@ def save_checkpoint(executor,
:param checkpoint_dir :param checkpoint_dir
:param main_program :param main_program
:param max_num_checkpoints :param max_num_checkpoints
:param is_chief
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("The values of 'checkpoint_dir' should not be None") raise ValueError("The values of 'checkpoint_dir' should not be None")
if trainer_args and not isinstance(trainer_args, dict):
raise TypeError("The type of 'trainer_args' should be dict")
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir) os.makedirs(checkpoint_dir)
serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1 serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program) save_persist_vars_without_grad(executor, cur_dir, main_program)
_write_success(cur_dir)
save_trainer_args(cur_dir, trainer_id, trainer_args)
_lru_delete(checkpoint_dir, max_num_checkpoints) _lru_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir, main_program=None): def need_load_checkpoint(checkpoint_dir):
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial < 0:
return None
return serial
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
""" """
Load checkpoint from a directory by executor, Load checkpoint from a directory by executor,
it will find the most recent saved checkpoint file and load it auto. it will find the most recent saved checkpoint file and load it auto.
...@@ -528,14 +519,17 @@ def load_checkpoint(executor, checkpoint_dir, main_program=None): ...@@ -528,14 +519,17 @@ def load_checkpoint(executor, checkpoint_dir, main_program=None):
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("The values of 'checkpoint_dir' should not be None") raise ValueError(
"The values of 'checkpoint_dir' or 'serial' should not be None")
serial = _get_lastest_checkpoint_dir(checkpoint_dir) if serial is None or serial < 0:
raise ValueError("The values of 'serial' should not be None or <0 ")
if serial < 0: if main_program is None:
return raise ValueError("The values of 'main_program'should not be None")
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_model_dir(cur_dir)
load_persist_vars_without_grad(executor, cur_dir, main_program) load_persist_vars_without_grad(executor, cur_dir, main_program)
...@@ -552,6 +546,68 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): ...@@ -552,6 +546,68 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os.rmdir(checkpoint_dir) os.rmdir(checkpoint_dir)
def load_persist_vars_without_grad(executor, dirname, program, nest=True):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
"""
if nest:
dirname = _get_model_dir(dirname)
load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var,
filename=None)
def save_persist_vars_without_grad(executor, dirname, program):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
"""
cur_dir = _get_model_dir(dirname)
save_vars(
executor,
dirname=cur_dir,
main_program=program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
def save_trainer_args(dirname, trainer_id, trainer_args):
if not isinstance(trainer_args, dict):
raise TypeError("The type of 'trainer_args' should be dict")
cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in trainer_args.iteritems():
args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f:
f.write(str(value))
_write_success(cur_dir)
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
if not isinstance(trainer_args, list):
raise TypeError("The type of 'trainer_args' should be list")
ret_values = []
for arg in trainer_args:
cur_file = os.path.join(cur_dir, arg)
with open(cur_file, 'r') as f:
contents = f.read()
ret_values.append(contents.strip())
return ret_values
def _is_checkpoint_var(var): def _is_checkpoint_var(var):
""" """
the checkpoint will not save or load all the variables. the checkpoint will not save or load all the variables.
...@@ -583,7 +639,31 @@ def _get_dir_serial(dirname): ...@@ -583,7 +639,31 @@ def _get_dir_serial(dirname):
def _get_serial_dir(dirname, serial): def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(dirname, serial_folder) serial_dir = os.path.join(dirname, serial_folder)
if not os.path.isdir(serial_dir):
os.makedirs(serial_dir)
return serial_dir
def _get_model_dir(dirname):
model_dir = os.path.join(dirname, MODEL_DIR)
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
return model_dir
def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder)
if not os.path.isdir(trainer_dir):
os.makedirs(trainer_dir)
return trainer_dir
def _lru_delete(dirname, max_num_checkpoints=3): def _lru_delete(dirname, max_num_checkpoints=3):
...@@ -638,7 +718,8 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): ...@@ -638,7 +718,8 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
return -1 return -1
success_path = os.path.join( success_path = os.path.join(
_get_serial_dir(checkpoint_dir, serial), SUCCESS_MARK_FILENAME) _get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path): if os.path.isfile(success_path):
return serial return serial
......
...@@ -79,6 +79,9 @@ class CheckpointConfig(object): ...@@ -79,6 +79,9 @@ class CheckpointConfig(object):
else: else:
self.step_interval = step_interval self.step_interval = step_interval
self.epoch_id = 0
self.step_id = 0
def check_and_get_place(place): def check_and_get_place(place):
""" """
...@@ -132,6 +135,7 @@ class Trainer(object): ...@@ -132,6 +135,7 @@ class Trainer(object):
# config for checkpoint # config for checkpoint
# only chief worker will save variables # only chief worker will save variables
self.trainer_id = 0
self.chief = True self.chief = True
self.checkpoint = checkpoint_config self.checkpoint = checkpoint_config
if self.checkpoint and \ if self.checkpoint and \
...@@ -139,6 +143,8 @@ class Trainer(object): ...@@ -139,6 +143,8 @@ class Trainer(object):
raise TypeError( raise TypeError(
"The checkpoint_config shoule be an instance of CheckpointConfig" "The checkpoint_config shoule be an instance of CheckpointConfig"
) )
self.load_checkpoint_serial = io.need_load_checkpoint(
self.checkpoint.checkpoint_dir)
self.scope = core.Scope() self.scope = core.Scope()
...@@ -168,15 +174,25 @@ class Trainer(object): ...@@ -168,15 +174,25 @@ class Trainer(object):
exe = executor.Executor(place) exe = executor.Executor(place)
exe.run(self.startup_program) exe.run(self.startup_program)
if self.checkpoint: if self.load_checkpoint_serial:
exe = executor.Executor(place) exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.load_checkpoint_serial,
self.startup_program) self.startup_program)
if param_path: epoch_id, step_id = io.load_trainer_args(
self.checkpoint.checkpoint_dir, self.load_checkpoint_serial,
self.trainer_id, ["epoch_id", "step_id"])
self.checkpoint.epoch_id = int(epoch_id)
self.checkpoint.step_id = int(step_id)
if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
io.load_persist_vars_without_grad( io.load_persist_vars_without_grad(
exe, dirname=param_path, program=self.startup_program) exe,
dirname=param_path,
program=self.startup_program,
nest=False)
def _transpile_nccl2_dist(self): def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS # PADDLE_TRAINER_IPS
...@@ -333,11 +349,20 @@ class Trainer(object): ...@@ -333,11 +349,20 @@ class Trainer(object):
self._train_by_any_executor(event_handler, exe, num_epochs, reader) self._train_by_any_executor(event_handler, exe, num_epochs, reader)
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
for epoch_id in range(num_epochs): epochs = [
epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint.epoch_id
]
for epoch_id in epochs:
event_handler(BeginEpochEvent(epoch_id)) event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()): for step_id, data in enumerate(reader()):
if self.__stop: if self.__stop:
self._clean_checkpoint()
return return
if self.checkpoint and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id:
continue
begin_event = BeginStepEvent(epoch_id, step_id) begin_event = BeginStepEvent(epoch_id, step_id)
event_handler(begin_event) event_handler(begin_event)
if begin_event.fetch_metrics: if begin_event.fetch_metrics:
...@@ -352,6 +377,7 @@ class Trainer(object): ...@@ -352,6 +377,7 @@ class Trainer(object):
event_handler(EndStepEvent(epoch_id, step_id, metrics)) event_handler(EndStepEvent(epoch_id, step_id, metrics))
self._save_checkpoint(epoch_id, step_id) self._save_checkpoint(epoch_id, step_id)
event_handler(EndEpochEvent(epoch_id)) event_handler(EndEpochEvent(epoch_id))
self._clean_checkpoint()
def _test_by_executor(self, reader, feed_order, fetch_list): def _test_by_executor(self, reader, feed_order, fetch_list):
with executor.scope_guard(self.scope): with executor.scope_guard(self.scope):
...@@ -390,17 +416,29 @@ class Trainer(object): ...@@ -390,17 +416,29 @@ class Trainer(object):
loss_name=self.train_func_outputs[0].name) loss_name=self.train_func_outputs[0].name)
return self._get_parallel_executor() return self._get_parallel_executor()
def _clean_checkpoint(self):
if not self.checkpoint:
return
io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir)
def _save_checkpoint(self, epoch_id, step_id): def _save_checkpoint(self, epoch_id, step_id):
if not self.checkpoint or not self.chief: if not self.checkpoint:
return return
if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0: if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0:
trainer_args = {}
trainer_args["epoch_id"] = epoch_id
trainer_args["step_id"] = step_id
exe = executor.Executor(self.place) exe = executor.Executor(self.place)
io.save_checkpoint( io.save_checkpoint(
executor=exe, executor=exe,
checkpoint_dir=self.checkpoint.checkpoint_dir, checkpoint_dir=self.checkpoint.checkpoint_dir,
max_num_checkpoints=self.checkpoint.max_num_checkpoints, trainer_id=self.trainer_id,
main_program=self.train_program) is_chief=self.chief,
trainer_args=trainer_args,
main_program=self.train_program,
max_num_checkpoints=self.checkpoint.max_num_checkpoints)
def build_feed_var_list(program, feed_order): def build_feed_var_list(program, feed_order):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册