diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index ac1b3eb84afc3241ccd4ef53980527d8f72a1a3d..8194a4e331c2a3fda13dc515d2129e43b7723d52 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -134,8 +134,6 @@ class CheckpointConfig(object): self.epoch_id = 0 self.step_id = 0 self.load_serial = None - self.pserver_id = None - self.lookup_table_name = None def check_and_get_place(place): @@ -351,11 +349,9 @@ class Trainer(object): t.transpile( self.trainer_id, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": - if self.checkpoint_cfg: - pserver_id = eplist.index(current_endpoint) - self.checkpoint_cfg.pserver_id = pserver_id - if t.has_distributed_lookup_table: - self.checkpoint_cfg.lookup_table_name = t.table_name + self.pserver_id = eplist.index(current_endpoint) + self.pserver_endpoints = pserver_endpoints + self.lookup_table_name = t.table_name if t.has_distributed_lookup_table else None self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint, @@ -417,6 +413,11 @@ class Trainer(object): def save_params(self, param_path): """ Save all parameters into :code:`param_path`. + Only No.0 trainer will save dense params. + In standalone PaddlePaddle, the only existing trainer will save dense params. + In distributed PaddlePaddle, the No.0 trainer will save dense params, + If there have lookup table need to save, No.0 trainer will broadcast notification + to all Parameter Servers to save it on Parameter Servers independent. Args: param_path(str): The path to save parameters. @@ -424,9 +425,19 @@ class Trainer(object): Returns: None """ + + if self.trainer_id != 0: + return + with self._prog_and_scope_guard(): + # save params on trainer exe = executor.Executor(self.place) io.save_persistables(exe, dirname=param_path) + # save params on pserver + if self.lookup_table_name: + _save_pserver_vars_by_notify(exe, param_path, + self.lookup_table_name, + self.pserver_endpoints) @contextlib.contextmanager def _prog_and_scope_guard(self): @@ -560,15 +571,16 @@ class Trainer(object): if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \ and step_id % self.checkpoint_cfg.step_interval == 0: - print("_save_checkpoint ...") - exe = executor.Executor(self.place) save_checkpoint( executor=exe, checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, - trainer_id=self.trainer_id, - trainer_args=self._get_checkpoint_save_args(epoch_id, step_id), main_program=self.train_program, + trainer_id=self.trainer_id, + save_trainer_args=self._get_checkpoint_save_args(epoch_id, + step_id), + save_lookup_table=self.lookup_table_name, + pserver_endpoints=self.pserver_endpoints, max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) def _load_checkpoint(self): @@ -579,7 +591,7 @@ class Trainer(object): self.checkpoint_cfg.load_serial) # Trainer Load - if self.checkpoint_cfg.pserver_id is None: + if self.pserver_id is None: # load model load_checkpoint( executor=exe, @@ -608,15 +620,25 @@ class Trainer(object): # Pserver Load else: + # load model + load_checkpoint( + executor=exe, + checkpoint_dir=checkpoint_dir, + main_program=self.startup_program, + role_id=self.pserver_id, + is_trainer=False, + load_models=True, + load_lookup_table=self.lookup_table_name) + # load lookup table - if self.checkpoint_cfg.lookup_table_name: + if self.lookup_table_name: load_checkpoint( executor=exe, checkpoint_dir=checkpoint_dir, main_program=self.startup_program, - role_id=self.checkpoint_cfg.pserver_id, + role_id=self.pserver_id, is_trainer=False, - load_lookup_table=self.checkpoint_cfg.lookup_table_name) + load_lookup_table=self.lookup_table_name) def build_feed_var_list(program, feed_order): @@ -813,13 +835,21 @@ def load_checkpoint(executor, if is_trainer: if load_models: _load_persistable_vars(executor, checkpoint_dir, main_program, True) - return + if load_trainer_args: trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id, load_trainer_args) return trainer_args_ret # pserver load else: + if load_models: + if load_lookup_table: + _load_persistable_vars(executor, checkpoint_dir, main_program, + True, [load_lookup_table]) + else: + _load_persistable_vars(executor, checkpoint_dir, main_program, + True) + if load_lookup_table: _load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id, load_lookup_table) @@ -843,7 +873,11 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): os.rmdir(checkpoint_dir) -def _load_persistable_vars(executor, dirname, program, has_model_dir=False): +def _load_persistable_vars(executor, + dirname, + program, + has_model_dir=False, + except_vars=None): """ This function filters out all checkpoint variables from the give program and then trys to load these variables from the given directory. @@ -888,7 +922,7 @@ def _load_persistable_vars(executor, dirname, program, has_model_dir=False): executor, dirname=dirname, main_program=program, - predicate=_is_checkpoint_var, + predicate=_is_checkpoint_var(except_vars), filename=None) @@ -983,13 +1017,13 @@ def _save_persistable_vars(executor, dirname, program): dirname=cur_dir, main_program=program, vars=None, - predicate=_is_checkpoint_var, + predicate=_is_checkpoint_var(), filename=None) _write_success(cur_dir) def _save_pserver_vars_by_notify(executor, dirname, lookup_table, - ps_endpoint_list): + pserver_endpoints): """ This function will send checkpoint notify message from Trainer 0 to all the pservers. @@ -1002,8 +1036,8 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table, lookup_table(string): the lookup table name, when use distribute lookup table, we can get lookup table name by DistributeTranspiler. table_name - ps_endpoint_list(list): the parameter server ip:port list. - when use distribute lookup table, we can get ps_endpoint_list by + pserver_endpoints(list): the parameter server ip:port list. + when use distribute lookup table, we can get pserver_endpoints by distribute arguments. Return: None @@ -1027,7 +1061,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table, checkpoint_notify_block = checkpoint_notify_program.global_block() attrs = {} - attrs['epmap'] = ps_endpoint_list + attrs['epmap'] = pserver_endpoints.split(",") attrs['dir'] = cur_dir attrs['lookup_table'] = lookup_table @@ -1086,29 +1120,37 @@ def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args): return ret_values -def _is_checkpoint_var(var): - """ - the checkpoint will not save or load all the variables. - var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. +def _is_checkpoint_var(except_vars=None): + except_vars = [] if except_vars is None else except_vars - : param var(Variable) - """ - if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ - var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ - var.desc.type() == core.VarDesc.VarType.RAW: - return False - # @GRAD are named for gradient variables, checkpoint will not save it. - if "@GRAD" in var.name: - return False - # .trainer_ are named for distribute train variables, checkpoint will not save it. - if ".trainer_" in var.name: - return False - - # .block is named for distribute train variables, checkpoint will not save it. - if ".block" in var.name: - return False - - return var.persistable + def _except_vars(var): + """ + the checkpoint will not save or load all the variables. + var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. + + : param var(Variable) + """ + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.RAW: + return False + # @GRAD are named for gradient variables, checkpoint will not save it. + if "@GRAD" in var.name: + return False + # .trainer_ are named for distribute train variables, checkpoint will not save it. + if ".trainer_" in var.name: + return False + + # .block is named for distribute train variables, checkpoint will not save it. + if ".block" in var.name: + return False + + if var in except_vars: + return False + + return var.persistable + + return _except_vars def _make_chekcpoint_dirs(dirs):