From 1dd14a704dad85686af048c269993ffde8cf82fb Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 19 Jul 2018 15:05:25 +0800 Subject: [PATCH] bug fix --- python/paddle/fluid/trainer.py | 189 ++++++++++++------ .../fluid/transpiler/distribute_transpiler.py | 22 ++ 2 files changed, 148 insertions(+), 63 deletions(-) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 64049a93c..573e0cdab 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -360,6 +360,7 @@ class Trainer(object): self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint, self.train_program) + self.slice_vars = t.get_slice_vars_and_atts(current_endpoint) elif training_role == "TRAINER": self.train_program = t.get_trainer_program() else: @@ -474,8 +475,10 @@ class Trainer(object): self._clean_checkpoint() return - if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \ - and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id: + if self.checkpoint_cfg and \ + self.checkpoint_cfg.load_serial is not None and \ + self.checkpoint_cfg.step_id >= step_id and \ + self.checkpoint_cfg.epoch_id == epoch_id: continue begin_event = BeginStepEvent(epoch_id, step_id) @@ -569,36 +572,58 @@ class Trainer(object): def _load_checkpoint(self): with self._prog_and_scope_guard(): exe = executor.Executor(self.place) - load_checkpoint( - executor=exe, - checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, - main_program=self.startup_program) - if not self.checkpoint_cfg.pserver_id: - load_trainer_args = self._get_checkpoint_load_args() - trainer_args = load_checkpoint( + checkpoint_dir = _get_serial_dir(self.checkpoint_cfg.checkpoint_dir, + self.checkpoint_cfg.load_serial) + + # Trainer Load + if self.checkpoint_cfg.pserver_id is None: + # load model + load_checkpoint( executor=exe, - checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + checkpoint_dir=checkpoint_dir, main_program=self.startup_program, role_id=self.trainer_id, is_trainer=True, - load_trainer_args=load_trainer_args) + load_models=True) - if len(trainer_args) != 2: + # load trainer_args + trainer_args = self._get_checkpoint_load_args() + trainer_args_ret = load_checkpoint( + executor=exe, + checkpoint_dir=checkpoint_dir, + main_program=self.startup_program, + role_id=self.trainer_id, + is_trainer=True, + load_trainer_args=trainer_args) + + if len(trainer_args_ret) != 2: raise ValueError( "the return trainer_args length do not equal _get_checkpoint_load_args" ) - self.checkpoint_cfg.epoch_id = int(trainer_args[0]) - self.checkpoint_cfg.step_id = int(trainer_args[1]) + self.checkpoint_cfg.epoch_id = int(trainer_args_ret[0]) + self.checkpoint_cfg.step_id = int(trainer_args_ret[1]) + + # Pserver Load else: + # load slice_vars + if self.slice_vars != None and len(self.slice_vars) != 0: + load_checkpoint( + executor=exe, + checkpoint_dir=checkpoint_dir, + main_program=self.startup_program, + role_id=self.checkpoint_cfg.pserver_id, + is_trainer=False, + load_slice_up_vars=self.slice_vars) + + # load lookup table if self.checkpoint_cfg.lookup_table_name: load_checkpoint( executor=exe, - checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + checkpoint_dir=checkpoint_dir, main_program=self.startup_program, role_id=self.checkpoint_cfg.pserver_id, is_trainer=False, - load_trainer_args=None, load_lookup_table=self.checkpoint_cfg.lookup_table_name) @@ -640,7 +665,7 @@ def save_checkpoint(executor, main_program, trainer_args=None, max_num_checkpoints=3, - lookup_table=None, + save_lookup_table=None, pserver_endpoints=None): """ This function filters out all checkpoint variables from the give @@ -673,7 +698,7 @@ def save_checkpoint(executor, max_num_checkpoints(int): The max number of total number of existing checkpoints. Default: 3 - lookup_table(string|None): the lookup table name, when use distribute + save_lookup_table(string|None): the lookup table name, when use distribute lookup table, we can get lookup table name by DistributeTranspiler. table_name pserver_endpoints(list|None): the parameter server ip:port list. @@ -704,7 +729,7 @@ def save_checkpoint(executor, trainer_args=trainer_args, main_program=prog, max_num_checkpoints=3, - lookup_table=table_name, + save_lookup_table=table_name, pserver_endpoints = ps_endpoints) """ if checkpoint_dir is None: @@ -720,15 +745,15 @@ def save_checkpoint(executor, _make_chekcpoint_dirs(checkpoint_dir) serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1 - cur_dir = _get_serial_dir(checkpoint_dir, serial) + cur_dir = _get_serial_dir(checkpoint_dir, serial, True) _save_trainer_args(cur_dir, trainer_id, trainer_args) if is_chief: - _save_persist_vars_without_grad(executor, cur_dir, main_program) + _save_persistable_vars(executor, cur_dir, main_program) - if is_chief and lookup_table and pserver_endpoints: - _save_pserver_vars_by_notify(executor, cur_dir, lookup_table, + if is_chief and save_lookup_table and pserver_endpoints: + _save_pserver_vars_by_notify(executor, cur_dir, save_lookup_table, pserver_endpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints) @@ -736,10 +761,12 @@ def save_checkpoint(executor, def load_checkpoint(executor, checkpoint_dir, - main_program, + main_program=None, role_id=0, is_trainer=True, + load_models=True, load_trainer_args=None, + load_slice_up_vars=None, load_lookup_table=None): """ This function filters out all checkpoint variables from the give @@ -762,7 +789,7 @@ def load_checkpoint(executor, executor(Executor): The executor to run for loading checkpoint. checkpoint_dir(str): The folder where all checkpoints are. serial(int): The serial of checkpoint you would like to load. - main_program(Program): The program whose checkpoint variables will + main_program(Program|None): The program whose checkpoint variables will be loaded. role_id(int): the trainer id or the parameter server id. is_trainer(bool): trainer is True and parameter server is False. @@ -794,27 +821,23 @@ def load_checkpoint(executor, if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") - serial = _get_latest_checkpoint_serial(checkpoint_dir) - - # there are nothing need to be loaded - if serial is None or serial < 0: - return - - if main_program is None: - raise ValueError('main_program should not be None.') - - if is_trainer and load_trainer_args is None: - cur_dir = _get_serial_dir(checkpoint_dir, serial) - _load_persist_vars_without_grad(executor, cur_dir, main_program, True) - return - - if is_trainer and load_trainer_args: - return _load_trainer_args(checkpoint_dir, serial, role_id, - load_trainer_args) - - if not is_trainer and load_lookup_table: - _load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id, - load_lookup_table) + # trainer load + if is_trainer: + if load_models: + _load_persistable_vars(executor, checkpoint_dir, main_program, True) + return + if load_trainer_args: + trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id, + load_trainer_args) + return trainer_args_ret + # pserver load + else: + if load_slice_up_vars: + _load_slice_up_vars(executor, checkpoint_dir, load_slice_up_vars) + return + if load_lookup_table: + _load_lookup_table_vars(executor, checkpoint_dir, main_program, + role_id, load_lookup_table) def clean_checkpoint(checkpoint_dir, delete_dir=False): @@ -835,10 +858,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): os.rmdir(checkpoint_dir) -def _load_persist_vars_without_grad(executor, - dirname, - program, - has_model_dir=False): +def _load_persistable_vars(executor, dirname, program, has_model_dir=False): """ This function filters out all checkpoint variables from the give program and then trys to load these variables from the given directory. @@ -867,10 +887,10 @@ def _load_persist_vars_without_grad(executor, exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() - _load_persist_vars_without_grad(executor=exe, + _load_persistable_vars(executor=exe, dirname=param_path, program=prog, has_model_dir=True) - # In this example, `_load_persist_vars_without_grad` function + # In this example, `_load_persistable_vars` function # will first filters out all checkpoint variables in the default # main program, and then trys to load these variables form the # folder "./my_paddle_model/__model__". @@ -887,6 +907,51 @@ def _load_persist_vars_without_grad(executor, filename=None) +def _load_slice_up_vars(executor, dirname, slice_vars): + if slice_vars == None or len(slice_vars) == 0: + return + + dirname = _get_model_dir(dirname) + + load_prog = framework.Program() + load_block = load_prog.global_block() + + for var_tuple in slice_vars: + orig_var = var_tuple[0] + start = var_tuple[1] + slice_var = var_tuple[2] + end = start + reduce(lambda x, y: x * y, slice_var.shape) + + clone_orig_var = load_block.create_var( + name=orig_var.name, + type=orig_var.type, + shape=orig_var.shape, + dtype=orig_var.dtype, + persistable=True) + + clone_slice_var = load_block.create_var( + name=slice_var.name, + type=slice_var.type, + shape=slice_var.shape, + dtype=slice_var.dtype, + persistable=True) + + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [clone_orig_var]}, + attrs={'file_path': os.path.join(dirname, clone_orig_var.name)}) + load_block.append_op( + type="slice", + inputs={'Input': clone_orig_var}, + outputs={'Out': clone_slice_var}, + attrs={'axes': [0], + 'starts': [start], + 'ends': [end]}) + + executor.run(load_prog) + + def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): """ The parameter server will load lookup table's local file in @@ -937,7 +1002,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): executor.run(load_prog) -def _save_persist_vars_without_grad(executor, dirname, program): +def _save_persistable_vars(executor, dirname, program): """ This function filters out all checkpoint variables from the give program and then save these variables to a sub-folder '__model__' of @@ -964,10 +1029,10 @@ def _save_persist_vars_without_grad(executor, dirname, program): exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() - _save_persist_vars_without_grad(executor=exe, + _save_persistable_vars(executor=exe, dirname=param_path, program=prog) - # In this example, `_save_persist_vars_without_grad` function + # In this example, `_save_persistable_vars` function # will first filters out all checkpoint variables in the default # main program, and then saves these variables to the folder # "./my_paddle_model/__model__". @@ -1043,7 +1108,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args): _write_success(cur_dir) -def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): +def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args): """ trainer will load some args from it's independent directory, such as epoch_id and step_id. @@ -1069,8 +1134,7 @@ def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): """ assert isinstance(trainer_args, list) - cur_dir = _get_serial_dir(checkpoint_dir, serial) - cur_dir = _get_trainer_dir(cur_dir, trainer_id) + cur_dir = _get_trainer_dir(checkpoint_dir, trainer_id) ret_values = [] @@ -1125,20 +1189,19 @@ def _make_chekcpoint_dirs(dirs): def _get_dir_serial(dirname): - _, serial = dirname.split(CHECKPOINT_SEPARATOR) - try: + _, serial = dirname.split(CHECKPOINT_SEPARATOR) serial_num = int(serial) except ValueError: serial_num = -1 return serial_num -def _get_serial_dir(dirname, serial): +def _get_serial_dir(dirname, serial, makedirs=False): serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) serial_dir = os.path.join(dirname, serial_folder) - _make_chekcpoint_dirs(serial_dir) - + if makedirs: + _make_chekcpoint_dirs(serial_dir) return serial_dir diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index c2044bf03..efed6cf15 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -719,6 +719,28 @@ class DistributeTranspiler(object): }) for ep in self.pserver_endpoints ] + def get_slice_vars_and_atts(self, endpoint): + slice_vars_and_atts = [] + block_suffix = ".block" + for param in self.param_grad_ep_mapping[endpoint]["params"]: + + suff_idx = param.name.find(block_suffix) + if suff_idx <= 0: + continue + + orig_var_name = param.name[:suff_idx] + block_idx = int(param.name[suff_idx + len(block_suffix):]) + + orig_var = self.origin_program.global_block().vars[orig_var_name] + + skip_numel = 0 + slice_vars = self.param_var_mapping[orig_var_name] + for slice_var in slice_vars[:block_idx]: + skip_numel += reduce(lambda x, y: x * y, slice_var.shape) + slice_vars_and_atts.append([orig_var, skip_numel, param]) + + return slice_vars_and_atts + # transpiler function for dis lookup_table def _replace_lookup_table_op_with_prefetch(self, program, pserver_endpoints): -- GitLab