提交 a8959162 编写于 作者: T tangwei12

[wip] add load lookup table in io and trianer

上级 41701969
......@@ -25,7 +25,8 @@ __all__ = [
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint', 'load_persist_vars_without_grad',
'save_persist_vars_without_grad', 'get_latest_checkpoint_serial'
'load_lookup_table_vars', 'save_persist_vars_without_grad',
'get_latest_checkpoint_serial'
]
......@@ -459,7 +460,9 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__"
LOOKUP_TABLE_DIR = "__lookup_table__"
TRAINER_PREFIX = "trainer"
PSERVER_PREFIX = "pserver"
CHECKPOINT_SEPARATOR = "_"
......@@ -567,6 +570,14 @@ def load_persist_vars_without_grad(executor,
filename=None)
def load_lookup_table_vars(executor, dirname, pserver_id, table_name):
lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
table_file = table_name + CHECKPOINT_SEPARATOR + PSERVER_PREFIX + CHECKPOINT_SEPARATOR + str(
pserver_id)
load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file)
def save_persist_vars_without_grad(executor, dirname, program):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
......
......@@ -62,27 +62,20 @@ class CheckpointConfig(object):
max_num_checkpoints=3,
epoch_interval=1,
step_interval=10):
if checkpoint_dir is None:
self.checkpoint_dir = os.getcwd()
else:
self.checkpoint_dir = checkpoint_dir
self.max_num_checkpoints = max_num_checkpoints
if epoch_interval < 1:
self.epoch_interval = 1
else:
self.epoch_interval = epoch_interval
if step_interval < 1:
self.step_interval = 10
else:
self.step_interval = step_interval
assert epoch_interval >= 1
assert step_interval >= 1
self.checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else os.getcwd(
)
self.max_num_checkpoints = max_num_checkpoints
self.epoch_interval = epoch_interval
self.step_interval = step_interval
self.epoch_id = 0
self.step_id = 0
self.load_serial = None
self.is_pserver = False
self.has_lookup_table = False
def check_and_get_place(place):
......@@ -181,13 +174,18 @@ class Trainer(object):
self.checkpoint_cfg.load_serial,
self.startup_program)
if not self.checkpoint_cfg.is_pserver:
epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
if not self.checkpoint_cfg.is_pserver:
epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
else:
if self.checkpoint_cfg.has_lookup_table:
io.load_lookup_table_vars(
exe, self.checkpoint_cfg.checkpoint_dir, 0,
"table_name")
if param_path and os.path.isdir(param_path):
# load params from param_path into scope
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册