From 2163006838bca9a4f00fd00bf707d6ea84c579ab Mon Sep 17 00:00:00 2001 From: tangwei Date: Wed, 29 Apr 2020 14:42:18 +0800 Subject: [PATCH] mac/windows fix --- fleetrec/core/trainers/cluster_trainer.py | 49 ++++++++++++++++++----- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/fleetrec/core/trainers/cluster_trainer.py b/fleetrec/core/trainers/cluster_trainer.py index 1ab70d3b..4a204bce 100644 --- a/fleetrec/core/trainers/cluster_trainer.py +++ b/fleetrec/core/trainers/cluster_trainer.py @@ -64,6 +64,7 @@ class ClusterTrainer(TranspileTrainer): assert strategy is not None + self.strategy = strategy return strategy def init(self, context): @@ -96,20 +97,50 @@ class ClusterTrainer(TranspileTrainer): def dataset_train(self, context): self._exe.run(fleet.startup_program) + fleet.init_worker() - dataset = self._get_dataset() + reader = self._get_dataloader() epochs = envs.get_global_env("train.epochs") - for i in range(epochs): - self._exe.train_from_dataset(program=fluid.default_main_program(), - dataset=dataset, - fetch_list=self.fetch_vars, - fetch_info=self.fetch_alias, - print_period=self.fetch_period) - self.save(i, "train", is_fleet=True) - context['status'] = 'terminal_pass' + program = fluid.compiler.CompiledProgram( + fleet.main_program).with_data_parallel( + loss_name=self.model.get_cost_op().name, + build_strategy=self.strategy.get_build_strategy(), + exec_strategy=self.strategy.get_execute_strategy()) + + metrics_varnames = [] + metrics_format = [] + + metrics_format.append("{}: {{}}".format("epoch")) + metrics_format.append("{}: {{}}".format("batch")) + + for name, var in self.model.get_metrics().items(): + metrics_varnames.append(var.name) + metrics_format.append("{}: {{}}".format(name)) + + metrics_format = ", ".join(metrics_format) + + for epoch in range(epochs): + reader.start() + batch_id = 0 + try: + while True: + metrics_rets = self._exe.run( + program=program, + fetch_list=metrics_varnames) + + metrics = [epoch, batch_id] + metrics.extend(metrics_rets) + + if batch_id % 10 == 0 and batch_id != 0: + print(metrics_format.format(*metrics)) + batch_id += 1 + except fluid.core.EOFException: + reader.reset() + fleet.stop_worker() + context['status'] = 'terminal_pass' def infer(self, context): context['status'] = 'terminal_pass' -- GitLab