提交 95545f76 编写于 作者: T tangwei12

checkpoint api optimized

上级 436bb450
...@@ -25,9 +25,7 @@ __all__ = [ ...@@ -25,9 +25,7 @@ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model', 'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint', 'load_persist_vars_without_grad', 'clean_checkpoint'
'load_lookup_table_vars', 'save_persist_vars_without_grad',
'get_latest_checkpoint_serial'
] ]
...@@ -805,11 +803,11 @@ CHECKPOINT_SEPARATOR = "_" ...@@ -805,11 +803,11 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor, def save_checkpoint(executor,
checkpoint_dir, checkpoint_dir,
trainer_id, trainer_id,
main_program,
trainer_args=None, trainer_args=None,
main_program=None,
max_num_checkpoints=3, max_num_checkpoints=3,
lookup_table=None, lookup_table=None,
ps_endpoint_list=None): pserver_endpoints=None):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir` main_program and then saves these variables to the `checkpoint_dir`
...@@ -836,16 +834,16 @@ def save_checkpoint(executor, ...@@ -836,16 +834,16 @@ def save_checkpoint(executor,
trainer_args(dict|None): Current training arguments. Such as 'epoch_id' trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'. and 'step_id'.
Defaut: None Defaut: None
main_program(Program|None): The program whose checkpoint variables will main_program(Program): The program whose checkpoint variables will
be saved. If it is None, the default main program will be used. be saved.
max_num_checkpoints(int): The max number of total number of existing max_num_checkpoints(int): The max number of total number of existing
checkpoints. checkpoints.
Default: 3 Default: 3
lookup_table(string|None): the lookup table name, when use distribute lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler. lookup table, we can get lookup table name by DistributeTranspiler.
table_name table_name
ps_endpoint_list(list|None): the parameter server ip:port list. pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by when use distribute lookup table, we can get pserver_endpoints by
distribute arguments. distribute arguments.
Returns: Returns:
...@@ -873,11 +871,13 @@ def save_checkpoint(executor, ...@@ -873,11 +871,13 @@ def save_checkpoint(executor,
main_program=prog, main_program=prog,
max_num_checkpoints=3, max_num_checkpoints=3,
lookup_table=table_name, lookup_table=table_name,
ps_endpoint_list = ps_endpoints) pserver_endpoints = ps_endpoints)
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
assert checkpoint_dir
if main_program is None:
raise ValueError('main_program should not be None.')
if trainer_args: if trainer_args:
assert isinstance(trainer_args, dict) assert isinstance(trainer_args, dict)
...@@ -885,22 +885,28 @@ def save_checkpoint(executor, ...@@ -885,22 +885,28 @@ def save_checkpoint(executor,
is_chief = trainer_id == 0 is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir) _make_chekcpoint_dirs(checkpoint_dir)
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args) _save_trainer_args(cur_dir, trainer_id, trainer_args)
if is_chief: if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program) _save_persist_vars_without_grad(executor, cur_dir, main_program)
if is_chief and lookup_table and ps_endpoint_list: if is_chief and lookup_table and pserver_endpoints:
save_pserver_vars_by_notify(executor, cur_dir, lookup_table, _save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
ps_endpoint_list) pserver_endpoints)
_scroll_delete(checkpoint_dir, max_num_checkpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir, serial, main_program): def load_checkpoint(executor,
checkpoint_dir,
main_program,
role_id=0,
is_trainer=True,
load_trainer_args=None,
load_lookup_table=None):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the main_program and then try to load these variables from the
...@@ -924,13 +930,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -924,13 +930,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
serial(int): The serial of checkpoint you would like to load. serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will main_program(Program): The program whose checkpoint variables will
be loaded. be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
load_trainer_args(list|None): list about load trainer args.
load_lookup_table(str|None): the lookup table name
Returns: Returns:
None None
Raises: Raises:
ValueError: If `checkpoint_dir` is None. ValueError: If `checkpoint_dir` is None.
ValueError: If `serial` is None or `serial` is less than 0.
ValueError: If `main_program` is None. ValueError: If `main_program` is None.
Examples: Examples:
...@@ -951,14 +960,27 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -951,14 +960,27 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
serial = _get_latest_checkpoint_serial(checkpoint_dir)
# there are nothing need to be loaded
if serial is None or serial < 0: if serial is None or serial < 0:
raise ValueError("'serial' should not be None or <0 ") return
if main_program is None: if main_program is None:
raise ValueError('main_program should not be None.') raise ValueError('main_program should not be None.')
cur_dir = _get_serial_dir(checkpoint_dir, serial) if is_trainer and load_trainer_args is None:
load_persist_vars_without_grad(executor, cur_dir, main_program, True) cur_dir = _get_serial_dir(checkpoint_dir, serial)
_load_persist_vars_without_grad(executor, cur_dir, main_program, True)
return
if is_trainer and load_trainer_args:
return _load_trainer_args(checkpoint_dir, serial, role_id,
load_trainer_args)
if not is_trainer and load_lookup_table:
_load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id,
load_lookup_table)
def clean_checkpoint(checkpoint_dir, delete_dir=False): def clean_checkpoint(checkpoint_dir, delete_dir=False):
...@@ -979,10 +1001,10 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): ...@@ -979,10 +1001,10 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os.rmdir(checkpoint_dir) os.rmdir(checkpoint_dir)
def load_persist_vars_without_grad(executor, def _load_persist_vars_without_grad(executor,
dirname, dirname,
program, program,
has_model_dir=False): has_model_dir=False):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory. program and then trys to load these variables from the given directory.
...@@ -1011,10 +1033,10 @@ def load_persist_vars_without_grad(executor, ...@@ -1011,10 +1033,10 @@ def load_persist_vars_without_grad(executor,
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model" param_path = "./my_paddle_model"
prog = fluid.default_main_program() prog = fluid.default_main_program()
fluid.io.load_persist_vars_without_grad(executor=exe, fluid.io._load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True) dirname=param_path, program=prog, has_model_dir=True)
# In this example, `load_persist_vars_without_grad` function # In this example, `_load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default # will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the # main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__". # folder "./my_paddle_model/__model__".
...@@ -1031,7 +1053,7 @@ def load_persist_vars_without_grad(executor, ...@@ -1031,7 +1053,7 @@ def load_persist_vars_without_grad(executor,
filename=None) filename=None)
def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
""" """
The parameter server will load lookup table's local file in The parameter server will load lookup table's local file in
selectedrows variable. selectedrows variable.
...@@ -1050,11 +1072,11 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): ...@@ -1050,11 +1072,11 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
.. code-block:: python .. code-block:: python
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/__model__" dirname = "./checkpoints/checkpoint_9/"
prog = fluid.default_main_program() prog = fluid.default_main_program()
pserver_id = 1 pserver_id = 1
table_name = "share_w" table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe, fluid.io._load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id, dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name) table_name=table_name)
""" """
...@@ -1081,7 +1103,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): ...@@ -1081,7 +1103,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
executor.run(load_prog) executor.run(load_prog)
def save_persist_vars_without_grad(executor, dirname, program): def _save_persist_vars_without_grad(executor, dirname, program):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of program and then save these variables to a sub-folder '__model__' of
...@@ -1108,10 +1130,10 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -1108,10 +1130,10 @@ def save_persist_vars_without_grad(executor, dirname, program):
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model" param_path = "./my_paddle_model"
prog = fluid.default_main_program() prog = fluid.default_main_program()
fluid.io.save_persist_vars_without_grad(executor=exe, fluid.io._save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog) dirname=param_path, program=prog)
# In this example, `save_persist_vars_without_grad` function # In this example, `_save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default # will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder # main program, and then saves these variables to the folder
# "./my_paddle_model/__model__". # "./my_paddle_model/__model__".
...@@ -1127,8 +1149,8 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -1127,8 +1149,8 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir) _write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, lookup_table, def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoint_list): ps_endpoint_list):
""" """
This function will send checkpoint notify message from Trainer 0 This function will send checkpoint notify message from Trainer 0
to all the pservers. to all the pservers.
...@@ -1156,7 +1178,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table, ...@@ -1156,7 +1178,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table,
table_name = "share_w" table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe, fluid.io._save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name, dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints) ps_endpoint_list=ps_endpoints)
""" """
...@@ -1175,7 +1197,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table, ...@@ -1175,7 +1197,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table,
executor.run(checkpoint_notify_program) executor.run(checkpoint_notify_program)
def save_trainer_args(dirname, trainer_id, trainer_args): def _save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict) assert isinstance(trainer_args, dict)
cur_dir = _get_trainer_dir(dirname, trainer_id) cur_dir = _get_trainer_dir(dirname, trainer_id)
...@@ -1187,7 +1209,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args): ...@@ -1187,7 +1209,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
_write_success(cur_dir) _write_success(cur_dir)
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
""" """
trainer will load some args from it's independent directory, trainer will load some args from it's independent directory,
such as epoch_id and step_id. such as epoch_id and step_id.
...@@ -1208,7 +1230,7 @@ def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): ...@@ -1208,7 +1230,7 @@ def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
trainer_id = 2 trainer_id = 2
trainer_args = ["epoch_id", "step_id"] trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial, fluid.io._load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args) trainer_id=trainer_id, trainer_args=trainer_args)
""" """
assert isinstance(trainer_args, list) assert isinstance(trainer_args, list)
...@@ -1339,7 +1361,7 @@ def _write_success(dirname): ...@@ -1339,7 +1361,7 @@ def _write_success(dirname):
f.write(now) f.write(now)
def get_latest_checkpoint_serial(checkpoint_dir): def _get_latest_checkpoint_serial(checkpoint_dir):
""" """
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
......
...@@ -277,31 +277,14 @@ class Trainer(object): ...@@ -277,31 +277,14 @@ class Trainer(object):
exe.run(self.startup_program) exe.run(self.startup_program)
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial: if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
with self._prog_and_scope_guard(): self._load_checkpoint()
exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial,
self.startup_program)
if not self.checkpoint_cfg.pserver_id:
epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_lookup_table_vars(
exe, self.checkpoint_cfg.checkpoint_dir,
self.startup_program,
self.checkpoint_cfg.pserver_id,
self.checkpoint_cfg.lookup_table_name)
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
io.load_persist_vars_without_grad( io.load_persistables(
exe, dirname=param_path, program=self.startup_program) executor=exe,
dirname=param_path,
main_program=self.startup_program)
def _transpile_nccl2_dist(self): def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS # PADDLE_TRAINER_IPS
...@@ -580,6 +563,42 @@ class Trainer(object): ...@@ -580,6 +563,42 @@ class Trainer(object):
main_program=self.train_program, main_program=self.train_program,
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
def _load_checkpoint(self):
with self._prog_and_scope_guard():
exe = executor.Executor(self.place)
io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program)
if not self.checkpoint_cfg.pserver_id:
load_trainer_args = self._get_checkpoint_load_args()
trainer_args = io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.trainer_id,
is_trainer=True,
load_trainer_args=load_trainer_args)
if len(trainer_args) != 2:
raise ValueError(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self.checkpoint_cfg.epoch_id = int(trainer_args[0])
self.checkpoint_cfg.step_id = int(trainer_args[1])
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.checkpoint_cfg.pserver_id,
is_trainer=False,
load_trainer_args=None,
load_lookup_table=self.checkpoint_cfg.lookup_table_name)
def build_feed_var_list(program, feed_order): def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program): if not isinstance(program, framework.Program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册