提交 ac0ae57e 编写于 作者: J Junkun

add collactor and evaluation code for ST

上级 03231519
...@@ -24,7 +24,6 @@ from typing import Tuple ...@@ -24,7 +24,6 @@ from typing import Tuple
import numpy as np import numpy as np
import paddle import paddle
import sacrebleu
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
...@@ -32,6 +31,7 @@ from yacs.config import CfgNode ...@@ -32,6 +31,7 @@ from yacs.config import CfgNode
from deepspeech.io.collator_st import KaldiPrePorocessedCollator from deepspeech.io.collator_st import KaldiPrePorocessedCollator
from deepspeech.io.collator_st import SpeechCollator from deepspeech.io.collator_st import SpeechCollator
from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator
from deepspeech.io.collator_st import TripletSpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.dataset import TripletManifestDataset from deepspeech.io.dataset import TripletManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
...@@ -40,6 +40,7 @@ from deepspeech.models.u2_st import U2STModel ...@@ -40,6 +40,7 @@ from deepspeech.models.u2_st import U2STModel
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import bleu_score
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
...@@ -248,7 +249,11 @@ class U2STTrainer(Trainer): ...@@ -248,7 +249,11 @@ class U2STTrainer(Trainer):
dev_dataset = Dataset.from_config(config) dev_dataset = Dataset.from_config(config)
if config.collator.raw_wav: if config.collator.raw_wav:
TestCollator = Collator = SpeechCollator if config.model.model_conf.asr_weight > 0.:
Collator = TripletSpeechCollator
TestCollator = SpeechCollator
else:
TestCollator = Collator = SpeechCollator
# Not yet implement the mtl loader for raw_wav. # Not yet implement the mtl loader for raw_wav.
else: else:
if config.model.model_conf.asr_weight > 0.: if config.model.model_conf.asr_weight > 0.:
...@@ -393,7 +398,7 @@ class U2STTester(U2STTrainer): ...@@ -393,7 +398,7 @@ class U2STTester(U2STTrainer):
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search', decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search',
# 'ctc_prefix_beam_search', 'attention_rescoring' # 'ctc_prefix_beam_search', 'attention_rescoring'
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' error_rate_type='bleu', # Error rate type for evaluation. Options `bleu`, 'char_bleu'
num_proc_bsearch=8, # # of CPUs for beam search. num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=10, # Beam search width. beam_size=10, # Beam search width.
batch_size=16, # decoding batch size batch_size=16, # decoding batch size
...@@ -428,10 +433,10 @@ class U2STTester(U2STTrainer): ...@@ -428,10 +433,10 @@ class U2STTester(U2STTrainer):
audio_len, audio_len,
texts, texts,
texts_len, texts_len,
bleu_func,
fout=None): fout=None):
cfg = self.config.decoding cfg = self.config.decoding
len_refs, num_ins = 0, 0 len_refs, num_ins = 0, 0
bleu_func = sacrebleu.corpus_bleu
start_time = time.time() start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature text_feature = self.test_loader.collate_fn.text_feature
...@@ -487,6 +492,9 @@ class U2STTester(U2STTrainer): ...@@ -487,6 +492,9 @@ class U2STTester(U2STTrainer):
self.model.eval() self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
cfg = self.config.decoding
bleu_func = bleu_score.char_bleu if cfg.error_rate_type == 'char-bleu' else bleu_score.bleu
stride_ms = self.test_loader.collate_fn.stride_ms stride_ms = self.test_loader.collate_fn.stride_ms
hyps, refs = [], [] hyps, refs = [], []
len_refs, num_ins = 0, 0 len_refs, num_ins = 0, 0
...@@ -495,7 +503,7 @@ class U2STTester(U2STTrainer): ...@@ -495,7 +503,7 @@ class U2STTester(U2STTrainer):
with open(self.args.result_file, 'w') as fout: with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_translation_metrics( metrics = self.compute_translation_metrics(
*batch, fout=fout) *batch, bleu_func=bleu_func, fout=fout)
hyps += metrics['hyps'] hyps += metrics['hyps']
refs += metrics['refs'] refs += metrics['refs']
bleu = metrics['bleu'] bleu = metrics['bleu']
...@@ -504,19 +512,16 @@ class U2STTester(U2STTrainer): ...@@ -504,19 +512,16 @@ class U2STTester(U2STTrainer):
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']
rtf = num_time / (num_frames * stride_ms) rtf = num_time / (num_frames * stride_ms)
logger.info("RTF: %f, BELU (%d) = %f" % logger.info("RTF: %f, BELU (%d) = %f" % (rtf, num_ins, bleu))
(rtf, num_ins, bleu))
rtf = num_time / (num_frames * stride_ms) rtf = num_time / (num_frames * stride_ms)
msg = "Test: " msg = "Test: "
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "RTF: {}, ".format(rtf) msg += "RTF: {}, ".format(rtf)
msg += "Test set [%s]: %s" % ( msg += "Test set [%s]: %s" % (len(hyps), str(bleu_func(hyps, [refs])))
len(hyps), str(sacrebleu.corpus_bleu(hyps, [refs])))
logger.info(msg) logger.info(msg)
bleu_meta_path = os.path.splitext( bleu_meta_path = os.path.splitext(self.args.result_file)[0] + '.bleu'
self.args.result_file)[0] + '.bleu'
err_type_str = "BLEU" err_type_str = "BLEU"
with open(bleu_meta_path, 'w') as f: with open(bleu_meta_path, 'w') as f:
data = json.dumps({ data = json.dumps({
...@@ -527,7 +532,7 @@ class U2STTester(U2STTrainer): ...@@ -527,7 +532,7 @@ class U2STTester(U2STTrainer):
"rtf": "rtf":
rtf, rtf,
err_type_str: err_type_str:
sacrebleu.corpus_bleu(hyps, [refs]).score, bleu_func(hyps, [refs]).score,
"dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0, "dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
"process_hour": "process_hour":
num_time / 1000.0 / 3600.0, num_time / 1000.0 / 3600.0,
......
此差异已折叠。
...@@ -19,9 +19,7 @@ from yacs.config import CfgNode ...@@ -19,9 +19,7 @@ from yacs.config import CfgNode
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = [ __all__ = ["ManifestDataset", "TripletManifestDataset"]
"ManifestDataset",
]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -105,3 +103,16 @@ class ManifestDataset(Dataset): ...@@ -105,3 +103,16 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
instance = self._manifest[idx] instance = self._manifest[idx]
return instance["utt"], instance["feat"], instance["text"] return instance["utt"], instance["feat"], instance["text"]
class TripletManifestDataset(ManifestDataset):
"""
For Joint Training of Speech Translation and ASR.
text: translation,
text1: transcript.
"""
def __getitem__(self, idx):
instance = self._manifest[idx]
return instance["utt"], instance["feat"], instance["text"], instance[
"text1"]
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
"""
import numpy as np
import sacrebleu
__all__ = ['bleu', 'char_bleu']
def bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in word-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference length is zero.
"""
return sacrebleu.corpus_bleu(hypothesis, reference)
def char_bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in char-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference number is zero.
"""
hypothesis =[' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis]
reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref ]for ref in reference ]
return sacrebleu.corpus_bleu(hypothesis, reference)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册