From 9ff4f3fbad5fe7545a4934e6f034a8e28b31dc35 Mon Sep 17 00:00:00 2001 From: tangwei Date: Fri, 10 Apr 2020 12:50:36 +0800 Subject: [PATCH] code fix --- fleetrec/trainer/cluster_trainer.py | 4 ++-- fleetrec/trainer/transpiler_trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fleetrec/trainer/cluster_trainer.py b/fleetrec/trainer/cluster_trainer.py index 791b7b59..b82799a0 100644 --- a/fleetrec/trainer/cluster_trainer.py +++ b/fleetrec/trainer/cluster_trainer.py @@ -71,7 +71,7 @@ class ClusterTrainerWithDataset(TranspileTrainer): def init(self, context): self.model.input() - self.model.net() + self.model.build_model() self.model.metrics() self.model.avg_loss() optimizer = self.model.optimizer() @@ -112,7 +112,7 @@ class ClusterTrainerWithDataset(TranspileTrainer): fetch_info=self.fetch_alias, print_period=self.fetch_period) self.save(i, "train", is_fleet=True) - context['status'] = 'infer_pass' + context['status'] = 'terminal_pass' fleet.stop_worker() def infer(self, context): diff --git a/fleetrec/trainer/transpiler_trainer.py b/fleetrec/trainer/transpiler_trainer.py index 3d1fa934..f7fa8e67 100644 --- a/fleetrec/trainer/transpiler_trainer.py +++ b/fleetrec/trainer/transpiler_trainer.py @@ -103,7 +103,7 @@ class TranspileTrainer(Trainer): dirname = os.path.join(dirname, str(epoch_id)) if is_fleet: - fleet.save_persistables(dirname) + fleet.save_persistables(self._exe, dirname) else: fluid.io.save_persistables(self._exe, dirname) self.increment_models.append((epoch_id, dirname)) -- GitLab