diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index f5800cdb7f7e124426bbb970d00b429894a110b4..e0902320cff003797b12ed0204f7f99c44554b62 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -81,9 +81,3 @@ message VariableMessage { } message VoidMessage {} - -message CheckpointMessage { - string varname = 1; - string notify_type = 2; - string checkpoint_dir = 3; -} diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index f77c0f65dcb97bd48c94783a903380fd10a9bd95..6fc456f47556282a854acfbb2e892bf4ab368b82 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -74,8 +74,8 @@ class CheckpointConfig(object): self.epoch_id = 0 self.step_id = 0 self.load_serial = None - self.is_pserver = False - self.has_lookup_table = False + self.pserver_id = -1, + self.lookup_table_name = None def check_and_get_place(place): @@ -174,7 +174,7 @@ class Trainer(object): self.checkpoint_cfg.load_serial, self.startup_program) - if not self.checkpoint_cfg.is_pserver: + if self.checkpoint_cfg.pserver_id != -1: epoch_id, step_id = io.load_trainer_args( self.checkpoint_cfg.checkpoint_dir, self.checkpoint_cfg.load_serial, self.trainer_id, @@ -182,10 +182,12 @@ class Trainer(object): self.checkpoint_cfg.epoch_id = int(epoch_id) self.checkpoint_cfg.step_id = int(step_id) else: - if self.checkpoint_cfg.has_lookup_table: + if self.checkpoint_cfg.lookup_table_name: io.load_lookup_table_vars( - exe, self.checkpoint_cfg.checkpoint_dir, 0, - "table_name") + exe, self.checkpoint_cfg.checkpoint_dir, + self.startup_program, + self.checkpoint_cfg.pserver_id, + self.checkpoint_cfg.lookup_table_name) if param_path and os.path.isdir(param_path): # load params from param_path into scope @@ -255,7 +257,10 @@ class Trainer(object): self.trainer_id, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": if self.checkpoint_cfg: - self.is_pserver = True + pserver_id = eplist.index(current_endpoint) + self.checkpoint_cfg.pserver_id = pserver_id + if t.has_distributed_lookup_table: + self.checkpoint_cfg.lookup_table_name = t.table_name self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint,