提交 7efd73ac 编写于 作者: T tangwei12

code clean

上级 bccf8df5
...@@ -81,9 +81,3 @@ message VariableMessage { ...@@ -81,9 +81,3 @@ message VariableMessage {
} }
message VoidMessage {} message VoidMessage {}
message CheckpointMessage {
string varname = 1;
string notify_type = 2;
string checkpoint_dir = 3;
}
...@@ -74,8 +74,8 @@ class CheckpointConfig(object): ...@@ -74,8 +74,8 @@ class CheckpointConfig(object):
self.epoch_id = 0 self.epoch_id = 0
self.step_id = 0 self.step_id = 0
self.load_serial = None self.load_serial = None
self.is_pserver = False self.pserver_id = -1,
self.has_lookup_table = False self.lookup_table_name = None
def check_and_get_place(place): def check_and_get_place(place):
...@@ -174,7 +174,7 @@ class Trainer(object): ...@@ -174,7 +174,7 @@ class Trainer(object):
self.checkpoint_cfg.load_serial, self.checkpoint_cfg.load_serial,
self.startup_program) self.startup_program)
if not self.checkpoint_cfg.is_pserver: if self.checkpoint_cfg.pserver_id != -1:
epoch_id, step_id = io.load_trainer_args( epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id, self.checkpoint_cfg.load_serial, self.trainer_id,
...@@ -182,10 +182,12 @@ class Trainer(object): ...@@ -182,10 +182,12 @@ class Trainer(object):
self.checkpoint_cfg.epoch_id = int(epoch_id) self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id) self.checkpoint_cfg.step_id = int(step_id)
else: else:
if self.checkpoint_cfg.has_lookup_table: if self.checkpoint_cfg.lookup_table_name:
io.load_lookup_table_vars( io.load_lookup_table_vars(
exe, self.checkpoint_cfg.checkpoint_dir, 0, exe, self.checkpoint_cfg.checkpoint_dir,
"table_name") self.startup_program,
self.checkpoint_cfg.pserver_id,
self.checkpoint_cfg.lookup_table_name)
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
...@@ -255,7 +257,10 @@ class Trainer(object): ...@@ -255,7 +257,10 @@ class Trainer(object):
self.trainer_id, pservers=pserver_endpoints, trainers=trainers) self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
if self.checkpoint_cfg: if self.checkpoint_cfg:
self.is_pserver = True pserver_id = eplist.index(current_endpoint)
self.checkpoint_cfg.pserver_id = pserver_id
if t.has_distributed_lookup_table:
self.checkpoint_cfg.lookup_table_name = t.table_name
self.train_program = t.get_pserver_program(current_endpoint) self.train_program = t.get_pserver_program(current_endpoint)
self.startup_program = t.get_startup_program(current_endpoint, self.startup_program = t.get_startup_program(current_endpoint,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册