From db8ae9ed4525142bbcac447d92baa012d6c4a850 Mon Sep 17 00:00:00 2001 From: tangwei Date: Wed, 29 Apr 2020 13:12:30 +0800 Subject: [PATCH] fix dataloader --- fleetrec/core/trainers/single_trainer.py | 6 +++--- fleetrec/core/utils/dataloader_instance.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fleetrec/core/trainers/single_trainer.py b/fleetrec/core/trainers/single_trainer.py index c3da7edd..989297a0 100644 --- a/fleetrec/core/trainers/single_trainer.py +++ b/fleetrec/core/trainers/single_trainer.py @@ -73,6 +73,7 @@ class SingleTrainer(TranspileTrainer): metrics_format.append("{}: {{}}".format("batch")) for name, var in self.model.get_metrics().items(): + metrics_varnames.append(var.name) metrics_format.append("{}: {{}}".format(name)) metrics_format = ", ".join(metrics_format) @@ -86,12 +87,11 @@ class SingleTrainer(TranspileTrainer): program=program, fetch_list=metrics_varnames) - metrics_rets = np.mean(metrics_rets, axis=0) metrics = [epoch, batch_id] - metrics.extend(metrics_rets.tolist()) + metrics.extend(metrics_rets) if batch_id % 10 == 0 and batch_id != 0: - print(metrics_format.format(metrics)) + print(metrics_format.format(*metrics)) batch_id += 1 except fluid.core.EOFException: reader.reset() diff --git a/fleetrec/core/utils/dataloader_instance.py b/fleetrec/core/utils/dataloader_instance.py index 9786c262..8c43c610 100644 --- a/fleetrec/core/utils/dataloader_instance.py +++ b/fleetrec/core/utils/dataloader_instance.py @@ -40,7 +40,7 @@ def dataloader(readerclass, train, yaml_file): for file in files: with open(file, 'r') as f: for line in f: - line = line.rstrip('\n').split('\t') + line = line.rstrip('\n') iter = reader.generate_sample(line) for parsed_line in iter(): if parsed_line is None: -- GitLab