提交 95545f76 编写于 作者: T tangwei12

checkpoint api optimized

上级 436bb450
......@@ -25,9 +25,7 @@ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint', 'load_persist_vars_without_grad',
'load_lookup_table_vars', 'save_persist_vars_without_grad',
'get_latest_checkpoint_serial'
'clean_checkpoint'
]
......@@ -805,11 +803,11 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
checkpoint_dir,
trainer_id,
main_program,
trainer_args=None,
main_program=None,
max_num_checkpoints=3,
lookup_table=None,
ps_endpoint_list=None):
pserver_endpoints=None):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
......@@ -836,16 +834,16 @@ def save_checkpoint(executor,
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program|None): The program whose checkpoint variables will
be saved. If it is None, the default main program will be used.
main_program(Program): The program whose checkpoint variables will
be saved.
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get pserver_endpoints by
distribute arguments.
Returns:
......@@ -873,11 +871,13 @@ def save_checkpoint(executor,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
ps_endpoint_list = ps_endpoints)
pserver_endpoints = ps_endpoints)
"""
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
assert checkpoint_dir
if main_program is None:
raise ValueError('main_program should not be None.')
if trainer_args:
assert isinstance(trainer_args, dict)
......@@ -885,22 +885,28 @@ def save_checkpoint(executor,
is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir)
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args)
_save_trainer_args(cur_dir, trainer_id, trainer_args)
if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program)
_save_persist_vars_without_grad(executor, cur_dir, main_program)
if is_chief and lookup_table and ps_endpoint_list:
save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
ps_endpoint_list)
if is_chief and lookup_table and pserver_endpoints:
_save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
pserver_endpoints)
_scroll_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
def load_checkpoint(executor,
checkpoint_dir,
main_program,
role_id=0,
is_trainer=True,
load_trainer_args=None,
load_lookup_table=None):
"""
This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the
......@@ -924,13 +930,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
load_trainer_args(list|None): list about load trainer args.
load_lookup_table(str|None): the lookup table name
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
ValueError: If `serial` is None or `serial` is less than 0.
ValueError: If `main_program` is None.
Examples:
......@@ -951,14 +960,27 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
serial = _get_latest_checkpoint_serial(checkpoint_dir)
# there are nothing need to be loaded
if serial is None or serial < 0:
raise ValueError("'serial' should not be None or <0 ")
return
if main_program is None:
raise ValueError('main_program should not be None.')
if is_trainer and load_trainer_args is None:
cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program, True)
_load_persist_vars_without_grad(executor, cur_dir, main_program, True)
return
if is_trainer and load_trainer_args:
return _load_trainer_args(checkpoint_dir, serial, role_id,
load_trainer_args)
if not is_trainer and load_lookup_table:
_load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id,
load_lookup_table)
def clean_checkpoint(checkpoint_dir, delete_dir=False):
......@@ -979,7 +1001,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os.rmdir(checkpoint_dir)
def load_persist_vars_without_grad(executor,
def _load_persist_vars_without_grad(executor,
dirname,
program,
has_model_dir=False):
......@@ -1011,10 +1033,10 @@ def load_persist_vars_without_grad(executor,
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
fluid.io.load_persist_vars_without_grad(executor=exe,
fluid.io._load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `load_persist_vars_without_grad` function
# In this example, `_load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
......@@ -1031,7 +1053,7 @@ def load_persist_vars_without_grad(executor,
filename=None)
def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
......@@ -1050,11 +1072,11 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/__model__"
dirname = "./checkpoints/checkpoint_9/"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe,
fluid.io._load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
......@@ -1081,7 +1103,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
executor.run(load_prog)
def save_persist_vars_without_grad(executor, dirname, program):
def _save_persist_vars_without_grad(executor, dirname, program):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
......@@ -1108,10 +1130,10 @@ def save_persist_vars_without_grad(executor, dirname, program):
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
fluid.io.save_persist_vars_without_grad(executor=exe,
fluid.io._save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog)
# In this example, `save_persist_vars_without_grad` function
# In this example, `_save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
......@@ -1127,7 +1149,7 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, lookup_table,
def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoint_list):
"""
This function will send checkpoint notify message from Trainer 0
......@@ -1156,7 +1178,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table,
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe,
fluid.io._save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
......@@ -1175,7 +1197,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table,
executor.run(checkpoint_notify_program)
def save_trainer_args(dirname, trainer_id, trainer_args):
def _save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict)
cur_dir = _get_trainer_dir(dirname, trainer_id)
......@@ -1187,7 +1209,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
_write_success(cur_dir)
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
......@@ -1208,7 +1230,7 @@ def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial,
fluid.io._load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert isinstance(trainer_args, list)
......@@ -1339,7 +1361,7 @@ def _write_success(dirname):
f.write(now)
def get_latest_checkpoint_serial(checkpoint_dir):
def _get_latest_checkpoint_serial(checkpoint_dir):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
......
......@@ -277,31 +277,14 @@ class Trainer(object):
exe.run(self.startup_program)
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
with self._prog_and_scope_guard():
exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial,
self.startup_program)
if not self.checkpoint_cfg.pserver_id:
epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_lookup_table_vars(
exe, self.checkpoint_cfg.checkpoint_dir,
self.startup_program,
self.checkpoint_cfg.pserver_id,
self.checkpoint_cfg.lookup_table_name)
self._load_checkpoint()
if param_path and os.path.isdir(param_path):
# load params from param_path into scope
io.load_persist_vars_without_grad(
exe, dirname=param_path, program=self.startup_program)
io.load_persistables(
executor=exe,
dirname=param_path,
main_program=self.startup_program)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
......@@ -580,6 +563,42 @@ class Trainer(object):
main_program=self.train_program,
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
def _load_checkpoint(self):
with self._prog_and_scope_guard():
exe = executor.Executor(self.place)
io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program)
if not self.checkpoint_cfg.pserver_id:
load_trainer_args = self._get_checkpoint_load_args()
trainer_args = io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.trainer_id,
is_trainer=True,
load_trainer_args=load_trainer_args)
if len(trainer_args) != 2:
raise ValueError(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self.checkpoint_cfg.epoch_id = int(trainer_args[0])
self.checkpoint_cfg.step_id = int(trainer_args[1])
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.checkpoint_cfg.pserver_id,
is_trainer=False,
load_trainer_args=None,
load_lookup_table=self.checkpoint_cfg.lookup_table_name)
def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册