From 95545f7676cd37b39823c2bc4a5106997eaf61a9 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 9 Jul 2018 21:31:19 +0800 Subject: [PATCH] checkpoint api optimized --- python/paddle/fluid/io.py | 104 ++++++++++++++++++++------------- python/paddle/fluid/trainer.py | 63 +++++++++++++------- 2 files changed, 104 insertions(+), 63 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 5c8f4f6507c..72139f47b62 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -25,9 +25,7 @@ __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'load_persistables', 'save_inference_model', 'load_inference_model', 'get_inference_program', 'save_checkpoint', 'load_checkpoint', - 'clean_checkpoint', 'load_persist_vars_without_grad', - 'load_lookup_table_vars', 'save_persist_vars_without_grad', - 'get_latest_checkpoint_serial' + 'clean_checkpoint' ] @@ -805,11 +803,11 @@ CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, checkpoint_dir, trainer_id, + main_program, trainer_args=None, - main_program=None, max_num_checkpoints=3, lookup_table=None, - ps_endpoint_list=None): + pserver_endpoints=None): """ This function filters out all checkpoint variables from the give main_program and then saves these variables to the `checkpoint_dir` @@ -836,16 +834,16 @@ def save_checkpoint(executor, trainer_args(dict|None): Current training arguments. Such as 'epoch_id' and 'step_id'. Defaut: None - main_program(Program|None): The program whose checkpoint variables will - be saved. If it is None, the default main program will be used. + main_program(Program): The program whose checkpoint variables will + be saved. max_num_checkpoints(int): The max number of total number of existing checkpoints. Default: 3 lookup_table(string|None): the lookup table name, when use distribute lookup table, we can get lookup table name by DistributeTranspiler. table_name - ps_endpoint_list(list|None): the parameter server ip:port list. - when use distribute lookup table, we can get ps_endpoint_list by + pserver_endpoints(list|None): the parameter server ip:port list. + when use distribute lookup table, we can get pserver_endpoints by distribute arguments. Returns: @@ -873,11 +871,13 @@ def save_checkpoint(executor, main_program=prog, max_num_checkpoints=3, lookup_table=table_name, - ps_endpoint_list = ps_endpoints) + pserver_endpoints = ps_endpoints) """ if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") - assert checkpoint_dir + + if main_program is None: + raise ValueError('main_program should not be None.') if trainer_args: assert isinstance(trainer_args, dict) @@ -885,22 +885,28 @@ def save_checkpoint(executor, is_chief = trainer_id == 0 _make_chekcpoint_dirs(checkpoint_dir) - serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 + serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1 cur_dir = _get_serial_dir(checkpoint_dir, serial) - save_trainer_args(cur_dir, trainer_id, trainer_args) + _save_trainer_args(cur_dir, trainer_id, trainer_args) if is_chief: - save_persist_vars_without_grad(executor, cur_dir, main_program) + _save_persist_vars_without_grad(executor, cur_dir, main_program) - if is_chief and lookup_table and ps_endpoint_list: - save_pserver_vars_by_notify(executor, cur_dir, lookup_table, - ps_endpoint_list) + if is_chief and lookup_table and pserver_endpoints: + _save_pserver_vars_by_notify(executor, cur_dir, lookup_table, + pserver_endpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints) -def load_checkpoint(executor, checkpoint_dir, serial, main_program): +def load_checkpoint(executor, + checkpoint_dir, + main_program, + role_id=0, + is_trainer=True, + load_trainer_args=None, + load_lookup_table=None): """ This function filters out all checkpoint variables from the give main_program and then try to load these variables from the @@ -924,13 +930,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): serial(int): The serial of checkpoint you would like to load. main_program(Program): The program whose checkpoint variables will be loaded. + role_id(int): the trainer id or the parameter server id. + is_trainer(bool): trainer is True and parameter server is False. + load_trainer_args(list|None): list about load trainer args. + load_lookup_table(str|None): the lookup table name Returns: None Raises: ValueError: If `checkpoint_dir` is None. - ValueError: If `serial` is None or `serial` is less than 0. ValueError: If `main_program` is None. Examples: @@ -951,14 +960,27 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") + serial = _get_latest_checkpoint_serial(checkpoint_dir) + + # there are nothing need to be loaded if serial is None or serial < 0: - raise ValueError("'serial' should not be None or <0 ") + return if main_program is None: raise ValueError('main_program should not be None.') - cur_dir = _get_serial_dir(checkpoint_dir, serial) - load_persist_vars_without_grad(executor, cur_dir, main_program, True) + if is_trainer and load_trainer_args is None: + cur_dir = _get_serial_dir(checkpoint_dir, serial) + _load_persist_vars_without_grad(executor, cur_dir, main_program, True) + return + + if is_trainer and load_trainer_args: + return _load_trainer_args(checkpoint_dir, serial, role_id, + load_trainer_args) + + if not is_trainer and load_lookup_table: + _load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id, + load_lookup_table) def clean_checkpoint(checkpoint_dir, delete_dir=False): @@ -979,10 +1001,10 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): os.rmdir(checkpoint_dir) -def load_persist_vars_without_grad(executor, - dirname, - program, - has_model_dir=False): +def _load_persist_vars_without_grad(executor, + dirname, + program, + has_model_dir=False): """ This function filters out all checkpoint variables from the give program and then trys to load these variables from the given directory. @@ -1011,10 +1033,10 @@ def load_persist_vars_without_grad(executor, exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() - fluid.io.load_persist_vars_without_grad(executor=exe, + fluid.io._load_persist_vars_without_grad(executor=exe, dirname=param_path, program=prog, has_model_dir=True) - # In this example, `load_persist_vars_without_grad` function + # In this example, `_load_persist_vars_without_grad` function # will first filters out all checkpoint variables in the default # main program, and then trys to load these variables form the # folder "./my_paddle_model/__model__". @@ -1031,7 +1053,7 @@ def load_persist_vars_without_grad(executor, filename=None) -def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): +def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): """ The parameter server will load lookup table's local file in selectedrows variable. @@ -1050,11 +1072,11 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): .. code-block:: python exe = fluid.Executor(fluid.CPUPlace()) - dirname = "./checkpoints/checkpoint_9/__model__" + dirname = "./checkpoints/checkpoint_9/" prog = fluid.default_main_program() pserver_id = 1 table_name = "share_w" - fluid.io.load_lookup_table_vars(executor=exe, + fluid.io._load_lookup_table_vars(executor=exe, dirname=dirname, program=prog, pserver_id=pserver_id, table_name=table_name) """ @@ -1081,7 +1103,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): executor.run(load_prog) -def save_persist_vars_without_grad(executor, dirname, program): +def _save_persist_vars_without_grad(executor, dirname, program): """ This function filters out all checkpoint variables from the give program and then save these variables to a sub-folder '__model__' of @@ -1108,10 +1130,10 @@ def save_persist_vars_without_grad(executor, dirname, program): exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() - fluid.io.save_persist_vars_without_grad(executor=exe, + fluid.io._save_persist_vars_without_grad(executor=exe, dirname=param_path, program=prog) - # In this example, `save_persist_vars_without_grad` function + # In this example, `_save_persist_vars_without_grad` function # will first filters out all checkpoint variables in the default # main program, and then saves these variables to the folder # "./my_paddle_model/__model__". @@ -1127,8 +1149,8 @@ def save_persist_vars_without_grad(executor, dirname, program): _write_success(cur_dir) -def save_pserver_vars_by_notify(executor, dirname, lookup_table, - ps_endpoint_list): +def _save_pserver_vars_by_notify(executor, dirname, lookup_table, + ps_endpoint_list): """ This function will send checkpoint notify message from Trainer 0 to all the pservers. @@ -1156,7 +1178,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table, table_name = "share_w" ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] - fluid.io.save_pserver_vars_by_notify(executor=exe, + fluid.io._save_pserver_vars_by_notify(executor=exe, dirname=param_path, lookup_table=table_name, ps_endpoint_list=ps_endpoints) """ @@ -1175,7 +1197,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table, executor.run(checkpoint_notify_program) -def save_trainer_args(dirname, trainer_id, trainer_args): +def _save_trainer_args(dirname, trainer_id, trainer_args): assert isinstance(trainer_args, dict) cur_dir = _get_trainer_dir(dirname, trainer_id) @@ -1187,7 +1209,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args): _write_success(cur_dir) -def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): +def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): """ trainer will load some args from it's independent directory, such as epoch_id and step_id. @@ -1208,7 +1230,7 @@ def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): trainer_id = 2 trainer_args = ["epoch_id", "step_id"] - fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial, + fluid.io._load_trainer_args(checkpoint_dir=param_path, serial=serial, trainer_id=trainer_id, trainer_args=trainer_args) """ assert isinstance(trainer_args, list) @@ -1339,7 +1361,7 @@ def _write_success(dirname): f.write(now) -def get_latest_checkpoint_serial(checkpoint_dir): +def _get_latest_checkpoint_serial(checkpoint_dir): """ get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index b6e0241265b..3eaf687cf9f 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -277,31 +277,14 @@ class Trainer(object): exe.run(self.startup_program) if self.checkpoint_cfg and self.checkpoint_cfg.load_serial: - with self._prog_and_scope_guard(): - exe = executor.Executor(place) - io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir, - self.checkpoint_cfg.load_serial, - self.startup_program) - - if not self.checkpoint_cfg.pserver_id: - epoch_id, step_id = io.load_trainer_args( - self.checkpoint_cfg.checkpoint_dir, - self.checkpoint_cfg.load_serial, self.trainer_id, - self._get_checkpoint_load_args()) - self.checkpoint_cfg.epoch_id = int(epoch_id) - self.checkpoint_cfg.step_id = int(step_id) - else: - if self.checkpoint_cfg.lookup_table_name: - io.load_lookup_table_vars( - exe, self.checkpoint_cfg.checkpoint_dir, - self.startup_program, - self.checkpoint_cfg.pserver_id, - self.checkpoint_cfg.lookup_table_name) + self._load_checkpoint() if param_path and os.path.isdir(param_path): # load params from param_path into scope - io.load_persist_vars_without_grad( - exe, dirname=param_path, program=self.startup_program) + io.load_persistables( + executor=exe, + dirname=param_path, + main_program=self.startup_program) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS @@ -580,6 +563,42 @@ class Trainer(object): main_program=self.train_program, max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) + def _load_checkpoint(self): + with self._prog_and_scope_guard(): + exe = executor.Executor(self.place) + io.load_checkpoint( + executor=exe, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + main_program=self.startup_program) + + if not self.checkpoint_cfg.pserver_id: + load_trainer_args = self._get_checkpoint_load_args() + trainer_args = io.load_checkpoint( + executor=exe, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + main_program=self.startup_program, + role_id=self.trainer_id, + is_trainer=True, + load_trainer_args=load_trainer_args) + + if len(trainer_args) != 2: + raise ValueError( + "the return trainer_args length do not equal _get_checkpoint_load_args" + ) + + self.checkpoint_cfg.epoch_id = int(trainer_args[0]) + self.checkpoint_cfg.step_id = int(trainer_args[1]) + else: + if self.checkpoint_cfg.lookup_table_name: + io.load_checkpoint( + executor=exe, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + main_program=self.startup_program, + role_id=self.checkpoint_cfg.pserver_id, + is_trainer=False, + load_trainer_args=None, + load_lookup_table=self.checkpoint_cfg.lookup_table_name) + def build_feed_var_list(program, feed_order): if not isinstance(program, framework.Program): -- GitLab