diff --git a/fleetrec/core/trainers/single_trainer.py b/fleetrec/core/trainers/single_trainer.py index c3da7edd9619458b53988f415c8e339256e3d6de..989297a04f6d031588ec58d606f9e74a0d5e4556 100644 --- a/fleetrec/core/trainers/single_trainer.py +++ b/fleetrec/core/trainers/single_trainer.py @@ -73,6 +73,7 @@ class SingleTrainer(TranspileTrainer): metrics_format.append("{}: {{}}".format("batch")) for name, var in self.model.get_metrics().items(): + metrics_varnames.append(var.name) metrics_format.append("{}: {{}}".format(name)) metrics_format = ", ".join(metrics_format) @@ -86,12 +87,11 @@ class SingleTrainer(TranspileTrainer): program=program, fetch_list=metrics_varnames) - metrics_rets = np.mean(metrics_rets, axis=0) metrics = [epoch, batch_id] - metrics.extend(metrics_rets.tolist()) + metrics.extend(metrics_rets) if batch_id % 10 == 0 and batch_id != 0: - print(metrics_format.format(metrics)) + print(metrics_format.format(*metrics)) batch_id += 1 except fluid.core.EOFException: reader.reset() diff --git a/fleetrec/core/utils/dataloader_instance.py b/fleetrec/core/utils/dataloader_instance.py index 9786c262eb9078f2a463c1230f9fb1412adece9a..8c43c610a29394b3b334d755bc95d21fb0f514ce 100644 --- a/fleetrec/core/utils/dataloader_instance.py +++ b/fleetrec/core/utils/dataloader_instance.py @@ -40,7 +40,7 @@ def dataloader(readerclass, train, yaml_file): for file in files: with open(file, 'r') as f: for line in f: - line = line.rstrip('\n').split('\t') + line = line.rstrip('\n') iter = reader.generate_sample(line) for parsed_line in iter(): if parsed_line is None: