提交 db8ae9ed 编写于 作者: T tangwei

fix dataloader

上级 30355623
......@@ -73,6 +73,7 @@ class SingleTrainer(TranspileTrainer):
metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_metrics().items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
......@@ -86,12 +87,11 @@ class SingleTrainer(TranspileTrainer):
program=program,
fetch_list=metrics_varnames)
metrics_rets = np.mean(metrics_rets, axis=0)
metrics = [epoch, batch_id]
metrics.extend(metrics_rets.tolist())
metrics.extend(metrics_rets)
if batch_id % 10 == 0 and batch_id != 0:
print(metrics_format.format(metrics))
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
......
......@@ -40,7 +40,7 @@ def dataloader(readerclass, train, yaml_file):
for file in files:
with open(file, 'r') as f:
for line in f:
line = line.rstrip('\n').split('\t')
line = line.rstrip('\n')
iter = reader.generate_sample(line)
for parsed_line in iter():
if parsed_line is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册