提交 9ff4f3fb 编写于 作者: T tangwei

code fix

上级 e3fb25b6
......@@ -71,7 +71,7 @@ class ClusterTrainerWithDataset(TranspileTrainer):
def init(self, context):
self.model.input()
self.model.net()
self.model.build_model()
self.model.metrics()
self.model.avg_loss()
optimizer = self.model.optimizer()
......@@ -112,7 +112,7 @@ class ClusterTrainerWithDataset(TranspileTrainer):
fetch_info=self.fetch_alias,
print_period=self.fetch_period)
self.save(i, "train", is_fleet=True)
context['status'] = 'infer_pass'
context['status'] = 'terminal_pass'
fleet.stop_worker()
def infer(self, context):
......
......@@ -103,7 +103,7 @@ class TranspileTrainer(Trainer):
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_persistables(dirname)
fleet.save_persistables(self._exe, dirname)
else:
fluid.io.save_persistables(self._exe, dirname)
self.increment_models.append((epoch_id, dirname))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册