From 2c05e37a77c921d48f7c2205f10e3eaa8f31ac21 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 24 Jul 2018 23:01:48 +0800 Subject: [PATCH] hidden slice_vars in distribute transpile, hidden it to users --- python/paddle/fluid/framework.py | 1 + python/paddle/fluid/io.py | 48 ++++++++++++++ python/paddle/fluid/trainer.py | 64 ------------------- .../fluid/transpiler/distribute_transpiler.py | 53 ++++++++------- 4 files changed, 80 insertions(+), 86 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 03e0ac75758..1a61fce37fa 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1281,6 +1281,7 @@ class Program(object): self._seed = 0 self._current_role = core.op_proto_and_checker_maker.OpRole.Forward self._op_role_var = [] + self._slice_vars_and_atts = [] @property def op_role(self): diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index e1d26474e63..5183f1d8b86 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -369,6 +369,7 @@ def load_vars(executor, load_vars( executor, dirname=dirname, + main_program=main_program, vars=filter(predicate, main_program.list_vars()), filename=filename) else: @@ -401,6 +402,10 @@ def load_vars(executor, outputs={"Out": load_var_list}, attrs={'file_path': os.path.join(dirname, filename)}) + if main_program._slice_vars_and_atts: + _load_slice_up_vars(executor, dirname, + main_program._slice_vars_and_atts) + executor.run(load_prog) @@ -888,3 +893,46 @@ def get_test_program(filelist, program=None, startup_program=None): program._sync_with_cpp() return program + + +def _load_slice_up_vars(executor, dirname, _slice_vars_and_atts): + if slice_vars == None or len(slice_vars) == 0: + return + + load_prog = Program() + load_block = load_prog.global_block() + + for var_tuple in slice_vars: + orig_var = var_tuple[0] + start = var_tuple[1] + slice_var = var_tuple[2] + end = start + reduce(lambda x, y: x * y, slice_var.shape) + + clone_orig_var = load_block.create_var( + name=orig_var.name, + type=orig_var.type, + shape=orig_var.shape, + dtype=orig_var.dtype, + persistable=True) + + clone_slice_var = load_block.create_var( + name=slice_var.name, + type=slice_var.type, + shape=slice_var.shape, + dtype=slice_var.dtype, + persistable=True) + + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [clone_orig_var]}, + attrs={'file_path': os.path.join(dirname, clone_orig_var.name)}) + load_block.append_op( + type="slice", + inputs={'Input': clone_orig_var}, + outputs={'Out': clone_slice_var}, + attrs={'axes': [0], + 'starts': [start], + 'ends': [end]}) + + executor.run(load_prog) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 2285df4c63b..ac1b3eb84af 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -360,7 +360,6 @@ class Trainer(object): self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint, self.train_program) - self.slice_vars = t.get_slice_vars_and_atts(current_endpoint) elif training_role == "TRAINER": self.train_program = t.get_trainer_program() else: @@ -609,16 +608,6 @@ class Trainer(object): # Pserver Load else: - # load slice_vars - if self.slice_vars != None and len(self.slice_vars) != 0: - load_checkpoint( - executor=exe, - checkpoint_dir=checkpoint_dir, - main_program=self.startup_program, - role_id=self.checkpoint_cfg.pserver_id, - is_trainer=False, - load_slice_up_vars=self.slice_vars) - # load lookup table if self.checkpoint_cfg.lookup_table_name: load_checkpoint( @@ -766,7 +755,6 @@ def load_checkpoint(executor, is_trainer=True, load_models=False, load_trainer_args=None, - load_slice_up_vars=None, load_lookup_table=None): """ This function filters out all checkpoint variables from the give @@ -827,18 +815,11 @@ def load_checkpoint(executor, _load_persistable_vars(executor, checkpoint_dir, main_program, True) return if load_trainer_args: - - print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}". - format(checkpoint_dir, role_id, load_trainer_args)) - trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id, load_trainer_args) return trainer_args_ret # pserver load else: - if load_slice_up_vars: - _load_slice_up_vars(executor, checkpoint_dir, load_slice_up_vars) - return if load_lookup_table: _load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id, load_lookup_table) @@ -911,51 +892,6 @@ def _load_persistable_vars(executor, dirname, program, has_model_dir=False): filename=None) -def _load_slice_up_vars(executor, dirname, slice_vars): - if slice_vars == None or len(slice_vars) == 0: - return - - dirname = _get_model_dir(dirname) - - load_prog = framework.Program() - load_block = load_prog.global_block() - - for var_tuple in slice_vars: - orig_var = var_tuple[0] - start = var_tuple[1] - slice_var = var_tuple[2] - end = start + reduce(lambda x, y: x * y, slice_var.shape) - - clone_orig_var = load_block.create_var( - name=orig_var.name, - type=orig_var.type, - shape=orig_var.shape, - dtype=orig_var.dtype, - persistable=True) - - clone_slice_var = load_block.create_var( - name=slice_var.name, - type=slice_var.type, - shape=slice_var.shape, - dtype=slice_var.dtype, - persistable=True) - - load_block.append_op( - type='load', - inputs={}, - outputs={'Out': [clone_orig_var]}, - attrs={'file_path': os.path.join(dirname, clone_orig_var.name)}) - load_block.append_op( - type="slice", - inputs={'Input': clone_orig_var}, - outputs={'Out': clone_slice_var}, - attrs={'axes': [0], - 'starts': [start], - 'ends': [end]}) - - executor.run(load_prog) - - def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): """ The parameter server will load lookup table's local file in diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index efed6cf1584..22a461125e9 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -525,6 +525,10 @@ class DistributeTranspiler(object): outputs={}, attrs=attrs) + # add slice vars + slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint) + pserver_program._slice_vars_and_atts = slice_vars_and_atts + pserver_program._sync_with_cpp() return pserver_program @@ -587,8 +591,35 @@ class DistributeTranspiler(object): inputs=new_inputs, outputs=new_outputs, attrs=op.attrs) + + # add slice vars + slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint) + s_prog._slice_vars_and_atts = slice_vars_and_atts + return s_prog + def _get_slice_vars_and_atts(self, endpoint): + slice_vars_and_atts = [] + block_suffix = ".block" + for param in self.param_grad_ep_mapping[endpoint]["params"]: + + suff_idx = param.name.find(block_suffix) + if suff_idx <= 0: + continue + + orig_var_name = param.name[:suff_idx] + block_idx = int(param.name[suff_idx + len(block_suffix):]) + + orig_var = self.origin_program.global_block().vars[orig_var_name] + + skip_numel = 0 + slice_vars = self.param_var_mapping[orig_var_name] + for slice_var in slice_vars[:block_idx]: + skip_numel += reduce(lambda x, y: x * y, slice_var.shape) + slice_vars_and_atts.append([orig_var, skip_numel, param]) + + return slice_vars_and_atts + # ====================== private transpiler functions ===================== def _has_distributed_lookup_table(self): @@ -719,28 +750,6 @@ class DistributeTranspiler(object): }) for ep in self.pserver_endpoints ] - def get_slice_vars_and_atts(self, endpoint): - slice_vars_and_atts = [] - block_suffix = ".block" - for param in self.param_grad_ep_mapping[endpoint]["params"]: - - suff_idx = param.name.find(block_suffix) - if suff_idx <= 0: - continue - - orig_var_name = param.name[:suff_idx] - block_idx = int(param.name[suff_idx + len(block_suffix):]) - - orig_var = self.origin_program.global_block().vars[orig_var_name] - - skip_numel = 0 - slice_vars = self.param_var_mapping[orig_var_name] - for slice_var in slice_vars[:block_idx]: - skip_numel += reduce(lambda x, y: x * y, slice_var.shape) - slice_vars_and_atts.append([orig_var, skip_numel, param]) - - return slice_vars_and_atts - # transpiler function for dis lookup_table def _replace_lookup_table_op_with_prefetch(self, program, pserver_endpoints): -- GitLab