提交 a8959162 编写于 作者: T tangwei12

[wip] add load lookup table in io and trianer

上级 41701969
...@@ -25,7 +25,8 @@ __all__ = [ ...@@ -25,7 +25,8 @@ __all__ = [
'load_persistables', 'save_inference_model', 'load_inference_model', 'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint', 'load_persist_vars_without_grad', 'clean_checkpoint', 'load_persist_vars_without_grad',
'save_persist_vars_without_grad', 'get_latest_checkpoint_serial' 'load_lookup_table_vars', 'save_persist_vars_without_grad',
'get_latest_checkpoint_serial'
] ]
...@@ -459,7 +460,9 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -459,7 +460,9 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS_MARK_FILENAME = "_SUCCESS" SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint" CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__" MODEL_DIR = "__model__"
LOOKUP_TABLE_DIR = "__lookup_table__"
TRAINER_PREFIX = "trainer" TRAINER_PREFIX = "trainer"
PSERVER_PREFIX = "pserver"
CHECKPOINT_SEPARATOR = "_" CHECKPOINT_SEPARATOR = "_"
...@@ -567,6 +570,14 @@ def load_persist_vars_without_grad(executor, ...@@ -567,6 +570,14 @@ def load_persist_vars_without_grad(executor,
filename=None) filename=None)
def load_lookup_table_vars(executor, dirname, pserver_id, table_name):
lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
table_file = table_name + CHECKPOINT_SEPARATOR + PSERVER_PREFIX + CHECKPOINT_SEPARATOR + str(
pserver_id)
load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file)
def save_persist_vars_without_grad(executor, dirname, program): def save_persist_vars_without_grad(executor, dirname, program):
""" """
save_persist_vars_without_grad will save variables to a directory by an executor, save_persist_vars_without_grad will save variables to a directory by an executor,
......
...@@ -62,27 +62,20 @@ class CheckpointConfig(object): ...@@ -62,27 +62,20 @@ class CheckpointConfig(object):
max_num_checkpoints=3, max_num_checkpoints=3,
epoch_interval=1, epoch_interval=1,
step_interval=10): step_interval=10):
if checkpoint_dir is None:
self.checkpoint_dir = os.getcwd()
else:
self.checkpoint_dir = checkpoint_dir
self.max_num_checkpoints = max_num_checkpoints assert epoch_interval >= 1
assert step_interval >= 1
if epoch_interval < 1: self.checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else os.getcwd(
self.epoch_interval = 1 )
else: self.max_num_checkpoints = max_num_checkpoints
self.epoch_interval = epoch_interval self.epoch_interval = epoch_interval
if step_interval < 1:
self.step_interval = 10
else:
self.step_interval = step_interval self.step_interval = step_interval
self.epoch_id = 0 self.epoch_id = 0
self.step_id = 0 self.step_id = 0
self.load_serial = None self.load_serial = None
self.is_pserver = False self.is_pserver = False
self.has_lookup_table = False
def check_and_get_place(place): def check_and_get_place(place):
...@@ -188,6 +181,11 @@ class Trainer(object): ...@@ -188,6 +181,11 @@ class Trainer(object):
self._get_checkpoint_load_args()) self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id) self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id) self.checkpoint_cfg.step_id = int(step_id)
else:
if self.checkpoint_cfg.has_lookup_table:
io.load_lookup_table_vars(
exe, self.checkpoint_cfg.checkpoint_dir, 0,
"table_name")
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册