From 3d994f5c23a86b97d058400b6f6b06dafad064ce Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Tue, 11 Oct 2022 16:53:10 +0000 Subject: [PATCH] format wav2vec2 demo --- examples/librispeech/asr3/conf/preprocess.yaml | 4 ++-- examples/librispeech/asr3/conf/tuning/decode.yaml | 9 +-------- examples/librispeech/asr3/run.sh | 3 +-- paddlespeech/audio/transform/spectrogram.py | 2 +- paddlespeech/s2t/exps/wav2vec2/bin/test.py | 2 -- paddlespeech/s2t/exps/wav2vec2/model.py | 3 --- paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py | 2 +- paddlespeech/s2t/modules/ctc.py | 7 +++++-- 8 files changed, 11 insertions(+), 21 deletions(-) diff --git a/examples/librispeech/asr3/conf/preprocess.yaml b/examples/librispeech/asr3/conf/preprocess.yaml index 3979d256..4a908a83 100644 --- a/examples/librispeech/asr3/conf/preprocess.yaml +++ b/examples/librispeech/asr3/conf/preprocess.yaml @@ -1,4 +1,4 @@ process: - # extract kaldi fbank from PCM + # use raw audio - type: wav_process - dither: 0.1 + dither: 0.0 diff --git a/examples/librispeech/asr3/conf/tuning/decode.yaml b/examples/librispeech/asr3/conf/tuning/decode.yaml index c2261fb2..2ba39326 100644 --- a/examples/librispeech/asr3/conf/tuning/decode.yaml +++ b/examples/librispeech/asr3/conf/tuning/decode.yaml @@ -1,11 +1,4 @@ decode_batch_size: 1 error_rate_type: wer -decoding_method: ctc_greedy_search # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' +decoding_method: ctc_greedy_search # 'ctc_greedy_search', 'ctc_prefix_beam_search' beam_size: 10 -ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. -decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. - # <0: for decoding, use full chunk. - # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. -num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. -simulate_streaming: False # simulate streaming inference. Defaults to False. diff --git a/examples/librispeech/asr3/run.sh b/examples/librispeech/asr3/run.sh index 55b2ca86..3b1abb11 100644 --- a/examples/librispeech/asr3/run.sh +++ b/examples/librispeech/asr3/run.sh @@ -36,9 +36,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi - if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - # attetion resocre decoder + # greedy search decoder CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi diff --git a/paddlespeech/audio/transform/spectrogram.py b/paddlespeech/audio/transform/spectrogram.py index 2e519939..cba60cfd 100644 --- a/paddlespeech/audio/transform/spectrogram.py +++ b/paddlespeech/audio/transform/spectrogram.py @@ -383,7 +383,7 @@ class LogMelSpectrogramKaldi(): class WavProcess(): - def __init__(self, dither=0.1): + def __init__(self, dither=0.0): """ Args: dither (float): Dithering constant diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/test.py b/paddlespeech/s2t/exps/wav2vec2/bin/test.py index 4fa224c3..d1a6fd40 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/test.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/test.py @@ -20,8 +20,6 @@ from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments -# TODO(hui zhang): dynamic load - def main_sp(config, args): exp = Tester(config, args) diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 32cf0b47..d845d8c6 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -25,9 +25,7 @@ import paddle from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer -from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import DataLoaderFactory -from paddlespeech.s2t.io.dataloader import StreamDataLoader from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR from paddlespeech.s2t.training.optimizer import OptimizerFactory @@ -300,7 +298,6 @@ class Wav2Vec2ASRTrainer(Trainer): "epsilon": optim_conf.epsilon, "rho": optim_conf.rho, "parameters": parameters, - "epsilon": 1e-9 if optim_type == 'noam' else None, "beta1": 0.9 if optim_type == 'noam' else None, "beat2": 0.98 if optim_type == 'noam' else None, } diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index f54748f8..0d99e870 100644 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -39,7 +39,7 @@ class Wav2vec2ASR(nn.Layer): enc_n_units=config.dnn_neurons, blank_id=config.blank_id, dropout_rate=config.ctc_dropout_rate, - reduction=True) + reduction='mean') def forward(self, wav, wavs_lens_rate, target, target_lens_rate): if self.normalize_wav: diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 0f50db21..e0c01ab4 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -53,7 +53,7 @@ class CTCDecoderBase(nn.Layer): enc_n_units, blank_id=0, dropout_rate: float=0.0, - reduction: bool=True, + reduction: Union[str, bool]=True, batch_average: bool=True, grad_norm_type: Union[str, None]=None): """CTC decoder @@ -73,7 +73,10 @@ class CTCDecoderBase(nn.Layer): self.odim = odim self.dropout = nn.Dropout(dropout_rate) self.ctc_lo = Linear(enc_n_units, self.odim) - reduction_type = "sum" if reduction else "none" + if isinstance(reduction, bool): + reduction_type = "sum" if reduction else "none" + else: + reduction_type = reduction self.criterion = CTCLoss( blank=self.blank_id, reduction=reduction_type, -- GitLab