From ae87bc8c7ad534a931b7298cd36d7959e69b6f5a Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 22 Sep 2021 12:26:00 +0000 Subject: [PATCH] dump decode result as jsonlines --- deepspeech/exps/deepspeech2/model.py | 10 ++++++---- deepspeech/exps/u2/model.py | 10 ++++++---- deepspeech/exps/u2_kaldi/model.py | 10 ++++++---- deepspeech/exps/u2_st/model.py | 9 ++++++--- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 12053981..646f6f23 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -18,6 +18,7 @@ from collections import defaultdict from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle @@ -305,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + fout.write({"utt": utt, "ref", target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("Current error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -350,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cfg = self.config error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch metrics = self.compute_metrics(utts, audio, audio_len, texts, diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 4dd05489..f1970334 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -21,6 +21,7 @@ from collections import OrderedDict from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle @@ -466,9 +467,10 @@ class U2Tester(U2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + fout.write({"utt": utt, "ref", target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("One example error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -493,7 +495,7 @@ class U2Tester(U2Trainer): errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 num_time = 0.0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_metrics(*batch, fout=fout) num_frames += metrics['num_frames'] diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index e8482aa9..00d78081 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -20,6 +20,7 @@ from collections import defaultdict from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle @@ -445,9 +446,10 @@ class U2Tester(U2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + fout.write({"utt": utt, "ref", target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("One example error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -472,7 +474,7 @@ class U2Tester(U2Trainer): errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 num_time = 0.0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_metrics(*batch, fout=fout) num_frames += metrics['num_frames'] diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index c98f5e69..86bb649b 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -20,6 +20,7 @@ from collections import defaultdict from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle @@ -479,8 +480,10 @@ class U2STTester(U2STTrainer): len_refs += len(target.split()) num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nReference: %s\nHypothesis: %s" % (target, result)) + fout.write({"utt": utt, "ref", target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("One example BLEU = %s" % (bleu_func([result], [[target]]).prec_str)) @@ -508,7 +511,7 @@ class U2STTester(U2STTrainer): len_refs, num_ins = 0, 0 num_frames = 0.0 num_time = 0.0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_translation_metrics( *batch, bleu_func=bleu_func, fout=fout) -- GitLab