提交 ae87bc8c 编写于 作者: H Hui Zhang

dump decode result as jsonlines

上级 c6e8a33b
...@@ -18,6 +18,7 @@ from collections import defaultdict ...@@ -18,6 +18,7 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
...@@ -305,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -305,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("Current error rate [%s] = %f" % logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
...@@ -350,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -350,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cfg = self.config cfg = self.config
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts, metrics = self.compute_metrics(utts, audio, audio_len, texts,
......
...@@ -21,6 +21,7 @@ from collections import OrderedDict ...@@ -21,6 +21,7 @@ from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
...@@ -466,9 +467,10 @@ class U2Tester(U2Trainer): ...@@ -466,9 +467,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
...@@ -493,7 +495,7 @@ class U2Tester(U2Trainer): ...@@ -493,7 +495,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout) metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames'] num_frames += metrics['num_frames']
......
...@@ -20,6 +20,7 @@ from collections import defaultdict ...@@ -20,6 +20,7 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
...@@ -445,9 +446,10 @@ class U2Tester(U2Trainer): ...@@ -445,9 +446,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
...@@ -472,7 +474,7 @@ class U2Tester(U2Trainer): ...@@ -472,7 +474,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout) metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames'] num_frames += metrics['num_frames']
......
...@@ -20,6 +20,7 @@ from collections import defaultdict ...@@ -20,6 +20,7 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
...@@ -479,8 +480,10 @@ class U2STTester(U2STTrainer): ...@@ -479,8 +480,10 @@ class U2STTester(U2STTrainer):
len_refs += len(target.split()) len_refs += len(target.split())
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info("\nReference: %s\nHypothesis: %s" % (target, result)) logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example BLEU = %s" % logger.info("One example BLEU = %s" %
(bleu_func([result], [[target]]).prec_str)) (bleu_func([result], [[target]]).prec_str))
...@@ -508,7 +511,7 @@ class U2STTester(U2STTrainer): ...@@ -508,7 +511,7 @@ class U2STTester(U2STTrainer):
len_refs, num_ins = 0, 0 len_refs, num_ins = 0, 0
num_frames = 0.0 num_frames = 0.0
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_translation_metrics( metrics = self.compute_translation_metrics(
*batch, bleu_func=bleu_func, fout=fout) *batch, bleu_func=bleu_func, fout=fout)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册