提交 2c05e37a 编写于 作者: T tangwei12

hidden slice_vars in distribute transpile, hidden it to users

上级 438de1e0
...@@ -1281,6 +1281,7 @@ class Program(object): ...@@ -1281,6 +1281,7 @@ class Program(object):
self._seed = 0 self._seed = 0
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = [] self._op_role_var = []
self._slice_vars_and_atts = []
@property @property
def op_role(self): def op_role(self):
......
...@@ -369,6 +369,7 @@ def load_vars(executor, ...@@ -369,6 +369,7 @@ def load_vars(executor,
load_vars( load_vars(
executor, executor,
dirname=dirname, dirname=dirname,
main_program=main_program,
vars=filter(predicate, main_program.list_vars()), vars=filter(predicate, main_program.list_vars()),
filename=filename) filename=filename)
else: else:
...@@ -401,6 +402,10 @@ def load_vars(executor, ...@@ -401,6 +402,10 @@ def load_vars(executor,
outputs={"Out": load_var_list}, outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(dirname, filename)})
if main_program._slice_vars_and_atts:
_load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_atts)
executor.run(load_prog) executor.run(load_prog)
...@@ -888,3 +893,46 @@ def get_test_program(filelist, program=None, startup_program=None): ...@@ -888,3 +893,46 @@ def get_test_program(filelist, program=None, startup_program=None):
program._sync_with_cpp() program._sync_with_cpp()
return program return program
def _load_slice_up_vars(executor, dirname, _slice_vars_and_atts):
if slice_vars == None or len(slice_vars) == 0:
return
load_prog = Program()
load_block = load_prog.global_block()
for var_tuple in slice_vars:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + reduce(lambda x, y: x * y, slice_var.shape)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
executor.run(load_prog)
...@@ -360,7 +360,6 @@ class Trainer(object): ...@@ -360,7 +360,6 @@ class Trainer(object):
self.train_program = t.get_pserver_program(current_endpoint) self.train_program = t.get_pserver_program(current_endpoint)
self.startup_program = t.get_startup_program(current_endpoint, self.startup_program = t.get_startup_program(current_endpoint,
self.train_program) self.train_program)
self.slice_vars = t.get_slice_vars_and_atts(current_endpoint)
elif training_role == "TRAINER": elif training_role == "TRAINER":
self.train_program = t.get_trainer_program() self.train_program = t.get_trainer_program()
else: else:
...@@ -609,16 +608,6 @@ class Trainer(object): ...@@ -609,16 +608,6 @@ class Trainer(object):
# Pserver Load # Pserver Load
else: else:
# load slice_vars
if self.slice_vars != None and len(self.slice_vars) != 0:
load_checkpoint(
executor=exe,
checkpoint_dir=checkpoint_dir,
main_program=self.startup_program,
role_id=self.checkpoint_cfg.pserver_id,
is_trainer=False,
load_slice_up_vars=self.slice_vars)
# load lookup table # load lookup table
if self.checkpoint_cfg.lookup_table_name: if self.checkpoint_cfg.lookup_table_name:
load_checkpoint( load_checkpoint(
...@@ -766,7 +755,6 @@ def load_checkpoint(executor, ...@@ -766,7 +755,6 @@ def load_checkpoint(executor,
is_trainer=True, is_trainer=True,
load_models=False, load_models=False,
load_trainer_args=None, load_trainer_args=None,
load_slice_up_vars=None,
load_lookup_table=None): load_lookup_table=None):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
...@@ -827,18 +815,11 @@ def load_checkpoint(executor, ...@@ -827,18 +815,11 @@ def load_checkpoint(executor,
_load_persistable_vars(executor, checkpoint_dir, main_program, True) _load_persistable_vars(executor, checkpoint_dir, main_program, True)
return return
if load_trainer_args: if load_trainer_args:
print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}".
format(checkpoint_dir, role_id, load_trainer_args))
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id, trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
load_trainer_args) load_trainer_args)
return trainer_args_ret return trainer_args_ret
# pserver load # pserver load
else: else:
if load_slice_up_vars:
_load_slice_up_vars(executor, checkpoint_dir, load_slice_up_vars)
return
if load_lookup_table: if load_lookup_table:
_load_lookup_table_vars(executor, checkpoint_dir, main_program, _load_lookup_table_vars(executor, checkpoint_dir, main_program,
role_id, load_lookup_table) role_id, load_lookup_table)
...@@ -911,51 +892,6 @@ def _load_persistable_vars(executor, dirname, program, has_model_dir=False): ...@@ -911,51 +892,6 @@ def _load_persistable_vars(executor, dirname, program, has_model_dir=False):
filename=None) filename=None)
def _load_slice_up_vars(executor, dirname, slice_vars):
if slice_vars == None or len(slice_vars) == 0:
return
dirname = _get_model_dir(dirname)
load_prog = framework.Program()
load_block = load_prog.global_block()
for var_tuple in slice_vars:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + reduce(lambda x, y: x * y, slice_var.shape)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
executor.run(load_prog)
def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
""" """
The parameter server will load lookup table's local file in The parameter server will load lookup table's local file in
......
...@@ -525,6 +525,10 @@ class DistributeTranspiler(object): ...@@ -525,6 +525,10 @@ class DistributeTranspiler(object):
outputs={}, outputs={},
attrs=attrs) attrs=attrs)
# add slice vars
slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint)
pserver_program._slice_vars_and_atts = slice_vars_and_atts
pserver_program._sync_with_cpp() pserver_program._sync_with_cpp()
return pserver_program return pserver_program
...@@ -587,8 +591,35 @@ class DistributeTranspiler(object): ...@@ -587,8 +591,35 @@ class DistributeTranspiler(object):
inputs=new_inputs, inputs=new_inputs,
outputs=new_outputs, outputs=new_outputs,
attrs=op.attrs) attrs=op.attrs)
# add slice vars
slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint)
s_prog._slice_vars_and_atts = slice_vars_and_atts
return s_prog return s_prog
def _get_slice_vars_and_atts(self, endpoint):
slice_vars_and_atts = []
block_suffix = ".block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
suff_idx = param.name.find(block_suffix)
if suff_idx <= 0:
continue
orig_var_name = param.name[:suff_idx]
block_idx = int(param.name[suff_idx + len(block_suffix):])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_atts.append([orig_var, skip_numel, param])
return slice_vars_and_atts
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self): def _has_distributed_lookup_table(self):
...@@ -719,28 +750,6 @@ class DistributeTranspiler(object): ...@@ -719,28 +750,6 @@ class DistributeTranspiler(object):
}) for ep in self.pserver_endpoints }) for ep in self.pserver_endpoints
] ]
def get_slice_vars_and_atts(self, endpoint):
slice_vars_and_atts = []
block_suffix = ".block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
suff_idx = param.name.find(block_suffix)
if suff_idx <= 0:
continue
orig_var_name = param.name[:suff_idx]
block_idx = int(param.name[suff_idx + len(block_suffix):])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_atts.append([orig_var, skip_numel, param])
return slice_vars_and_atts
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints): pserver_endpoints):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册