提交 95545f76 编写于 作者: T tangwei12

checkpoint api optimized

上级 436bb450
...@@ -25,9 +25,7 @@ __all__ = [ ...@@ -25,9 +25,7 @@ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model', 'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint', 'load_persist_vars_without_grad', 'clean_checkpoint'
'load_lookup_table_vars', 'save_persist_vars_without_grad',
'get_latest_checkpoint_serial'
] ]
...@@ -805,11 +803,11 @@ CHECKPOINT_SEPARATOR = "_" ...@@ -805,11 +803,11 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor, def save_checkpoint(executor,
checkpoint_dir, checkpoint_dir,
trainer_id, trainer_id,
main_program,
trainer_args=None, trainer_args=None,
main_program=None,
max_num_checkpoints=3, max_num_checkpoints=3,
lookup_table=None, lookup_table=None,
ps_endpoint_list=None): pserver_endpoints=None):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir` main_program and then saves these variables to the `checkpoint_dir`
...@@ -836,16 +834,16 @@ def save_checkpoint(executor, ...@@ -836,16 +834,16 @@ def save_checkpoint(executor,
trainer_args(dict|None): Current training arguments. Such as 'epoch_id' trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'. and 'step_id'.
Defaut: None Defaut: None
main_program(Program|None): The program whose checkpoint variables will main_program(Program): The program whose checkpoint variables will
be saved. If it is None, the default main program will be used. be saved.
max_num_checkpoints(int): The max number of total number of existing max_num_checkpoints(int): The max number of total number of existing
checkpoints. checkpoints.
Default: 3 Default: 3
lookup_table(string|None): the lookup table name, when use distribute lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler. lookup table, we can get lookup table name by DistributeTranspiler.
table_name table_name
ps_endpoint_list(list|None): the parameter server ip:port list. pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by when use distribute lookup table, we can get pserver_endpoints by
distribute arguments. distribute arguments.
Returns: Returns:
...@@ -873,11 +871,13 @@ def save_checkpoint(executor, ...@@ -873,11 +871,13 @@ def save_checkpoint(executor,
main_program=prog, main_program=prog,
max_num_checkpoints=3, max_num_checkpoints=3,
lookup_table=table_name, lookup_table=table_name,
ps_endpoint_list = ps_endpoints) pserver_endpoints = ps_endpoints)
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
assert checkpoint_dir
if main_program is None:
raise ValueError('main_program should not be None.')
if trainer_args: if trainer_args:
assert isinstance(trainer_args, dict) assert isinstance(trainer_args, dict)
...@@ -885,22 +885,28 @@ def save_checkpoint(executor, ...@@ -885,22 +885,28 @@ def save_checkpoint(executor,
is_chief = trainer_id == 0 is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir) _make_chekcpoint_dirs(checkpoint_dir)
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args) _save_trainer_args(cur_dir, trainer_id, trainer_args)
if is_chief: if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program) _save_persist_vars_without_grad(executor, cur_dir, main_program)
if is_chief and lookup_table and ps_endpoint_list: if is_chief and lookup_table and pserver_endpoints:
save_pserver_vars_by_notify(executor, cur_dir, lookup_table, _save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
ps_endpoint_list) pserver_endpoints)
_scroll_delete(checkpoint_dir, max_num_checkpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir, serial, main_program): def load_checkpoint(executor,
checkpoint_dir,
main_program,
role_id=0,
is_trainer=True,
load_trainer_args=None,
load_lookup_table=None):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the main_program and then try to load these variables from the
...@@ -924,13 +930,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -924,13 +930,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
serial(int): The serial of checkpoint you would like to load. serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will main_program(Program): The program whose checkpoint variables will
be loaded. be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
load_trainer_args(list|None): list about load trainer args.
load_lookup_table(str|None): the lookup table name
Returns: Returns:
None None
Raises: Raises:
ValueError: If `checkpoint_dir` is None. ValueError: If `checkpoint_dir` is None.
ValueError: If `serial` is None or `serial` is less than 0.
ValueError: If `main_program` is None. ValueError: If `main_program` is None.
Examples: Examples:
...@@ -951,14 +960,27 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -951,14 +960,27 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
serial = _get_latest_checkpoint_serial(checkpoint_dir)
# there are nothing need to be loaded
if serial is None or serial < 0: if serial is None or serial < 0:
raise ValueError("'serial' should not be None or <0 ") return
if main_program is None: if main_program is None:
raise ValueError('main_program should not be None.') raise ValueError('main_program should not be None.')
if is_trainer and load_trainer_args is None:
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program, True) _load_persist_vars_without_grad(executor, cur_dir, main_program, True)
return
if is_trainer and load_trainer_args:
return _load_trainer_args(checkpoint_dir, serial, role_id,
load_trainer_args)
if not is_trainer and load_lookup_table:
_load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id,
load_lookup_table)
def clean_checkpoint(checkpoint_dir, delete_dir=False): def clean_checkpoint(checkpoint_dir, delete_dir=False):
...@@ -979,7 +1001,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): ...@@ -979,7 +1001,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os.rmdir(checkpoint_dir) os.rmdir(checkpoint_dir)
def load_persist_vars_without_grad(executor, def _load_persist_vars_without_grad(executor,
dirname, dirname,
program, program,
has_model_dir=False): has_model_dir=False):
...@@ -1011,10 +1033,10 @@ def load_persist_vars_without_grad(executor, ...@@ -1011,10 +1033,10 @@ def load_persist_vars_without_grad(executor,
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model" param_path = "./my_paddle_model"
prog = fluid.default_main_program() prog = fluid.default_main_program()
fluid.io.load_persist_vars_without_grad(executor=exe, fluid.io._load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True) dirname=param_path, program=prog, has_model_dir=True)
# In this example, `load_persist_vars_without_grad` function # In this example, `_load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default # will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the # main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__". # folder "./my_paddle_model/__model__".
...@@ -1031,7 +1053,7 @@ def load_persist_vars_without_grad(executor, ...@@ -1031,7 +1053,7 @@ def load_persist_vars_without_grad(executor,
filename=None) filename=None)
def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
""" """
The parameter server will load lookup table's local file in The parameter server will load lookup table's local file in
selectedrows variable. selectedrows variable.
...@@ -1050,11 +1072,11 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): ...@@ -1050,11 +1072,11 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
.. code-block:: python .. code-block:: python
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/__model__" dirname = "./checkpoints/checkpoint_9/"
prog = fluid.default_main_program() prog = fluid.default_main_program()
pserver_id = 1 pserver_id = 1
table_name = "share_w" table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe, fluid.io._load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id, dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name) table_name=table_name)
""" """
...@@ -1081,7 +1103,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): ...@@ -1081,7 +1103,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
executor.run(load_prog) executor.run(load_prog)
def save_persist_vars_without_grad(executor, dirname, program): def _save_persist_vars_without_grad(executor, dirname, program):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of program and then save these variables to a sub-folder '__model__' of
...@@ -1108,10 +1130,10 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -1108,10 +1130,10 @@ def save_persist_vars_without_grad(executor, dirname, program):
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model" param_path = "./my_paddle_model"
prog = fluid.default_main_program() prog = fluid.default_main_program()
fluid.io.save_persist_vars_without_grad(executor=exe, fluid.io._save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog) dirname=param_path, program=prog)
# In this example, `save_persist_vars_without_grad` function # In this example, `_save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default # will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder # main program, and then saves these variables to the folder
# "./my_paddle_model/__model__". # "./my_paddle_model/__model__".
...@@ -1127,7 +1149,7 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -1127,7 +1149,7 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir) _write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, lookup_table, def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoint_list): ps_endpoint_list):
""" """
This function will send checkpoint notify message from Trainer 0 This function will send checkpoint notify message from Trainer 0
...@@ -1156,7 +1178,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table, ...@@ -1156,7 +1178,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table,
table_name = "share_w" table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe, fluid.io._save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name, dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints) ps_endpoint_list=ps_endpoints)
""" """
...@@ -1175,7 +1197,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table, ...@@ -1175,7 +1197,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table,
executor.run(checkpoint_notify_program) executor.run(checkpoint_notify_program)
def save_trainer_args(dirname, trainer_id, trainer_args): def _save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict) assert isinstance(trainer_args, dict)
cur_dir = _get_trainer_dir(dirname, trainer_id) cur_dir = _get_trainer_dir(dirname, trainer_id)
...@@ -1187,7 +1209,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args): ...@@ -1187,7 +1209,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
_write_success(cur_dir) _write_success(cur_dir)
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
""" """
trainer will load some args from it's independent directory, trainer will load some args from it's independent directory,
such as epoch_id and step_id. such as epoch_id and step_id.
...@@ -1208,7 +1230,7 @@ def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): ...@@ -1208,7 +1230,7 @@ def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
trainer_id = 2 trainer_id = 2
trainer_args = ["epoch_id", "step_id"] trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial, fluid.io._load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args) trainer_id=trainer_id, trainer_args=trainer_args)
""" """
assert isinstance(trainer_args, list) assert isinstance(trainer_args, list)
...@@ -1339,7 +1361,7 @@ def _write_success(dirname): ...@@ -1339,7 +1361,7 @@ def _write_success(dirname):
f.write(now) f.write(now)
def get_latest_checkpoint_serial(checkpoint_dir): def _get_latest_checkpoint_serial(checkpoint_dir):
""" """
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
......
...@@ -277,31 +277,14 @@ class Trainer(object): ...@@ -277,31 +277,14 @@ class Trainer(object):
exe.run(self.startup_program) exe.run(self.startup_program)
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial: if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
with self._prog_and_scope_guard(): self._load_checkpoint()
exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial,
self.startup_program)
if not self.checkpoint_cfg.pserver_id:
epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_lookup_table_vars(
exe, self.checkpoint_cfg.checkpoint_dir,
self.startup_program,
self.checkpoint_cfg.pserver_id,
self.checkpoint_cfg.lookup_table_name)
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
io.load_persist_vars_without_grad( io.load_persistables(
exe, dirname=param_path, program=self.startup_program) executor=exe,
dirname=param_path,
main_program=self.startup_program)
def _transpile_nccl2_dist(self): def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS # PADDLE_TRAINER_IPS
...@@ -580,6 +563,42 @@ class Trainer(object): ...@@ -580,6 +563,42 @@ class Trainer(object):
main_program=self.train_program, main_program=self.train_program,
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
def _load_checkpoint(self):
with self._prog_and_scope_guard():
exe = executor.Executor(self.place)
io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program)
if not self.checkpoint_cfg.pserver_id:
load_trainer_args = self._get_checkpoint_load_args()
trainer_args = io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.trainer_id,
is_trainer=True,
load_trainer_args=load_trainer_args)
if len(trainer_args) != 2:
raise ValueError(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self.checkpoint_cfg.epoch_id = int(trainer_args[0])
self.checkpoint_cfg.step_id = int(trainer_args[1])
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.checkpoint_cfg.pserver_id,
is_trainer=False,
load_trainer_args=None,
load_lookup_table=self.checkpoint_cfg.lookup_table_name)
def build_feed_var_list(program, feed_order): def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program): if not isinstance(program, framework.Program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册