diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 573e0cdab68598ab2c1e12747ab9b712e8561131..2285df4c63b3a16d8439e56a14e8243e09617b6b 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -73,7 +73,7 @@ class BeginStepEvent(object): self.step = step_id self.fetch_metrics = True """ - If fetch_metrics is true, the metrics will be fetched at the + If fetch_metrics is true, the metrics will be fetched at the EndStepEvent. Default is True. """ @@ -560,6 +560,9 @@ class Trainer(object): if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \ and step_id % self.checkpoint_cfg.step_interval == 0: + + print("_save_checkpoint ...") + exe = executor.Executor(self.place) save_checkpoint( executor=exe, @@ -604,7 +607,7 @@ class Trainer(object): self.checkpoint_cfg.epoch_id = int(trainer_args_ret[0]) self.checkpoint_cfg.step_id = int(trainer_args_ret[1]) - # Pserver Load + # Pserver Load else: # load slice_vars if self.slice_vars != None and len(self.slice_vars) != 0: @@ -661,22 +664,22 @@ CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, checkpoint_dir, - trainer_id, - main_program, - trainer_args=None, - max_num_checkpoints=3, + main_program=None, + trainer_id=0, + save_trainer_args=None, save_lookup_table=None, - pserver_endpoints=None): + pserver_endpoints=None, + max_num_checkpoints=3): """ This function filters out all checkpoint variables from the give - main_program and then saves these variables to the `checkpoint_dir` + main_program and then saves these variables to the `checkpoint_dir` directory. In the training precess, we generally save a checkpoint in each - iteration. So there might be a lot of checkpoints in the - `checkpoint_dir`. To avoid them taking too much disk space, the - `max_num_checkpoints` are introduced to limit the total number of - checkpoints. If the number of existing checkpints is greater than + iteration. So there might be a lot of checkpoints in the + `checkpoint_dir`. To avoid them taking too much disk space, the + `max_num_checkpoints` are introduced to limit the total number of + checkpoints. If the number of existing checkpints is greater than the `max_num_checkpoints`, oldest ones will be scroll deleted. A variable is a checkpoint variable and will be saved if it meets @@ -688,21 +691,21 @@ def save_checkpoint(executor, Args: executor(Executor): The executor to run for save checkpoint. checkpoint_dir(str): The folder where to save checkpoints. - trainer_id(int): currect trainer id, if id is equal to 0, the trainer + trainer_id(int): currect trainer id, if id is equal to 0, the trainer is chief. - trainer_args(dict|None): Current training arguments. Such as 'epoch_id' + trainer_args(dict|None): Current training arguments. Such as 'epoch_id' and 'step_id'. Defaut: None main_program(Program): The program whose checkpoint variables will be saved. - max_num_checkpoints(int): The max number of total number of existing + max_num_checkpoints(int): The max number of total number of existing checkpoints. Default: 3 save_lookup_table(string|None): the lookup table name, when use distribute lookup table, we can get lookup table name by DistributeTranspiler. - table_name - pserver_endpoints(list|None): the parameter server ip:port list. - when use distribute lookup table, we can get pserver_endpoints by + table_name + pserver_endpoints(list|None): the parameter server ip:port list. + when use distribute lookup table, we can get pserver_endpoints by distribute arguments. Returns: @@ -735,21 +738,18 @@ def save_checkpoint(executor, if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") - if main_program is None: - raise ValueError('main_program should not be None.') - - if trainer_args: - assert isinstance(trainer_args, dict) - - is_chief = trainer_id == 0 - _make_chekcpoint_dirs(checkpoint_dir) serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1 cur_dir = _get_serial_dir(checkpoint_dir, serial, True) - _save_trainer_args(cur_dir, trainer_id, trainer_args) + is_chief = trainer_id == 0 + + if save_trainer_args is not None: + _save_trainer_args(cur_dir, trainer_id, save_trainer_args) if is_chief: + if main_program is None: + raise ValueError('main_program should not be None.') _save_persistable_vars(executor, cur_dir, main_program) if is_chief and save_lookup_table and pserver_endpoints: @@ -764,7 +764,7 @@ def load_checkpoint(executor, main_program=None, role_id=0, is_trainer=True, - load_models=True, + load_models=False, load_trainer_args=None, load_slice_up_vars=None, load_lookup_table=None): @@ -774,8 +774,8 @@ def load_checkpoint(executor, `checkpoint_dir` directory. In the training precess, we generally save a checkpoint in each - iteration. So there are more than one checkpoint in the - `checkpoint_dir` (each checkpoint has its own sub folder), use + iteration. So there are more than one checkpoint in the + `checkpoint_dir` (each checkpoint has its own sub folder), use `serial` to specify which serial of checkpoint you would like to load. @@ -827,6 +827,10 @@ def load_checkpoint(executor, _load_persistable_vars(executor, checkpoint_dir, main_program, True) return if load_trainer_args: + + print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}". + format(checkpoint_dir, role_id, load_trainer_args)) + trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id, load_trainer_args) return trainer_args_ret @@ -842,9 +846,9 @@ def load_checkpoint(executor, def clean_checkpoint(checkpoint_dir, delete_dir=False): """ - clean the checkpoint dir, when the train exits normally, + clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. - delete_dir only works when the directory is empty, otherwise, OSError is raised. + delete_dir only works when the directory is empty, otherwise, OSError is raised. : param checkpoint_dir : param delete_dir @@ -954,7 +958,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars): def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): """ - The parameter server will load lookup table's local file in + The parameter server will load lookup table's local file in selectedrows variable. Args: @@ -1005,7 +1009,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): def _save_persistable_vars(executor, dirname, program): """ This function filters out all checkpoint variables from the give - program and then save these variables to a sub-folder '__model__' of + program and then save these variables to a sub-folder '__model__' of the given directory. A variable is a checkpoint variable if it meets all following @@ -1034,7 +1038,7 @@ def _save_persistable_vars(executor, dirname, program): # In this example, `_save_persistable_vars` function # will first filters out all checkpoint variables in the default - # main program, and then saves these variables to the folder + # main program, and then saves these variables to the folder # "./my_paddle_model/__model__". """ cur_dir = _get_model_dir(dirname) @@ -1053,7 +1057,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table, """ This function will send checkpoint notify message from Trainer 0 to all the pservers. - The checkpoint notify message contains lookup table name, + The checkpoint notify message contains lookup table name, the absolute path on pserver to save lookup_table. Args: @@ -1061,13 +1065,13 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table, dirname(str): The folder where to save checkpoints. lookup_table(string): the lookup table name, when use distribute lookup table, we can get lookup table name by DistributeTranspiler. - table_name - ps_endpoint_list(list): the parameter server ip:port list. - when use distribute lookup table, we can get ps_endpoint_list by + table_name + ps_endpoint_list(list): the parameter server ip:port list. + when use distribute lookup table, we can get ps_endpoint_list by distribute arguments. Return: None - + Examples: .. code-block:: python @@ -1078,7 +1082,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] _save_pserver_vars_by_notify(executor=exe, - dirname=param_path, lookup_table=table_name, + dirname=param_path, lookup_table=table_name, ps_endpoint_list=ps_endpoints) """ cur_dir = _get_lookuptable_dir(dirname) @@ -1110,7 +1114,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args): def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args): """ - trainer will load some args from it's independent directory, + trainer will load some args from it's independent directory, such as epoch_id and step_id. Args: @@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir): : param checkpoint_dir """ - if not checkpoint_dir: - return -1 def has_success(checkpoint_dir, cur_dir): """ @@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir): """ serial = _get_dir_serial(cur_dir) - if serial == -1 or not os.path.isdir( - os.path.join(checkpoint_dir, cur_dir)): + if serial == -1 or \ + not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): return -1 success_path = os.path.join( @@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir): if os.path.isfile(success_path): return serial - if not os.path.isdir(checkpoint_dir): - return -1 - current_dir = -1 + + if not checkpoint_dir or not os.path.isdir(checkpoint_dir): + return current_dir + dirs = os.listdir(checkpoint_dir) for cur_dir in dirs: success_num = has_success(checkpoint_dir, cur_dir)