diff --git a/fleetrec/trainer/cluster_trainer.py b/fleetrec/trainer/cluster_trainer.py index 791b7b5919953cacd8242ce198b9579ea6d4dd70..b82799a0de6dad6bb7b153ae73725e4d0bced52c 100644 --- a/fleetrec/trainer/cluster_trainer.py +++ b/fleetrec/trainer/cluster_trainer.py @@ -71,7 +71,7 @@ class ClusterTrainerWithDataset(TranspileTrainer): def init(self, context): self.model.input() - self.model.net() + self.model.build_model() self.model.metrics() self.model.avg_loss() optimizer = self.model.optimizer() @@ -112,7 +112,7 @@ class ClusterTrainerWithDataset(TranspileTrainer): fetch_info=self.fetch_alias, print_period=self.fetch_period) self.save(i, "train", is_fleet=True) - context['status'] = 'infer_pass' + context['status'] = 'terminal_pass' fleet.stop_worker() def infer(self, context): diff --git a/fleetrec/trainer/transpiler_trainer.py b/fleetrec/trainer/transpiler_trainer.py index 3d1fa934300023285d0fe2e725245096d616f7e9..f7fa8e679d27b9480479e5e7f9b4200f76855151 100644 --- a/fleetrec/trainer/transpiler_trainer.py +++ b/fleetrec/trainer/transpiler_trainer.py @@ -103,7 +103,7 @@ class TranspileTrainer(Trainer): dirname = os.path.join(dirname, str(epoch_id)) if is_fleet: - fleet.save_persistables(dirname) + fleet.save_persistables(self._exe, dirname) else: fluid.io.save_persistables(self._exe, dirname) self.increment_models.append((epoch_id, dirname))