diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 1c4565a83c591f8efd56cbdbc843b2b19d233973..eed9b49ef40b591d5b6481846dab714423f57990 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -133,6 +133,8 @@ class CheckpointConfig(object): self.epoch_id = 0 self.step_id = 0 self.load_serial = None + self.pserver_id = None + self.lookup_table_name = None def check_and_get_place(place): @@ -234,9 +236,6 @@ class Trainer(object): # config for checkpoint # only chief worker will save variables self.trainer_id = 0 - self.pserver_id = None - self.pserver_endpoints = None - self.lookup_table_name = None self.checkpoint_cfg = checkpoint_config if self.checkpoint_cfg: assert isinstance(self.checkpoint_cfg, CheckpointConfig) @@ -284,12 +283,10 @@ class Trainer(object): if param_path and os.path.isdir(param_path): # load params from param_path into scope - _load_persistable_vars(exe, param_path, self.startup_program, False, - [self.lookup_table_name] - if self.lookup_table_name else []) - if self.lookup_table_name and self.pserver_id is not None: - _load_lookup_table_vars(exe, param_path, self.startup_program, - self.pserver_id, self.lookup_table_name) + io.load_persistables( + executor=exe, + dirname=param_path, + main_program=self.startup_program) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS @@ -353,9 +350,11 @@ class Trainer(object): t.transpile( self.trainer_id, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": - self.pserver_id = eplist.index(current_endpoint) - self.pserver_endpoints = pserver_endpoints - self.lookup_table_name = t.table_name if t.has_distributed_lookup_table else None + if self.checkpoint_cfg: + pserver_id = eplist.index(current_endpoint) + self.checkpoint_cfg.pserver_id = pserver_id + if t.has_distributed_lookup_table: + self.checkpoint_cfg.lookup_table_name = t.table_name self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint, @@ -417,11 +416,6 @@ class Trainer(object): def save_params(self, param_path): """ Save all parameters into :code:`param_path`. - Only No.0 trainer will save dense params. - In standalone PaddlePaddle, the only existing trainer will save dense params. - In distributed PaddlePaddle, the No.0 trainer will save dense params, - If there have lookup table need to save, No.0 trainer will broadcast notification - to all Parameter Servers to save it on Parameter Servers independent. Args: param_path(str): The path to save parameters. @@ -429,19 +423,9 @@ class Trainer(object): Returns: None """ - - if self.trainer_id != 0: - return - with self._prog_and_scope_guard(): - # save params on trainer exe = executor.Executor(self.place) io.save_persistables(exe, dirname=param_path) - # save params on pserver - if self.lookup_table_name: - _save_pserver_vars_by_notify(exe, param_path, - self.lookup_table_name, - self.pserver_endpoints) @contextlib.contextmanager def _prog_and_scope_guard(self): @@ -489,10 +473,8 @@ class Trainer(object): self._clean_checkpoint() return - if self.checkpoint_cfg and \ - self.checkpoint_cfg.load_serial is not None and \ - self.checkpoint_cfg.step_id >= step_id and \ - self.checkpoint_cfg.epoch_id == epoch_id: + if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \ + and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id: continue begin_event = BeginStepEvent(epoch_id, step_id) @@ -574,75 +556,49 @@ class Trainer(object): if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \ and step_id % self.checkpoint_cfg.step_interval == 0: - exe = executor.Executor(self.place) save_checkpoint( executor=exe, checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, - main_program=self.train_program, trainer_id=self.trainer_id, - save_trainer_args=self._get_checkpoint_save_args(epoch_id, - step_id), - save_lookup_table=self.lookup_table_name, - pserver_endpoints=self.pserver_endpoints, + trainer_args=self._get_checkpoint_save_args(epoch_id, step_id), + main_program=self.train_program, max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) def _load_checkpoint(self): with self._prog_and_scope_guard(): exe = executor.Executor(self.place) + load_checkpoint( + executor=exe, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + main_program=self.startup_program) - checkpoint_dir = _get_serial_dir(self.checkpoint_cfg.checkpoint_dir, - self.checkpoint_cfg.load_serial) - - # Trainer Load - if self.pserver_id is None: - # load model - load_checkpoint( - executor=exe, - checkpoint_dir=checkpoint_dir, - main_program=self.startup_program, - role_id=self.trainer_id, - is_trainer=True, - load_models=True) - - # load trainer_args - trainer_args = self._get_checkpoint_load_args() - trainer_args_ret = load_checkpoint( + if not self.checkpoint_cfg.pserver_id: + load_trainer_args = self._get_checkpoint_load_args() + trainer_args = load_checkpoint( executor=exe, - checkpoint_dir=checkpoint_dir, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, main_program=self.startup_program, role_id=self.trainer_id, is_trainer=True, - load_trainer_args=trainer_args) + load_trainer_args=load_trainer_args) - if len(trainer_args_ret) != 2: + if len(trainer_args) != 2: raise ValueError( "the return trainer_args length do not equal _get_checkpoint_load_args" ) - self.checkpoint_cfg.epoch_id = int(trainer_args_ret[0]) - self.checkpoint_cfg.step_id = int(trainer_args_ret[1]) - - # Pserver Load + self.checkpoint_cfg.epoch_id = int(trainer_args[0]) + self.checkpoint_cfg.step_id = int(trainer_args[1]) else: - # load model - load_checkpoint( - executor=exe, - checkpoint_dir=checkpoint_dir, - main_program=self.startup_program, - role_id=self.pserver_id, - is_trainer=False, - load_models=True, - load_lookup_table=self.lookup_table_name) - - # load lookup table - if self.lookup_table_name: + if self.checkpoint_cfg.lookup_table_name: load_checkpoint( executor=exe, - checkpoint_dir=checkpoint_dir, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, main_program=self.startup_program, - role_id=self.pserver_id, + role_id=self.checkpoint_cfg.pserver_id, is_trainer=False, - load_lookup_table=self.lookup_table_name) + load_trainer_args=None, + load_lookup_table=self.checkpoint_cfg.lookup_table_name) def build_feed_var_list(program, feed_order): @@ -680,12 +636,12 @@ CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, checkpoint_dir, - main_program=None, - trainer_id=0, - save_trainer_args=None, - save_lookup_table=None, - pserver_endpoints=None, - max_num_checkpoints=3): + trainer_id, + main_program, + trainer_args=None, + max_num_checkpoints=3, + lookup_table=None, + pserver_endpoints=None): """ This function filters out all checkpoint variables from the give main_program and then saves these variables to the `checkpoint_dir` @@ -717,7 +673,7 @@ def save_checkpoint(executor, max_num_checkpoints(int): The max number of total number of existing checkpoints. Default: 3 - save_lookup_table(string|None): the lookup table name, when use distribute + lookup_table(string|None): the lookup table name, when use distribute lookup table, we can get lookup table name by DistributeTranspiler. table_name pserver_endpoints(list|None): the parameter server ip:port list. @@ -748,28 +704,31 @@ def save_checkpoint(executor, trainer_args=trainer_args, main_program=prog, max_num_checkpoints=3, - save_lookup_table=table_name, + lookup_table=table_name, pserver_endpoints = ps_endpoints) """ if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") - _make_chekcpoint_dirs(checkpoint_dir) - serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1 - cur_dir = _get_serial_dir(checkpoint_dir, serial, True) + if main_program is None: + raise ValueError('main_program should not be None.') + + if trainer_args: + assert isinstance(trainer_args, dict) is_chief = trainer_id == 0 - if save_trainer_args is not None: - _save_trainer_args(cur_dir, trainer_id, save_trainer_args) + _make_chekcpoint_dirs(checkpoint_dir) + serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1 + cur_dir = _get_serial_dir(checkpoint_dir, serial) + + _save_trainer_args(cur_dir, trainer_id, trainer_args) if is_chief: - if main_program is None: - raise ValueError('main_program should not be None.') - _save_persistable_vars(executor, cur_dir, main_program) + _save_persist_vars_without_grad(executor, cur_dir, main_program) - if is_chief and save_lookup_table and pserver_endpoints: - _save_pserver_vars_by_notify(executor, cur_dir, save_lookup_table, + if is_chief and lookup_table and pserver_endpoints: + _save_pserver_vars_by_notify(executor, cur_dir, lookup_table, pserver_endpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints) @@ -777,10 +736,9 @@ def save_checkpoint(executor, def load_checkpoint(executor, checkpoint_dir, - main_program=None, + main_program, role_id=0, is_trainer=True, - load_models=False, load_trainer_args=None, load_lookup_table=None): """ @@ -804,7 +762,7 @@ def load_checkpoint(executor, executor(Executor): The executor to run for loading checkpoint. checkpoint_dir(str): The folder where all checkpoints are. serial(int): The serial of checkpoint you would like to load. - main_program(Program|None): The program whose checkpoint variables will + main_program(Program): The program whose checkpoint variables will be loaded. role_id(int): the trainer id or the parameter server id. is_trainer(bool): trainer is True and parameter server is False. @@ -836,28 +794,27 @@ def load_checkpoint(executor, if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") - # trainer load - if is_trainer: - if load_models: - _load_persistable_vars(executor, checkpoint_dir, main_program, True) + serial = _get_latest_checkpoint_serial(checkpoint_dir) - if load_trainer_args: - trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id, - load_trainer_args) - return trainer_args_ret - # pserver load - else: - if load_models: - if load_lookup_table: - _load_persistable_vars(executor, checkpoint_dir, main_program, - True, [load_lookup_table]) - else: - _load_persistable_vars(executor, checkpoint_dir, main_program, - True) + # there are nothing need to be loaded + if serial is None or serial < 0: + return + + if main_program is None: + raise ValueError('main_program should not be None.') + + if is_trainer and load_trainer_args is None: + cur_dir = _get_serial_dir(checkpoint_dir, serial) + _load_persist_vars_without_grad(executor, cur_dir, main_program, True) + return + + if is_trainer and load_trainer_args: + return _load_trainer_args(checkpoint_dir, serial, role_id, + load_trainer_args) - if load_lookup_table: - _load_lookup_table_vars(executor, checkpoint_dir, main_program, - role_id, load_lookup_table) + if not is_trainer and load_lookup_table: + _load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id, + load_lookup_table) def clean_checkpoint(checkpoint_dir, delete_dir=False): @@ -878,11 +835,10 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): os.rmdir(checkpoint_dir) -def _load_persistable_vars(executor, - dirname, - program, - has_model_dir=False, - except_vars=None): +def _load_persist_vars_without_grad(executor, + dirname, + program, + has_model_dir=False): """ This function filters out all checkpoint variables from the give program and then trys to load these variables from the given directory. @@ -911,10 +867,10 @@ def _load_persistable_vars(executor, exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() - _load_persistable_vars(executor=exe, + _load_persist_vars_without_grad(executor=exe, dirname=param_path, program=prog, has_model_dir=True) - # In this example, `_load_persistable_vars` function + # In this example, `_load_persist_vars_without_grad` function # will first filters out all checkpoint variables in the default # main program, and then trys to load these variables form the # folder "./my_paddle_model/__model__". @@ -927,7 +883,7 @@ def _load_persistable_vars(executor, executor, dirname=dirname, main_program=program, - predicate=_is_checkpoint_var(except_vars), + predicate=_is_checkpoint_var, filename=None) @@ -981,7 +937,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): executor.run(load_prog) -def _save_persistable_vars(executor, dirname, program): +def _save_persist_vars_without_grad(executor, dirname, program): """ This function filters out all checkpoint variables from the give program and then save these variables to a sub-folder '__model__' of @@ -1008,10 +964,10 @@ def _save_persistable_vars(executor, dirname, program): exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() - _save_persistable_vars(executor=exe, + _save_persist_vars_without_grad(executor=exe, dirname=param_path, program=prog) - # In this example, `_save_persistable_vars` function + # In this example, `_save_persist_vars_without_grad` function # will first filters out all checkpoint variables in the default # main program, and then saves these variables to the folder # "./my_paddle_model/__model__". @@ -1022,13 +978,13 @@ def _save_persistable_vars(executor, dirname, program): dirname=cur_dir, main_program=program, vars=None, - predicate=_is_checkpoint_var(), + predicate=_is_checkpoint_var, filename=None) _write_success(cur_dir) def _save_pserver_vars_by_notify(executor, dirname, lookup_table, - pserver_endpoints): + ps_endpoint_list): """ This function will send checkpoint notify message from Trainer 0 to all the pservers. @@ -1066,7 +1022,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table, checkpoint_notify_block = checkpoint_notify_program.global_block() attrs = {} - attrs['epmap'] = pserver_endpoints.split(",") + attrs['epmap'] = ps_endpoint_list attrs['dir'] = cur_dir attrs['lookup_table'] = lookup_table @@ -1087,7 +1043,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args): _write_success(cur_dir) -def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args): +def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): """ trainer will load some args from it's independent directory, such as epoch_id and step_id. @@ -1113,7 +1069,8 @@ def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args): """ assert isinstance(trainer_args, list) - cur_dir = _get_trainer_dir(checkpoint_dir, trainer_id) + cur_dir = _get_serial_dir(checkpoint_dir, serial) + cur_dir = _get_trainer_dir(cur_dir, trainer_id) ret_values = [] @@ -1125,37 +1082,29 @@ def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args): return ret_values -def _is_checkpoint_var(except_vars=None): - except_vars = [] if except_vars is None else except_vars - - def _except_vars(var): - """ - the checkpoint will not save or load all the variables. - var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. - - : param var(Variable) - """ - if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ - var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ - var.desc.type() == core.VarDesc.VarType.RAW: - return False - # @GRAD are named for gradient variables, checkpoint will not save it. - if "@GRAD" in var.name: - return False - # .trainer_ are named for distribute train variables, checkpoint will not save it. - if ".trainer_" in var.name: - return False - - # .block is named for distribute train variables, checkpoint will not save it. - if ".block" in var.name: - return False +def _is_checkpoint_var(var): + """ + the checkpoint will not save or load all the variables. + var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. - if var in except_vars: - return False + : param var(Variable) + """ + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.RAW: + return False + # @GRAD are named for gradient variables, checkpoint will not save it. + if "@GRAD" in var.name: + return False + # .trainer_ are named for distribute train variables, checkpoint will not save it. + if ".trainer_" in var.name: + return False - return var.persistable + # .block is named for distribute train variables, checkpoint will not save it. + if ".block" in var.name: + return False - return _except_vars + return var.persistable def _make_chekcpoint_dirs(dirs): @@ -1176,19 +1125,20 @@ def _make_chekcpoint_dirs(dirs): def _get_dir_serial(dirname): + _, serial = dirname.split(CHECKPOINT_SEPARATOR) + try: - _, serial = dirname.split(CHECKPOINT_SEPARATOR) serial_num = int(serial) except ValueError: serial_num = -1 return serial_num -def _get_serial_dir(dirname, serial, makedirs=False): +def _get_serial_dir(dirname, serial): serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) serial_dir = os.path.join(dirname, serial_folder) - if makedirs: - _make_chekcpoint_dirs(serial_dir) + _make_chekcpoint_dirs(serial_dir) + return serial_dir @@ -1251,6 +1201,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir): : param checkpoint_dir """ + if not checkpoint_dir: + return -1 def has_success(checkpoint_dir, cur_dir): """ @@ -1258,8 +1210,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir): """ serial = _get_dir_serial(cur_dir) - if serial == -1 or \ - not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): + if serial == -1 or not os.path.isdir( + os.path.join(checkpoint_dir, cur_dir)): return -1 success_path = os.path.join( @@ -1268,11 +1220,10 @@ def _get_latest_checkpoint_serial(checkpoint_dir): if os.path.isfile(success_path): return serial - current_dir = -1 - - if not checkpoint_dir or not os.path.isdir(checkpoint_dir): - return current_dir + if not os.path.isdir(checkpoint_dir): + return -1 + current_dir = -1 dirs = os.listdir(checkpoint_dir) for cur_dir in dirs: success_num = has_success(checkpoint_dir, cur_dir)