未验证 提交 8d349432 编写于 作者: Z Zth9730 提交者: GitHub

[ASR] wav2vec2_en, test=asr (#2637)

* wav2vec2_en, test=asr

* wav2vec2_en, test=asr

* wav2vec2_en, test=asr
上级 07279848
...@@ -22,7 +22,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | ...@@ -22,7 +22,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
Model | Pre-Train Method | Pre-Train Data | Finetune Data | Size | Descriptions | CER | WER | Example Link | Model | Pre-Train Method | Pre-Train Data | Finetune Data | Size | Descriptions | CER | WER | Example Link |
:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: | :-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: |
[Wav2vec2-large-960h-lv60-self Model](https://paddlespeech.bj.bcebos.com/wav2vec/wav2vec2-large-960h-lv60-self.pdparams) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | - | 1.18 GB |Pre-trained Wav2vec2.0 Model | - | - | - | [Wav2vec2-large-960h-lv60-self Model](https://paddlespeech.bj.bcebos.com/wav2vec/wav2vec2-large-960h-lv60-self.pdparams) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | - | 1.18 GB |Pre-trained Wav2vec2.0 Model | - | - | - |
[Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.0.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 1.18 GB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) | [Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 718 MB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) |
### Language Model based on NGram ### Language Model based on NGram
Language Model | Training Data | Token-based | Size | Descriptions Language Model | Training Data | Token-based | Size | Descriptions
......
...@@ -70,7 +70,6 @@ train_manifest: data/manifest.train ...@@ -70,7 +70,6 @@ train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean test_manifest: data/manifest.test-clean
########################################### ###########################################
# Dataloader # # Dataloader #
########################################### ###########################################
...@@ -95,6 +94,12 @@ dist_sampler: True ...@@ -95,6 +94,12 @@ dist_sampler: True
shortest_first: True shortest_first: True
return_lens_rate: True return_lens_rate: True
############################################
# Data Augmentation #
############################################
audio_augment: # for raw audio
sample_rate: 16000
speeds: [95, 100, 105]
########################################### ###########################################
# Training # # Training #
...@@ -115,6 +120,3 @@ log_interval: 1 ...@@ -115,6 +120,3 @@ log_interval: 1
checkpoint: checkpoint:
kbest_n: 50 kbest_n: 50
latest_n: 5 latest_n: 5
augment: True
...@@ -71,7 +71,8 @@ class Wav2Vec2ASRTrainer(Trainer): ...@@ -71,7 +71,8 @@ class Wav2Vec2ASRTrainer(Trainer):
wavs_lens_rate = wavs_lens / wav.shape[1] wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1] target_lens_rate = target_lens / target.shape[1]
wav = wav[:, :, 0] wav = wav[:, :, 0]
wav = self.speech_augmentation(wav, wavs_lens_rate) if hasattr(train_conf, 'speech_augment'):
wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
...@@ -277,7 +278,9 @@ class Wav2Vec2ASRTrainer(Trainer): ...@@ -277,7 +278,9 @@ class Wav2Vec2ASRTrainer(Trainer):
logger.info("Setup model!") logger.info("Setup model!")
# setup speech augmentation for wav2vec2 # setup speech augmentation for wav2vec2
self.speech_augmentation = TimeDomainSpecAugment() if hasattr(config, 'audio_augment') and self.train:
self.speech_augmentation = TimeDomainSpecAugment(
**config.audio_augment)
if not self.train: if not self.train:
return return
......
...@@ -641,14 +641,11 @@ class DropChunk(nn.Layer): ...@@ -641,14 +641,11 @@ class DropChunk(nn.Layer):
class TimeDomainSpecAugment(nn.Layer): class TimeDomainSpecAugment(nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm. """A time-domain approximation of the SpecAugment algorithm.
This augmentation module implements three augmentations in This augmentation module implements three augmentations in
the time-domain. the time-domain.
1. Drop chunks of the audio (zero amplitude or white noise) 1. Drop chunks of the audio (zero amplitude or white noise)
2. Drop frequency bands (with band-drop filters) 2. Drop frequency bands (with band-drop filters)
3. Speed peturbation (via resampling to slightly different rate) 3. Speed peturbation (via resampling to slightly different rate)
Arguments Arguments
--------- ---------
perturb_prob : float from 0 to 1 perturb_prob : float from 0 to 1
...@@ -677,7 +674,6 @@ class TimeDomainSpecAugment(nn.Layer): ...@@ -677,7 +674,6 @@ class TimeDomainSpecAugment(nn.Layer):
drop_chunk_noise_factor : float drop_chunk_noise_factor : float
The noise factor used to scale the white noise inserted, relative to The noise factor used to scale the white noise inserted, relative to
the average amplitude of the utterance. Default 0 (no noise inserted). the average amplitude of the utterance. Default 0 (no noise inserted).
Example Example
------- -------
>>> inputs = paddle.randn([10, 16000]) >>> inputs = paddle.randn([10, 16000])
...@@ -718,7 +714,6 @@ class TimeDomainSpecAugment(nn.Layer): ...@@ -718,7 +714,6 @@ class TimeDomainSpecAugment(nn.Layer):
def forward(self, waveforms, lengths): def forward(self, waveforms, lengths):
"""Returns the distorted waveforms. """Returns the distorted waveforms.
Arguments Arguments
--------- ---------
waveforms : tensor waveforms : tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册