From c8368410e291f5ef0992309ed0fd19fc9f59865b Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Fri, 4 Jun 2021 12:41:48 +0000 Subject: [PATCH] utt datapipeline --- deepspeech/exps/deepspeech2/model.py | 4 ++-- deepspeech/io/collator.py | 4 ++-- deepspeech/io/dataset.py | 12 ++++++++---- deepspeech/models/deepspeech2.py | 2 +- examples/chinese_g2p/local/ignore_sandhi.py | 7 +++++-- examples/dataset/librispeech/.gitignore | 14 +++++++------- examples/librispeech/s0/README.md | 2 +- examples/tiny/s0/run.sh | 2 +- 8 files changed, 27 insertions(+), 20 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 8e8a1824..05b55f75 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -75,7 +75,7 @@ class DeepSpeech2Trainer(Trainer): for i, batch in enumerate(self.valid_loader): loss = self.model(*batch) if paddle.isfinite(loss): - num_utts = batch[0].shape[0] + num_utts = batch[1].shape[0] num_seen_utts += num_utts total_loss += float(loss) * num_utts valid_losses['val_loss'].append(float(loss)) @@ -191,7 +191,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, audio, audio_len, texts, texts_len): + def compute_metrics(self, utt, audio, audio_len, texts, texts_len): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 7f019039..5b521fbd 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -51,7 +51,7 @@ class SpeechCollator(): audio_lens = [] texts = [] text_lens = [] - for audio, text in batch: + for utt, audio, text in batch: # audio audios.append(audio.T) # [T, D] audio_lens.append(audio.shape[1]) @@ -75,4 +75,4 @@ class SpeechCollator(): padded_texts = pad_sequence( texts, padding_value=IGNORE_ID).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64) - return padded_audios, audio_lens, padded_texts, text_lens + return utt, padded_audios, audio_lens, padded_texts, text_lens diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index fba5f7c6..eaa57a4e 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -284,7 +284,7 @@ class ManifestDataset(Dataset): return self._local_data.tar2object[tarpath].extractfile( self._local_data.tar2info[tarpath][filename]) - def process_utterance(self, audio_file, transcript): + def process_utterance(self, utt, audio_file, transcript): """Load, augment, featurize and normalize for speech data. :param audio_file: Filepath or file object of audio file. @@ -323,7 +323,7 @@ class ManifestDataset(Dataset): specgram = self._augmentation_pipeline.transform_feature(specgram) feature_aug_time = time.time() - start_time #logger.debug(f"audio feature augmentation time: {feature_aug_time}") - return specgram, transcript_part + return utt, specgram, transcript_part def _instance_reader_creator(self, manifest): """ @@ -336,7 +336,9 @@ class ManifestDataset(Dataset): def reader(): for instance in manifest: - inst = self.process_utterance(instance["feat"], + # inst = self.process_utterance(instance["feat"], + # instance["text"]) + inst = self.process_utterance(instance["utt"], instance["feat"], instance["text"]) yield inst @@ -347,4 +349,6 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] - return self.process_utterance(instance["feat"], instance["text"]) + return self.process_utterance(instance["utt"], instance["feat"], + instance["text"]) + # return self.process_utterance(instance["feat"], instance["text"]) diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 0ff5514d..ab617a53 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -161,7 +161,7 @@ class DeepSpeech2Model(nn.Layer): reduction=True, # sum batch_average=True) # sum / batch_size - def forward(self, audio, audio_len, text, text_len): + def forward(self, utt, audio, audio_len, text, text_len): """Compute Model loss Args: diff --git a/examples/chinese_g2p/local/ignore_sandhi.py b/examples/chinese_g2p/local/ignore_sandhi.py index cda1bd14..b7f37a27 100644 --- a/examples/chinese_g2p/local/ignore_sandhi.py +++ b/examples/chinese_g2p/local/ignore_sandhi.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -from typing import List, Union from pathlib import Path +from typing import List +from typing import Union def erized(syllable: str) -> bool: @@ -67,7 +68,9 @@ def ignore_sandhi(reference: List[str], generated: List[str]) -> List[str]: return result -def convert_transcriptions(reference: Union[str, Path], generated: Union[str, Path], output: Union[str, Path]): +def convert_transcriptions(reference: Union[str, Path], + generated: Union[str, Path], + output: Union[str, Path]): with open(reference, 'rt') as f_ref: with open(generated, 'rt') as f_gen: with open(output, 'wt') as f_out: diff --git a/examples/dataset/librispeech/.gitignore b/examples/dataset/librispeech/.gitignore index a8d8eb76..dfd5c67b 100644 --- a/examples/dataset/librispeech/.gitignore +++ b/examples/dataset/librispeech/.gitignore @@ -1,7 +1,7 @@ -dev-clean/ -dev-other/ -test-clean/ -test-other/ -train-clean-100/ -train-clean-360/ -train-other-500/ +dev-clean +dev-other +test-clean +test-other +train-clean-100 +train-clean-360 +train-other-500 diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 09f700da..393dd457 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -3,7 +3,7 @@ ## Deepspeech2 | Model | release | Config | Test set | Loss | WER | -| --- | --- | --- | --- | --- | --- | +| --- | --- | --- | --- | --- | --- | | DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | | DeepSpeech2 | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | | DeepSpeech2 | 1.8.5 | - | test-clean | - | 0.074939 | diff --git a/examples/tiny/s0/run.sh b/examples/tiny/s0/run.sh index d4961adb..0f2e3fd1 100755 --- a/examples/tiny/s0/run.sh +++ b/examples/tiny/s0/run.sh @@ -11,7 +11,7 @@ avg_num=1 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} -ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ###ckpt = deepspeech2 echo "checkpoint name ${ckpt}" if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then -- GitLab