提交 21630068 编写于 作者: T tangwei

mac/windows fix

上级 6ef9bab0
...@@ -64,6 +64,7 @@ class ClusterTrainer(TranspileTrainer): ...@@ -64,6 +64,7 @@ class ClusterTrainer(TranspileTrainer):
assert strategy is not None assert strategy is not None
self.strategy = strategy
return strategy return strategy
def init(self, context): def init(self, context):
...@@ -96,20 +97,50 @@ class ClusterTrainer(TranspileTrainer): ...@@ -96,20 +97,50 @@ class ClusterTrainer(TranspileTrainer):
def dataset_train(self, context): def dataset_train(self, context):
self._exe.run(fleet.startup_program) self._exe.run(fleet.startup_program)
fleet.init_worker() fleet.init_worker()
dataset = self._get_dataset() reader = self._get_dataloader()
epochs = envs.get_global_env("train.epochs") epochs = envs.get_global_env("train.epochs")
for i in range(epochs): program = fluid.compiler.CompiledProgram(
self._exe.train_from_dataset(program=fluid.default_main_program(), fleet.main_program).with_data_parallel(
dataset=dataset, loss_name=self.model.get_cost_op().name,
fetch_list=self.fetch_vars, build_strategy=self.strategy.get_build_strategy(),
fetch_info=self.fetch_alias, exec_strategy=self.strategy.get_execute_strategy())
print_period=self.fetch_period)
self.save(i, "train", is_fleet=True) metrics_varnames = []
context['status'] = 'terminal_pass' metrics_format = []
metrics_format.append("{}: {{}}".format("epoch"))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_metrics().items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
for epoch in range(epochs):
reader.start()
batch_id = 0
try:
while True:
metrics_rets = self._exe.run(
program=program,
fetch_list=metrics_varnames)
metrics = [epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % 10 == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
fleet.stop_worker() fleet.stop_worker()
context['status'] = 'terminal_pass'
def infer(self, context): def infer(self, context):
context['status'] = 'terminal_pass' context['status'] = 'terminal_pass'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册