提交 1dd14a70 编写于 作者: T tangwei12

bug fix

上级 f9f8fbaa
......@@ -360,6 +360,7 @@ class Trainer(object):
self.train_program = t.get_pserver_program(current_endpoint)
self.startup_program = t.get_startup_program(current_endpoint,
self.train_program)
self.slice_vars = t.get_slice_vars_and_atts(current_endpoint)
elif training_role == "TRAINER":
self.train_program = t.get_trainer_program()
else:
......@@ -474,8 +475,10 @@ class Trainer(object):
self._clean_checkpoint()
return
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \
and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id:
if self.checkpoint_cfg and \
self.checkpoint_cfg.load_serial is not None and \
self.checkpoint_cfg.step_id >= step_id and \
self.checkpoint_cfg.epoch_id == epoch_id:
continue
begin_event = BeginStepEvent(epoch_id, step_id)
......@@ -569,36 +572,58 @@ class Trainer(object):
def _load_checkpoint(self):
with self._prog_and_scope_guard():
exe = executor.Executor(self.place)
checkpoint_dir = _get_serial_dir(self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial)
# Trainer Load
if self.checkpoint_cfg.pserver_id is None:
# load model
load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program)
checkpoint_dir=checkpoint_dir,
main_program=self.startup_program,
role_id=self.trainer_id,
is_trainer=True,
load_models=True)
if not self.checkpoint_cfg.pserver_id:
load_trainer_args = self._get_checkpoint_load_args()
trainer_args = load_checkpoint(
# load trainer_args
trainer_args = self._get_checkpoint_load_args()
trainer_args_ret = load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
checkpoint_dir=checkpoint_dir,
main_program=self.startup_program,
role_id=self.trainer_id,
is_trainer=True,
load_trainer_args=load_trainer_args)
load_trainer_args=trainer_args)
if len(trainer_args) != 2:
if len(trainer_args_ret) != 2:
raise ValueError(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self.checkpoint_cfg.epoch_id = int(trainer_args[0])
self.checkpoint_cfg.step_id = int(trainer_args[1])
self.checkpoint_cfg.epoch_id = int(trainer_args_ret[0])
self.checkpoint_cfg.step_id = int(trainer_args_ret[1])
# Pserver Load
else:
# load slice_vars
if self.slice_vars != None and len(self.slice_vars) != 0:
load_checkpoint(
executor=exe,
checkpoint_dir=checkpoint_dir,
main_program=self.startup_program,
role_id=self.checkpoint_cfg.pserver_id,
is_trainer=False,
load_slice_up_vars=self.slice_vars)
# load lookup table
if self.checkpoint_cfg.lookup_table_name:
load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
checkpoint_dir=checkpoint_dir,
main_program=self.startup_program,
role_id=self.checkpoint_cfg.pserver_id,
is_trainer=False,
load_trainer_args=None,
load_lookup_table=self.checkpoint_cfg.lookup_table_name)
......@@ -640,7 +665,7 @@ def save_checkpoint(executor,
main_program,
trainer_args=None,
max_num_checkpoints=3,
lookup_table=None,
save_lookup_table=None,
pserver_endpoints=None):
"""
This function filters out all checkpoint variables from the give
......@@ -673,7 +698,7 @@ def save_checkpoint(executor,
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
save_lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
......@@ -704,7 +729,7 @@ def save_checkpoint(executor,
trainer_args=trainer_args,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
save_lookup_table=table_name,
pserver_endpoints = ps_endpoints)
"""
if checkpoint_dir is None:
......@@ -720,15 +745,15 @@ def save_checkpoint(executor,
_make_chekcpoint_dirs(checkpoint_dir)
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_serial_dir(checkpoint_dir, serial, True)
_save_trainer_args(cur_dir, trainer_id, trainer_args)
if is_chief:
_save_persist_vars_without_grad(executor, cur_dir, main_program)
_save_persistable_vars(executor, cur_dir, main_program)
if is_chief and lookup_table and pserver_endpoints:
_save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
if is_chief and save_lookup_table and pserver_endpoints:
_save_pserver_vars_by_notify(executor, cur_dir, save_lookup_table,
pserver_endpoints)
_scroll_delete(checkpoint_dir, max_num_checkpoints)
......@@ -736,10 +761,12 @@ def save_checkpoint(executor,
def load_checkpoint(executor,
checkpoint_dir,
main_program,
main_program=None,
role_id=0,
is_trainer=True,
load_models=True,
load_trainer_args=None,
load_slice_up_vars=None,
load_lookup_table=None):
"""
This function filters out all checkpoint variables from the give
......@@ -762,7 +789,7 @@ def load_checkpoint(executor,
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
main_program(Program|None): The program whose checkpoint variables will
be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
......@@ -794,27 +821,23 @@ def load_checkpoint(executor,
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
serial = _get_latest_checkpoint_serial(checkpoint_dir)
# there are nothing need to be loaded
if serial is None or serial < 0:
return
if main_program is None:
raise ValueError('main_program should not be None.')
if is_trainer and load_trainer_args is None:
cur_dir = _get_serial_dir(checkpoint_dir, serial)
_load_persist_vars_without_grad(executor, cur_dir, main_program, True)
# trainer load
if is_trainer:
if load_models:
_load_persistable_vars(executor, checkpoint_dir, main_program, True)
return
if is_trainer and load_trainer_args:
return _load_trainer_args(checkpoint_dir, serial, role_id,
if load_trainer_args:
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
load_trainer_args)
if not is_trainer and load_lookup_table:
_load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id,
load_lookup_table)
return trainer_args_ret
# pserver load
else:
if load_slice_up_vars:
_load_slice_up_vars(executor, checkpoint_dir, load_slice_up_vars)
return
if load_lookup_table:
_load_lookup_table_vars(executor, checkpoint_dir, main_program,
role_id, load_lookup_table)
def clean_checkpoint(checkpoint_dir, delete_dir=False):
......@@ -835,10 +858,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os.rmdir(checkpoint_dir)
def _load_persist_vars_without_grad(executor,
dirname,
program,
has_model_dir=False):
def _load_persistable_vars(executor, dirname, program, has_model_dir=False):
"""
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
......@@ -867,10 +887,10 @@ def _load_persist_vars_without_grad(executor,
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_load_persist_vars_without_grad(executor=exe,
_load_persistable_vars(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `_load_persist_vars_without_grad` function
# In this example, `_load_persistable_vars` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
......@@ -887,6 +907,51 @@ def _load_persist_vars_without_grad(executor,
filename=None)
def _load_slice_up_vars(executor, dirname, slice_vars):
if slice_vars == None or len(slice_vars) == 0:
return
dirname = _get_model_dir(dirname)
load_prog = framework.Program()
load_block = load_prog.global_block()
for var_tuple in slice_vars:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + reduce(lambda x, y: x * y, slice_var.shape)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
executor.run(load_prog)
def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
"""
The parameter server will load lookup table's local file in
......@@ -937,7 +1002,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
executor.run(load_prog)
def _save_persist_vars_without_grad(executor, dirname, program):
def _save_persistable_vars(executor, dirname, program):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
......@@ -964,10 +1029,10 @@ def _save_persist_vars_without_grad(executor, dirname, program):
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_save_persist_vars_without_grad(executor=exe,
_save_persistable_vars(executor=exe,
dirname=param_path, program=prog)
# In this example, `_save_persist_vars_without_grad` function
# In this example, `_save_persistable_vars` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
......@@ -1043,7 +1108,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
_write_success(cur_dir)
def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
......@@ -1069,8 +1134,7 @@ def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
"""
assert isinstance(trainer_args, list)
cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
cur_dir = _get_trainer_dir(checkpoint_dir, trainer_id)
ret_values = []
......@@ -1125,20 +1189,19 @@ def _make_chekcpoint_dirs(dirs):
def _get_dir_serial(dirname):
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
try:
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
serial_num = int(serial)
except ValueError:
serial_num = -1
return serial_num
def _get_serial_dir(dirname, serial):
def _get_serial_dir(dirname, serial, makedirs=False):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
serial_dir = os.path.join(dirname, serial_folder)
if makedirs:
_make_chekcpoint_dirs(serial_dir)
return serial_dir
......
......@@ -719,6 +719,28 @@ class DistributeTranspiler(object):
}) for ep in self.pserver_endpoints
]
def get_slice_vars_and_atts(self, endpoint):
slice_vars_and_atts = []
block_suffix = ".block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
suff_idx = param.name.find(block_suffix)
if suff_idx <= 0:
continue
orig_var_name = param.name[:suff_idx]
block_idx = int(param.name[suff_idx + len(block_suffix):])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_atts.append([orig_var, skip_numel, param])
return slice_vars_and_atts
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册