提交 9ff4f3fb 编写于 作者: T tangwei

code fix

上级 e3fb25b6
...@@ -71,7 +71,7 @@ class ClusterTrainerWithDataset(TranspileTrainer): ...@@ -71,7 +71,7 @@ class ClusterTrainerWithDataset(TranspileTrainer):
def init(self, context): def init(self, context):
self.model.input() self.model.input()
self.model.net() self.model.build_model()
self.model.metrics() self.model.metrics()
self.model.avg_loss() self.model.avg_loss()
optimizer = self.model.optimizer() optimizer = self.model.optimizer()
...@@ -112,7 +112,7 @@ class ClusterTrainerWithDataset(TranspileTrainer): ...@@ -112,7 +112,7 @@ class ClusterTrainerWithDataset(TranspileTrainer):
fetch_info=self.fetch_alias, fetch_info=self.fetch_alias,
print_period=self.fetch_period) print_period=self.fetch_period)
self.save(i, "train", is_fleet=True) self.save(i, "train", is_fleet=True)
context['status'] = 'infer_pass' context['status'] = 'terminal_pass'
fleet.stop_worker() fleet.stop_worker()
def infer(self, context): def infer(self, context):
......
...@@ -103,7 +103,7 @@ class TranspileTrainer(Trainer): ...@@ -103,7 +103,7 @@ class TranspileTrainer(Trainer):
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet: if is_fleet:
fleet.save_persistables(dirname) fleet.save_persistables(self._exe, dirname)
else: else:
fluid.io.save_persistables(self._exe, dirname) fluid.io.save_persistables(self._exe, dirname)
self.increment_models.append((epoch_id, dirname)) self.increment_models.append((epoch_id, dirname))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册