提交 db8ae9ed 编写于 作者: T tangwei

fix dataloader

上级 30355623
...@@ -73,6 +73,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -73,6 +73,7 @@ class SingleTrainer(TranspileTrainer):
metrics_format.append("{}: {{}}".format("batch")) metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_metrics().items(): for name, var in self.model.get_metrics().items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name)) metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format) metrics_format = ", ".join(metrics_format)
...@@ -86,12 +87,11 @@ class SingleTrainer(TranspileTrainer): ...@@ -86,12 +87,11 @@ class SingleTrainer(TranspileTrainer):
program=program, program=program,
fetch_list=metrics_varnames) fetch_list=metrics_varnames)
metrics_rets = np.mean(metrics_rets, axis=0)
metrics = [epoch, batch_id] metrics = [epoch, batch_id]
metrics.extend(metrics_rets.tolist()) metrics.extend(metrics_rets)
if batch_id % 10 == 0 and batch_id != 0: if batch_id % 10 == 0 and batch_id != 0:
print(metrics_format.format(metrics)) print(metrics_format.format(*metrics))
batch_id += 1 batch_id += 1
except fluid.core.EOFException: except fluid.core.EOFException:
reader.reset() reader.reset()
......
...@@ -40,7 +40,7 @@ def dataloader(readerclass, train, yaml_file): ...@@ -40,7 +40,7 @@ def dataloader(readerclass, train, yaml_file):
for file in files: for file in files:
with open(file, 'r') as f: with open(file, 'r') as f:
for line in f: for line in f:
line = line.rstrip('\n').split('\t') line = line.rstrip('\n')
iter = reader.generate_sample(line) iter = reader.generate_sample(line)
for parsed_line in iter(): for parsed_line in iter():
if parsed_line is None: if parsed_line is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册