提交 ae87bc8c 编写于 作者: H Hui Zhang

dump decode result as jsonlines

上级 c6e8a33b
......@@ -18,6 +18,7 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import numpy as np
import paddle
......@@ -305,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
......@@ -350,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
with open(self.args.result_file, 'w') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts,
......
......@@ -21,6 +21,7 @@ from collections import OrderedDict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import numpy as np
import paddle
......@@ -466,9 +467,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
......@@ -493,7 +495,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
......
......@@ -20,6 +20,7 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import numpy as np
import paddle
......@@ -445,9 +446,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
......@@ -472,7 +474,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
......
......@@ -20,6 +20,7 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import numpy as np
import paddle
......@@ -479,8 +480,10 @@ class U2STTester(U2STTrainer):
len_refs += len(target.split())
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nReference: %s\nHypothesis: %s" % (target, result))
fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example BLEU = %s" %
(bleu_func([result], [[target]]).prec_str))
......@@ -508,7 +511,7 @@ class U2STTester(U2STTrainer):
len_refs, num_ins = 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_translation_metrics(
*batch, bleu_func=bleu_func, fout=fout)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册