提交 3c2869a6 编写于 作者: T tangwei

fix cluster

上级 d2522c61
......@@ -93,9 +93,6 @@ class ClusterTrainer(TranspileTrainer):
context['is_exit'] = True
def dataloader_train(self, context):
pass
def dataset_train(self, context):
self._exe.run(fleet.startup_program)
fleet.init_worker()
......@@ -142,6 +139,23 @@ class ClusterTrainer(TranspileTrainer):
fleet.stop_worker()
context['status'] = 'terminal_pass'
def dataset_train(self, context):
self._exe.run(fleet.startup_program)
fleet.init_worker()
dataset = self._get_dataset()
epochs = envs.get_global_env("train.epochs")
for i in range(epochs):
self._exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset,
fetch_list=self.fetch_vars,
fetch_info=self.fetch_alias,
print_period=self.fetch_period)
self.save(i, "train", is_fleet=True)
fleet.stop_worker()
context['status'] = 'terminal_pass'
def infer(self, context):
context['status'] = 'terminal_pass'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册