提交 2c05e37a 编写于 作者: T tangwei12

hidden slice_vars in distribute transpile, hidden it to users

上级 438de1e0
......@@ -1281,6 +1281,7 @@ class Program(object):
self._seed = 0
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = []
self._slice_vars_and_atts = []
@property
def op_role(self):
......
......@@ -369,6 +369,7 @@ def load_vars(executor,
load_vars(
executor,
dirname=dirname,
main_program=main_program,
vars=filter(predicate, main_program.list_vars()),
filename=filename)
else:
......@@ -401,6 +402,10 @@ def load_vars(executor,
outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, filename)})
if main_program._slice_vars_and_atts:
_load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_atts)
executor.run(load_prog)
......@@ -888,3 +893,46 @@ def get_test_program(filelist, program=None, startup_program=None):
program._sync_with_cpp()
return program
def _load_slice_up_vars(executor, dirname, _slice_vars_and_atts):
if slice_vars == None or len(slice_vars) == 0:
return
load_prog = Program()
load_block = load_prog.global_block()
for var_tuple in slice_vars:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + reduce(lambda x, y: x * y, slice_var.shape)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
executor.run(load_prog)
......@@ -360,7 +360,6 @@ class Trainer(object):
self.train_program = t.get_pserver_program(current_endpoint)
self.startup_program = t.get_startup_program(current_endpoint,
self.train_program)
self.slice_vars = t.get_slice_vars_and_atts(current_endpoint)
elif training_role == "TRAINER":
self.train_program = t.get_trainer_program()
else:
......@@ -609,16 +608,6 @@ class Trainer(object):
# Pserver Load
else:
# load slice_vars
if self.slice_vars != None and len(self.slice_vars) != 0:
load_checkpoint(
executor=exe,
checkpoint_dir=checkpoint_dir,
main_program=self.startup_program,
role_id=self.checkpoint_cfg.pserver_id,
is_trainer=False,
load_slice_up_vars=self.slice_vars)
# load lookup table
if self.checkpoint_cfg.lookup_table_name:
load_checkpoint(
......@@ -766,7 +755,6 @@ def load_checkpoint(executor,
is_trainer=True,
load_models=False,
load_trainer_args=None,
load_slice_up_vars=None,
load_lookup_table=None):
"""
This function filters out all checkpoint variables from the give
......@@ -827,18 +815,11 @@ def load_checkpoint(executor,
_load_persistable_vars(executor, checkpoint_dir, main_program, True)
return
if load_trainer_args:
print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}".
format(checkpoint_dir, role_id, load_trainer_args))
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
load_trainer_args)
return trainer_args_ret
# pserver load
else:
if load_slice_up_vars:
_load_slice_up_vars(executor, checkpoint_dir, load_slice_up_vars)
return
if load_lookup_table:
_load_lookup_table_vars(executor, checkpoint_dir, main_program,
role_id, load_lookup_table)
......@@ -911,51 +892,6 @@ def _load_persistable_vars(executor, dirname, program, has_model_dir=False):
filename=None)
def _load_slice_up_vars(executor, dirname, slice_vars):
if slice_vars == None or len(slice_vars) == 0:
return
dirname = _get_model_dir(dirname)
load_prog = framework.Program()
load_block = load_prog.global_block()
for var_tuple in slice_vars:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + reduce(lambda x, y: x * y, slice_var.shape)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
executor.run(load_prog)
def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
"""
The parameter server will load lookup table's local file in
......
......@@ -525,6 +525,10 @@ class DistributeTranspiler(object):
outputs={},
attrs=attrs)
# add slice vars
slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint)
pserver_program._slice_vars_and_atts = slice_vars_and_atts
pserver_program._sync_with_cpp()
return pserver_program
......@@ -587,8 +591,35 @@ class DistributeTranspiler(object):
inputs=new_inputs,
outputs=new_outputs,
attrs=op.attrs)
# add slice vars
slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint)
s_prog._slice_vars_and_atts = slice_vars_and_atts
return s_prog
def _get_slice_vars_and_atts(self, endpoint):
slice_vars_and_atts = []
block_suffix = ".block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
suff_idx = param.name.find(block_suffix)
if suff_idx <= 0:
continue
orig_var_name = param.name[:suff_idx]
block_idx = int(param.name[suff_idx + len(block_suffix):])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_atts.append([orig_var, skip_numel, param])
return slice_vars_and_atts
# ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self):
......@@ -719,28 +750,6 @@ class DistributeTranspiler(object):
}) for ep in self.pserver_endpoints
]
def get_slice_vars_and_atts(self, endpoint):
slice_vars_and_atts = []
block_suffix = ".block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
suff_idx = param.name.find(block_suffix)
if suff_idx <= 0:
continue
orig_var_name = param.name[:suff_idx]
block_idx = int(param.name[suff_idx + len(block_suffix):])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_atts.append([orig_var, skip_numel, param])
return slice_vars_and_atts
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册