collator.py 2.9 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

16 17 18 19 20
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.io.utility import pad_sequence
from deepspeech.utils.log import Log

__all__ = ["SpeechCollator"]
H
Hui Zhang 已提交
21

22
logger = Log(__name__).getlog()
H
Hui Zhang 已提交
23 24 25


class SpeechCollator():
26
    def __init__(self, keep_transcription_text=True):
H
Hui Zhang 已提交
27 28 29 30
        """
        Padding audio features with zeros to make them have the same shape (or
        a user-defined shape) within one bach.

31
        if ``keep_transcription_text`` is False, text is token ids else is raw string.
H
Hui Zhang 已提交
32
        """
33
        self._keep_transcription_text = keep_transcription_text
H
Hui Zhang 已提交
34 35

    def __call__(self, batch):
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
        """batch examples

        Args:
            batch ([List]): batch is (audio, text)
                audio (np.ndarray) shape (D, T)
                text (List[int] or str): shape (U,)

        Returns:
            tuple(audio, text, audio_lens, text_lens): batched data.
                audio : (B, Tmax, D)
                audio_lens: (B)
                text : (B, Umax)
                text_lens: (B)
        """
        audios = []
H
Hui Zhang 已提交
51
        audio_lens = []
52 53
        texts = []
        text_lens = []
H
Haoxin Ma 已提交
54
        utts = []
H
Haoxin Ma 已提交
55
        for utt, audio, text in batch:
H
Haoxin Ma 已提交
56 57
            #utt
            utts.append(utt)
H
Hui Zhang 已提交
58
            # audio
59
            audios.append(audio.T)  # [T, D]
H
Hui Zhang 已提交
60 61
            audio_lens.append(audio.shape[1])
            # text
62 63 64 65 66 67
            # for training, text is token ids
            # else text is string, convert to unicode ord
            tokens = []
            if self._keep_transcription_text:
                assert isinstance(text, str), (type(text), text)
                tokens = [ord(t) for t in text]
H
Hui Zhang 已提交
68
            else:
69 70 71 72 73
                tokens = text  # token ids
            tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
                tokens, dtype=np.int64)
            texts.append(tokens)
            text_lens.append(tokens.shape[0])
H
Hui Zhang 已提交
74

75 76 77 78 79 80
        padded_audios = pad_sequence(
            audios, padding_value=0.0).astype(np.float32)  #[B, T, D]
        audio_lens = np.array(audio_lens).astype(np.int64)
        padded_texts = pad_sequence(
            texts, padding_value=IGNORE_ID).astype(np.int64)
        text_lens = np.array(text_lens).astype(np.int64)
H
Haoxin Ma 已提交
81
        return utts, padded_audios, audio_lens, padded_texts, text_lens