提交 d712af25 编写于 作者: T tangwei12

add distribute config

上级 0deb6f90
...@@ -81,7 +81,8 @@ class CheckpointConfig(object): ...@@ -81,7 +81,8 @@ class CheckpointConfig(object):
self.epoch_id = 0 self.epoch_id = 0
self.step_id = 0 self.step_id = 0
self._load_serial = None self.load_serial = None
self.is_pserver = False
def check_and_get_place(place): def check_and_get_place(place):
...@@ -145,7 +146,7 @@ class Trainer(object): ...@@ -145,7 +146,7 @@ class Trainer(object):
"The checkpoint_config shoule be an instance of CheckpointConfig" "The checkpoint_config shoule be an instance of CheckpointConfig"
) )
else: else:
self.checkpoint._load_serial = io.need_load_checkpoint( self.checkpoint.load_serial = io.need_load_checkpoint(
self.checkpoint.checkpoint_dir) self.checkpoint.checkpoint_dir)
self.scope = core.Scope() self.scope = core.Scope()
...@@ -176,17 +177,18 @@ class Trainer(object): ...@@ -176,17 +177,18 @@ class Trainer(object):
exe = executor.Executor(place) exe = executor.Executor(place)
exe.run(self.startup_program) exe.run(self.startup_program)
if self.checkpoint and self.checkpoint._load_serial: if self.checkpoint and self.checkpoint.load_serial:
exe = executor.Executor(place) exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.checkpoint._load_serial, self.checkpoint.load_serial,
self.startup_program) self.startup_program)
epoch_id, step_id = io.load_trainer_args( if not self.checkpoint.is_pserver:
self.checkpoint.checkpoint_dir, self.checkpoint._load_serial, epoch_id, step_id = io.load_trainer_args(
self.trainer_id, ["epoch_id", "step_id"]) self.checkpoint.checkpoint_dir, self.checkpoint.load_serial,
self.checkpoint.epoch_id = int(epoch_id) self.trainer_id, ["epoch_id", "step_id"])
self.checkpoint.step_id = int(step_id) self.checkpoint.epoch_id = int(epoch_id)
self.checkpoint.step_id = int(step_id)
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
...@@ -259,6 +261,9 @@ class Trainer(object): ...@@ -259,6 +261,9 @@ class Trainer(object):
t.transpile( t.transpile(
trainer_id, pservers=pserver_endpoints, trainers=trainers) trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
if self.checkpoint:
self.is_pserver = True
self.train_program = t.get_pserver_program(current_endpoint) self.train_program = t.get_pserver_program(current_endpoint)
self.startup_program = t.get_startup_program(current_endpoint, self.startup_program = t.get_startup_program(current_endpoint,
self.train_program) self.train_program)
...@@ -362,7 +367,7 @@ class Trainer(object): ...@@ -362,7 +367,7 @@ class Trainer(object):
self._clean_checkpoint() self._clean_checkpoint()
return return
if self.checkpoint and self.checkpoint._load_serial \ if self.checkpoint and self.checkpoint.load_serial \
and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id: and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id:
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册