diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index e4e70292cda53226bd1eeaebffae7f3752275f87..f5a514c72fd8635efc4fa6b5c0e7daf44231138c 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -28,10 +28,8 @@ from paddle import distributed as dist from paddle.io import DataLoader from yacs.config import CfgNode -from deepspeech.io.collator_st import KaldiPrePorocessedCollator -from deepspeech.io.collator_st import SpeechCollator -from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator -from deepspeech.io.collator_st import TripletSpeechCollator +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.collator import TripletSpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import TripletManifestDataset from deepspeech.io.sampler import SortagradBatchSampler @@ -258,22 +256,13 @@ class U2STTrainer(Trainer): config.data.manifest = config.data.dev_manifest dev_dataset = Dataset.from_config(config) - if config.collator.raw_wav: - if config.model.model_conf.asr_weight > 0.: - Collator = TripletSpeechCollator - TestCollator = SpeechCollator - else: - TestCollator = Collator = SpeechCollator - # Not yet implement the mtl loader for raw_wav. + if config.model.model_conf.asr_weight > 0.: + Collator = TripletSpeechCollator + TestCollator = SpeechCollator else: - if config.model.model_conf.asr_weight > 0.: - Collator = TripletKaldiPrePorocessedCollator - TestCollator = KaldiPrePorocessedCollator - else: - TestCollator = Collator = KaldiPrePorocessedCollator + TestCollator = Collator = SpeechCollator collate_fn_train = Collator.from_config(config) - config.collator.augmentation_config = "" collate_fn_dev = Collator.from_config(config) diff --git a/deepspeech/frontend/audio.py b/deepspeech/frontend/audio.py index ffdcd4b3a5f3b3e7bd8c3725fa8199c3e89e40c4..13dc3a44d49b1e2f98b457086190a09c16f38fd9 100644 --- a/deepspeech/frontend/audio.py +++ b/deepspeech/frontend/audio.py @@ -24,8 +24,10 @@ import soundfile import soxbindings as sox from scipy import signal +from .utility import subfile_from_tar -class AudioSegment(object): + +class AudioSegment(): """Monaural audio segment abstraction. :param samples: Audio samples [num_samples x num_channels]. @@ -68,16 +70,20 @@ class AudioSegment(object): self.duration, self.rms_db)) @classmethod - def from_file(cls, file): + def from_file(cls, file, infos=None): """Create audio segment from audio file. - - :param filepath: Filepath or file object to audio file. - :type filepath: str|file - :return: Audio segment instance. - :rtype: AudioSegment + + Args: + filepath (str|file): Filepath or file object to audio file. + infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None. + + Returns: + AudioSegment: Audio segment instance. """ if isinstance(file, str) and re.findall(r".seqbin_\d+$", file): return cls.from_sequence_file(file) + elif isinstance(file, str) and file.startswith('tar:'): + return cls.from_file(subfile_from_tar(file, infos)) else: samples, sample_rate = soundfile.read(file, dtype='float32') return cls(samples, sample_rate) diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index 5082850d69146b4f4d2d34c7df69c63ddcb57835..f9f7d7c270079384c67d43fb88c28f6285900bbc 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -64,8 +64,12 @@ class SpeechFeaturizer(): target_sample_rate=16000, use_dB_normalization=True, target_dB=-20, - dither=1.0): - self._audio_featurizer = AudioFeaturizer( + dither=1.0, + maskctc=False): + self.stride_ms = stride_ms + self.window_ms = window_ms + + self.audio_feature = AudioFeaturizer( specgram_type=specgram_type, feat_dim=feat_dim, delta_delta=delta_delta, @@ -77,8 +81,12 @@ class SpeechFeaturizer(): use_dB_normalization=use_dB_normalization, target_dB=target_dB, dither=dither) - self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath, - spm_model_prefix) + + self.text_feature = TextFeaturizer( + unit_type=unit_type, + vocab_filepath=vocab_filepath, + spm_model_prefix=spm_model_prefix, + maskctc=maskctc) def featurize(self, speech_segment, keep_transcription_text): """Extract features for speech segment. @@ -94,60 +102,33 @@ class SpeechFeaturizer(): Returns: tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices. """ - spec_feature = self._audio_featurizer.featurize(speech_segment) + spec_feature = self.audio_feature.featurize(speech_segment) + if keep_transcription_text: return spec_feature, speech_segment.transcript + if speech_segment.has_token: text_ids = speech_segment.token_ids else: - text_ids = self._text_featurizer.featurize( - speech_segment.transcript) + text_ids = self.text_feature.featurize(speech_segment.transcript) return spec_feature, text_ids - @property - def vocab_size(self): - """Return the vocabulary size. - Returns: - int: Vocabulary size. - """ - return self._text_featurizer.vocab_size - - @property - def vocab_list(self): - """Return the vocabulary in list. - Returns: - List[str]: - """ - return self._text_featurizer.vocab_list + def text_featurize(self, text, keep_transcription_text): + """Extract features for speech segment. - @property - def vocab_dict(self): - """Return the vocabulary in dict. - Returns: - Dict[str, int]: - """ - return self._text_featurizer.vocab_dict + 1. For audio parts, extract the audio features. + 2. For transcript parts, keep the original text or convert text string + to a list of token indices in char-level. - @property - def feature_size(self): - """Return the audio feature size. - Returns: - int: audio feature size. - """ - return self._audio_featurizer.feature_size + Args: + text (str): text. + keep_transcription_text (bool): True, keep transcript text, False, token ids - @property - def stride_ms(self): - """time length in `ms` unit per frame Returns: - float: time(ms)/frame + (str|List[int]): text, or list of token indices. """ - return self._audio_featurizer.stride_ms + if keep_transcription_text: + return text - @property - def text_feature(self): - """Return the text feature object. - Returns: - TextFeaturizer: object. - """ - return self._text_featurizer + text_ids = self.text_feature.featurize(text) + return text_ids diff --git a/deepspeech/frontend/speech.py b/deepspeech/frontend/speech.py index e58795c0e88e3d73db8ffb53a3c1f261cbd6975f..9eed9725ad2c9b7a8efeac82ad3b5609f154bbbc 100644 --- a/deepspeech/frontend/speech.py +++ b/deepspeech/frontend/speech.py @@ -68,7 +68,12 @@ class SpeechSegment(AudioSegment): return not self.__eq__(other) @classmethod - def from_file(cls, filepath, transcript, tokens=None, token_ids=None): + def from_file(cls, + filepath, + transcript, + tokens=None, + token_ids=None, + infos=None): """Create speech segment from audio file and corresponding transcript. Args: @@ -76,12 +81,12 @@ class SpeechSegment(AudioSegment): transcript (str): Transcript text for the speech. tokens (List[str], optional): text tokens. Defaults to None. token_ids (List[int], optional): text token ids. Defaults to None. + infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None. Returns: SpeechSegment: Speech segment instance. """ - - audio = AudioSegment.from_file(filepath) + audio = AudioSegment.from_file(filepath, infos) return cls(audio.samples, audio.sample_rate, transcript, tokens, token_ids) diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index 3a972b50a4bbf53c346ea55ec3a671d0624f4979..2a58123240dd1673a87cc59fe5c1cb0953a47985 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -14,6 +14,7 @@ """Contains data helper functions.""" import json import math +import tarfile from typing import List from typing import Optional from typing import Text @@ -112,6 +113,51 @@ def read_manifest( return manifest +# Tar File read +TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + + +def parse_tar(file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + +def subfile_from_tar(file, local_data=None): + """Get subfile object from tar. + + tar:tarpath#filename + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + + if local_data is None: + local_data = TarLocalData(tar2info={}, tar2object={}) + + assert isinstance(local_data, TarLocalData) + + if 'tar2info' not in local_data.__dict__: + local_data.tar2info = {} + if 'tar2object' not in local_data.__dict__: + local_data.tar2object = {} + + if tarpath not in local_data.tar2info: + fobj, infos = parse_tar(tarpath) + local_data.tar2info[tarpath] = infos + local_data.tar2object[tarpath] = fobj + else: + fobj = local_data.tar2object[tarpath] + infos = local_data.tar2info[tarpath] + return fobj.extractfile(infos[filename]) + + def rms_to_db(rms: float): """Root Mean Square to dB. diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 15b89ab9f7768518314ba30a11f046c143a8d860..c5c0a414674f0d5ffa4338885a104d09bc3d7833 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from collections import namedtuple from typing import Optional import numpy as np @@ -23,96 +22,17 @@ from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer from deepspeech.frontend.normalizer import FeatureNormalizer from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.utility import IGNORE_ID +from deepspeech.frontend.utility import TarLocalData +from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.io.utility import pad_list from deepspeech.utils.log import Log -__all__ = ["SpeechCollator"] +__all__ = ["SpeechCollator", "TripletSpeechCollator"] logger = Log(__name__).getlog() -# namedtupe need global for pickle. -TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) - - -class SpeechCollator(): - @classmethod - def params(cls, config: Optional[CfgNode]=None) -> CfgNode: - default = CfgNode( - dict( - augmentation_config="", - random_seed=0, - mean_std_filepath="", - unit_type="char", - vocab_filepath="", - spm_model_prefix="", - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delta_delta=False, # 'mfcc', 'fbank' - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - dither=1.0, # feature dither - keep_transcription_text=False)) - - if config is not None: - config.merge_from_other_cfg(default) - return default - - @classmethod - def from_config(cls, config): - """Build a SpeechCollator object from a config. - - Args: - config (yacs.config.CfgNode): configs object. - - Returns: - SpeechCollator: collator object. - """ - assert 'augmentation_config' in config.collator - assert 'keep_transcription_text' in config.collator - assert 'mean_std_filepath' in config.collator - assert 'vocab_filepath' in config.collator - assert 'specgram_type' in config.collator - assert 'n_fft' in config.collator - assert config.collator - - if isinstance(config.collator.augmentation_config, (str, bytes)): - if config.collator.augmentation_config: - aug_file = io.open( - config.collator.augmentation_config, - mode='r', - encoding='utf8') - else: - aug_file = io.StringIO(initial_value='{}', newline='') - else: - aug_file = config.collator.augmentation_config - assert isinstance(aug_file, io.StringIO) - - speech_collator = cls( - aug_file=aug_file, - random_seed=0, - mean_std_filepath=config.collator.mean_std_filepath, - unit_type=config.collator.unit_type, - vocab_filepath=config.collator.vocab_filepath, - spm_model_prefix=config.collator.spm_model_prefix, - specgram_type=config.collator.specgram_type, - feat_dim=config.collator.feat_dim, - delta_delta=config.collator.delta_delta, - stride_ms=config.collator.stride_ms, - window_ms=config.collator.window_ms, - n_fft=config.collator.n_fft, - max_freq=config.collator.max_freq, - target_sample_rate=config.collator.target_sample_rate, - use_dB_normalization=config.collator.use_dB_normalization, - target_dB=config.collator.target_dB, - dither=config.collator.dither, - keep_transcription_text=config.collator.keep_transcription_text) - return speech_collator +class SpeechCollatorBase(): def __init__( self, aug_file, @@ -121,7 +41,7 @@ class SpeechCollator(): spm_model_prefix, random_seed=0, unit_type="char", - specgram_type='linear', # 'linear', 'mfcc', 'fbank' + spectrum_type='linear', # 'linear', 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank' delta_delta=False, # 'mfcc', 'fbank' stride_ms=10.0, # ms @@ -146,7 +66,7 @@ class SpeechCollator(): n_fft (int, optional): fft points for rfft. Defaults to None. max_freq (int, optional): max cut freq. Defaults to None. target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. - specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. + spectrum_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. use_dB_normalization (bool, optional): do dB normalization. Defaults to True. @@ -159,23 +79,27 @@ class SpeechCollator(): Padding audio features with zeros to make them have the same shape (or a user-defined shape) within one batch. """ - self._keep_transcription_text = keep_transcription_text + self.keep_transcription_text = keep_transcription_text + self.stride_ms = stride_ms + self.window_ms = window_ms + self.feat_dim = feat_dim + + self.loader = LoadInputsAndTargets() + # only for tar filetype self._local_data = TarLocalData(tar2info={}, tar2object={}) - self._augmentation_pipeline = AugmentationPipeline( + + self.augmentation = AugmentationPipeline( augmentation_config=aug_file.read(), random_seed=random_seed) self._normalizer = FeatureNormalizer( mean_std_filepath) if mean_std_filepath else None - self._stride_ms = stride_ms - self._target_sample_rate = target_sample_rate - self._speech_featurizer = SpeechFeaturizer( unit_type=unit_type, vocab_filepath=vocab_filepath, spm_model_prefix=spm_model_prefix, - specgram_type=specgram_type, + spectrum_type=spectrum_type, feat_dim=feat_dim, delta_delta=delta_delta, stride_ms=stride_ms, @@ -187,33 +111,11 @@ class SpeechCollator(): target_dB=target_dB, dither=dither) - def _parse_tar(self, file): - """Parse a tar file to get a tarfile object - and a map containing tarinfoes - """ - result = {} - f = tarfile.open(file) - for tarinfo in f.getmembers(): - result[tarinfo.name] = tarinfo - return f, result - - def _subfile_from_tar(self, file): - """Get subfile object from tar. - - It will return a subfile object from tar file - and cached tar file info for next reading request. - """ - tarpath, filename = file.split(':', 1)[1].split('#', 1) - if 'tar2info' not in self._local_data.__dict__: - self._local_data.tar2info = {} - if 'tar2object' not in self._local_data.__dict__: - self._local_data.tar2object = {} - if tarpath not in self._local_data.tar2info: - object, infoes = self._parse_tar(tarpath) - self._local_data.tar2info[tarpath] = infoes - self._local_data.tar2object[tarpath] = object - return self._local_data.tar2object[tarpath].extractfile( - self._local_data.tar2info[tarpath][filename]) + self.feature_size = self._speech_featurizer.audio_feature.feature_size + self.text_feature = self._speech_featurizer.text_feature + self.vocab_dict = self.text_feature.vocab_dict + self.vocab_list = self.text_feature.vocab_list + self.vocab_size = self.text_feature.vocab_size def process_utterance(self, audio_file, transcript): """Load, augment, featurize and normalize for speech data. @@ -226,23 +128,36 @@ class SpeechCollator(): where transcription part could be token ids or text. :rtype: tuple of (2darray, list) """ - if isinstance(audio_file, str) and audio_file.startswith('tar:'): - speech_segment = SpeechSegment.from_file( - self._subfile_from_tar(audio_file), transcript) + filetype = self.loader.file_type(audio_file) + + if filetype != 'sound': + spectrum = self.loader._get_from_loader(audio_file, filetype) + feat_dim = spectrum.shape[1] + assert feat_dim == self.feat_dim, f"expect feat dim {self.feat_dim}, but got {feat_dim}" + + if self.keep_transcription_text: + transcript_part = transcript + else: + text_ids = self.text_feature.featurize(transcript) + transcript_part = text_ids else: - speech_segment = SpeechSegment.from_file(audio_file, transcript) + # read audio + speech_segment = SpeechSegment.from_file( + audio_file, transcript, infos=self._local_data) + # audio augment + self.augmentation.transform_audio(speech_segment) - # audio augment - self._augmentation_pipeline.transform_audio(speech_segment) + # extract speech feature + spectrum, transcript_part = self._speech_featurizer.featurize( + speech_segment, self.keep_transcription_text) - specgram, transcript_part = self._speech_featurizer.featurize( - speech_segment, self._keep_transcription_text) - if self._normalizer: - specgram = self._normalizer.apply(specgram) + # CMVN spectrum + if self._normalizer: + spectrum = self._normalizer.apply(spectrum) - # specgram augment - specgram = self._augmentation_pipeline.transform_feature(specgram) - return specgram, transcript_part + # spectrum augment + spectrum = self.augmentation.transform_feature(spectrum) + return spectrum, transcript_part def __call__(self, batch): """batch examples @@ -272,16 +187,14 @@ class SpeechCollator(): audios.append(audio) # [T, D] audio_lens.append(audio.shape[0]) # text - # for training, text is token ids - # else text is string, convert to unicode ord + # for training, text is token ids, else text is string, convert to unicode ord tokens = [] - if self._keep_transcription_text: + if self.keep_transcription_text: assert isinstance(text, str), (type(text), text) tokens = [ord(t) for t in text] else: tokens = text # token ids - tokens = tokens if isinstance(tokens, np.ndarray) else np.array( - tokens, dtype=np.int64) + tokens = np.array(tokens, dtype=np.int64) texts.append(tokens) text_lens.append(tokens.shape[0]) @@ -292,26 +205,162 @@ class SpeechCollator(): olens = np.array(text_lens).astype(np.int64) return utts, xs_pad, ilens, ys_pad, olens - @property - def vocab_size(self): - return self._speech_featurizer.vocab_size - @property - def vocab_list(self): - return self._speech_featurizer.vocab_list +class SpeechCollator(SpeechCollatorBase): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + augmentation_config="", + random_seed=0, + mean_std_filepath="", + unit_type="char", + vocab_filepath="", + spm_model_prefix="", + spectrum_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, # feature dither + keep_transcription_text=False)) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + @classmethod + def from_config(cls, config): + """Build a SpeechCollator object from a config. + + Args: + config (yacs.config.CfgNode): configs object. + + Returns: + SpeechCollator: collator object. + """ + assert 'augmentation_config' in config.collator + assert 'keep_transcription_text' in config.collator + assert 'mean_std_filepath' in config.collator + assert 'vocab_filepath' in config.collator + assert 'spectrum_type' in config.collator + assert 'n_fft' in config.collator + assert config.collator + + if isinstance(config.collator.augmentation_config, (str, bytes)): + if config.collator.augmentation_config: + aug_file = io.open( + config.collator.augmentation_config, + mode='r', + encoding='utf8') + else: + aug_file = io.StringIO(initial_value='{}', newline='') + else: + aug_file = config.collator.augmentation_config + assert isinstance(aug_file, io.StringIO) + + speech_collator = cls( + aug_file=aug_file, + random_seed=0, + mean_std_filepath=config.collator.mean_std_filepath, + unit_type=config.collator.unit_type, + vocab_filepath=config.collator.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix, + spectrum_type=config.collator.spectrum_type, + feat_dim=config.collator.feat_dim, + delta_delta=config.collator.delta_delta, + stride_ms=config.collator.stride_ms, + window_ms=config.collator.window_ms, + n_fft=config.collator.n_fft, + max_freq=config.collator.max_freq, + target_sample_rate=config.collator.target_sample_rate, + use_dB_normalization=config.collator.use_dB_normalization, + target_dB=config.collator.target_dB, + dither=config.collator.dither, + keep_transcription_text=config.collator.keep_transcription_text) + return speech_collator + + +class TripletSpeechCollator(SpeechCollator): + def process_utterance(self, audio_file, translation, transcript): + """Load, augment, featurize and normalize for speech data. - @property - def vocab_dict(self): - return self._speech_featurizer.vocab_dict + :param audio_file: Filepath or file object of audio file. + :type audio_file: str | file + :param translation: translation text. + :type translation: str + :return: Tuple of audio feature tensor and data of translation part, + where translation part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + spectrum, translation_part = super().process_utterance(audio_file, + translation) + transcript_part = self._speech_featurizer.text_featurize( + transcript, self.keep_transcription_text) + return spectrum, translation_part, transcript_part - @property - def text_feature(self): - return self._speech_featurizer.text_feature + def __call__(self, batch): + """batch examples + + Args: + batch ([List]): batch is (audio, text) + audio (np.ndarray) shape (T, D) + text (List[int] or str): shape (U,) - @property - def feature_size(self): - return self._speech_featurizer.feature_size + Returns: + tuple(audio, text, audio_lens, text_lens): batched data. + audio : (B, Tmax, D) + audio_lens: (B) + text : (B, Umax) + text_lens: (B) + """ + audios = [] + audio_lens = [] + translation_text = [] + translation_text_lens = [] + transcription_text = [] + transcription_text_lens = [] - @property - def stride_ms(self): - return self._speech_featurizer.stride_ms + utts = [] + for utt, audio, translation, transcription in batch: + audio, translation, transcription = self.process_utterance( + audio, translation, transcription) + #utt + utts.append(utt) + # audio + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) + # text + # for training, text is token ids + # else text is string, convert to unicode ord + tokens = [[], []] + for idx, text in enumerate([translation, transcription]): + if self.keep_transcription_text: + assert isinstance(text, str), (type(text), text) + tokens[idx] = [ord(t) for t in text] + else: + tokens[idx] = text # token ids + tokens[idx] = np.array(tokens[idx], dtype=np.int64) + + translation_text.append(tokens[0]) + translation_text_lens.append(tokens[0].shape[0]) + transcription_text.append(tokens[1]) + transcription_text_lens.append(tokens[1].shape[0]) + + padded_audios = pad_sequence( + audios, padding_value=0.0).astype(np.float32) #[B, T, D] + audio_lens = np.array(audio_lens).astype(np.int64) + padded_translation = pad_sequence( + translation_text, padding_value=IGNORE_ID).astype(np.int64) + translation_lens = np.array(translation_text_lens).astype(np.int64) + padded_transcription = pad_sequence( + transcription_text, padding_value=IGNORE_ID).astype(np.int64) + transcription_lens = np.array(transcription_text_lens).astype(np.int64) + return utts, padded_audios, audio_lens, ( + padded_translation, padded_transcription), (translation_lens, + transcription_lens) diff --git a/deepspeech/io/collator_st.py b/deepspeech/io/collator_st.py deleted file mode 100644 index 28573366bb3bdced232fd0d31c2978446b1c1ba9..0000000000000000000000000000000000000000 --- a/deepspeech/io/collator_st.py +++ /dev/null @@ -1,631 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import io -from collections import namedtuple -from typing import Optional - -import kaldiio -import numpy as np -from yacs.config import CfgNode - -from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline -from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer -from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer -from deepspeech.frontend.normalizer import FeatureNormalizer -from deepspeech.frontend.speech import SpeechSegment -from deepspeech.frontend.utility import IGNORE_ID -from deepspeech.io.utility import pad_sequence -from deepspeech.utils.log import Log - -__all__ = ["SpeechCollator", "KaldiPrePorocessedCollator"] - -logger = Log(__name__).getlog() - -# namedtupe need global for pickle. -TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) - - -class SpeechCollator(): - @classmethod - def params(cls, config: Optional[CfgNode]=None) -> CfgNode: - default = CfgNode( - dict( - augmentation_config="", - random_seed=0, - mean_std_filepath="", - unit_type="char", - vocab_filepath="", - spm_model_prefix="", - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delta_delta=False, # 'mfcc', 'fbank' - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - dither=1.0, # feature dither - keep_transcription_text=False)) - - if config is not None: - config.merge_from_other_cfg(default) - return default - - @classmethod - def from_config(cls, config): - """Build a SpeechCollator object from a config. - - Args: - config (yacs.config.CfgNode): configs object. - - Returns: - SpeechCollator: collator object. - """ - assert 'augmentation_config' in config.collator - assert 'keep_transcription_text' in config.collator - assert 'mean_std_filepath' in config.collator - assert 'vocab_filepath' in config.collator - assert 'specgram_type' in config.collator - assert 'n_fft' in config.collator - assert config.collator - - if isinstance(config.collator.augmentation_config, (str, bytes)): - if config.collator.augmentation_config: - aug_file = io.open( - config.collator.augmentation_config, - mode='r', - encoding='utf8') - else: - aug_file = io.StringIO(initial_value='{}', newline='') - else: - aug_file = config.collator.augmentation_config - assert isinstance(aug_file, io.StringIO) - - speech_collator = cls( - aug_file=aug_file, - random_seed=0, - mean_std_filepath=config.collator.mean_std_filepath, - unit_type=config.collator.unit_type, - vocab_filepath=config.collator.vocab_filepath, - spm_model_prefix=config.collator.spm_model_prefix, - specgram_type=config.collator.specgram_type, - feat_dim=config.collator.feat_dim, - delta_delta=config.collator.delta_delta, - stride_ms=config.collator.stride_ms, - window_ms=config.collator.window_ms, - n_fft=config.collator.n_fft, - max_freq=config.collator.max_freq, - target_sample_rate=config.collator.target_sample_rate, - use_dB_normalization=config.collator.use_dB_normalization, - target_dB=config.collator.target_dB, - dither=config.collator.dither, - keep_transcription_text=config.collator.keep_transcription_text) - return speech_collator - - def __init__( - self, - aug_file, - mean_std_filepath, - vocab_filepath, - spm_model_prefix, - random_seed=0, - unit_type="char", - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delta_delta=False, # 'mfcc', 'fbank' - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - dither=1.0, - keep_transcription_text=True): - """SpeechCollator Collator - - Args: - unit_type(str): token unit type, e.g. char, word, spm - vocab_filepath (str): vocab file path. - mean_std_filepath (str): mean and std file path, which suffix is *.npy - spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. - augmentation_config (str, optional): augmentation json str. Defaults to '{}'. - stride_ms (float, optional): stride size in ms. Defaults to 10.0. - window_ms (float, optional): window size in ms. Defaults to 20.0. - n_fft (int, optional): fft points for rfft. Defaults to None. - max_freq (int, optional): max cut freq. Defaults to None. - target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. - specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. - feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. - delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. - use_dB_normalization (bool, optional): do dB normalization. Defaults to True. - target_dB (int, optional): target dB. Defaults to -20. - random_seed (int, optional): for random generator. Defaults to 0. - keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. - if ``keep_transcription_text`` is False, text is token ids else is raw string. - - Do augmentations - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one batch. - """ - self._keep_transcription_text = keep_transcription_text - - self._local_data = TarLocalData(tar2info={}, tar2object={}) - self._augmentation_pipeline = AugmentationPipeline( - augmentation_config=aug_file.read(), random_seed=random_seed) - - self._normalizer = FeatureNormalizer( - mean_std_filepath) if mean_std_filepath else None - - self._stride_ms = stride_ms - self._target_sample_rate = target_sample_rate - - self._speech_featurizer = SpeechFeaturizer( - unit_type=unit_type, - vocab_filepath=vocab_filepath, - spm_model_prefix=spm_model_prefix, - specgram_type=specgram_type, - feat_dim=feat_dim, - delta_delta=delta_delta, - stride_ms=stride_ms, - window_ms=window_ms, - n_fft=n_fft, - max_freq=max_freq, - target_sample_rate=target_sample_rate, - use_dB_normalization=use_dB_normalization, - target_dB=target_dB, - dither=dither) - - def _parse_tar(self, file): - """Parse a tar file to get a tarfile object - and a map containing tarinfoes - """ - result = {} - f = tarfile.open(file) - for tarinfo in f.getmembers(): - result[tarinfo.name] = tarinfo - return f, result - - def _subfile_from_tar(self, file): - """Get subfile object from tar. - - It will return a subfile object from tar file - and cached tar file info for next reading request. - """ - tarpath, filename = file.split(':', 1)[1].split('#', 1) - if 'tar2info' not in self._local_data.__dict__: - self._local_data.tar2info = {} - if 'tar2object' not in self._local_data.__dict__: - self._local_data.tar2object = {} - if tarpath not in self._local_data.tar2info: - object, infoes = self._parse_tar(tarpath) - self._local_data.tar2info[tarpath] = infoes - self._local_data.tar2object[tarpath] = object - return self._local_data.tar2object[tarpath].extractfile( - self._local_data.tar2info[tarpath][filename]) - - @property - def manifest(self): - return self._manifest - - @property - def vocab_size(self): - return self._speech_featurizer.vocab_size - - @property - def vocab_list(self): - return self._speech_featurizer.vocab_list - - @property - def vocab_dict(self): - return self._speech_featurizer.vocab_dict - - @property - def text_feature(self): - return self._speech_featurizer.text_feature - - @property - def feature_size(self): - return self._speech_featurizer.feature_size - - @property - def stride_ms(self): - return self._speech_featurizer.stride_ms - - def process_utterance(self, audio_file, translation): - """Load, augment, featurize and normalize for speech data. - - :param audio_file: Filepath or file object of audio file. - :type audio_file: str | file - :param translation: translation text. - :type translation: str - :return: Tuple of audio feature tensor and data of translation part, - where translation part could be token ids or text. - :rtype: tuple of (2darray, list) - """ - if isinstance(audio_file, str) and audio_file.startswith('tar:'): - speech_segment = SpeechSegment.from_file( - self._subfile_from_tar(audio_file), translation) - else: - speech_segment = SpeechSegment.from_file(audio_file, translation) - - # audio augment - self._augmentation_pipeline.transform_audio(speech_segment) - - specgram, translation_part = self._speech_featurizer.featurize( - speech_segment, self._keep_transcription_text) - if self._normalizer: - specgram = self._normalizer.apply(specgram) - - # specgram augment - specgram = self._augmentation_pipeline.transform_feature(specgram) - return specgram, translation_part - - def __call__(self, batch): - """batch examples - - Args: - batch ([List]): batch is (audio, text) - audio (np.ndarray) shape (T, D) - text (List[int] or str): shape (U,) - - Returns: - tuple(audio, text, audio_lens, text_lens): batched data. - audio : (B, Tmax, D) - audio_lens: (B) - text : (B, Umax) - text_lens: (B) - """ - audios = [] - audio_lens = [] - texts = [] - text_lens = [] - utts = [] - for utt, audio, text in batch: - audio, text = self.process_utterance(audio, text) - #utt - utts.append(utt) - # audio - audios.append(audio) # [T, D] - audio_lens.append(audio.shape[0]) - # text - # for training, text is token ids - # else text is string, convert to unicode ord - tokens = [] - if self._keep_transcription_text: - assert isinstance(text, str), (type(text), text) - tokens = [ord(t) for t in text] - else: - tokens = text # token ids - tokens = tokens if isinstance(tokens, np.ndarray) else np.array( - tokens, dtype=np.int64) - texts.append(tokens) - text_lens.append(tokens.shape[0]) - - padded_audios = pad_sequence( - audios, padding_value=0.0).astype(np.float32) #[B, T, D] - audio_lens = np.array(audio_lens).astype(np.int64) - padded_texts = pad_sequence( - texts, padding_value=IGNORE_ID).astype(np.int64) - text_lens = np.array(text_lens).astype(np.int64) - return utts, padded_audios, audio_lens, padded_texts, text_lens - - -class TripletSpeechCollator(SpeechCollator): - def process_utterance(self, audio_file, translation, transcript): - """Load, augment, featurize and normalize for speech data. - - :param audio_file: Filepath or file object of audio file. - :type audio_file: str | file - :param translation: translation text. - :type translation: str - :return: Tuple of audio feature tensor and data of translation part, - where translation part could be token ids or text. - :rtype: tuple of (2darray, list) - """ - if isinstance(audio_file, str) and audio_file.startswith('tar:'): - speech_segment = SpeechSegment.from_file( - self._subfile_from_tar(audio_file), translation) - else: - speech_segment = SpeechSegment.from_file(audio_file, translation) - - # audio augment - self._augmentation_pipeline.transform_audio(speech_segment) - - specgram, translation_part = self._speech_featurizer.featurize( - speech_segment, self._keep_transcription_text) - transcript_part = self._speech_featurizer._text_featurizer.featurize( - transcript) - if self._normalizer: - specgram = self._normalizer.apply(specgram) - - # specgram augment - specgram = self._augmentation_pipeline.transform_feature(specgram) - return specgram, translation_part, transcript_part - - def __call__(self, batch): - """batch examples - - Args: - batch ([List]): batch is (audio, text) - audio (np.ndarray) shape (T, D) - text (List[int] or str): shape (U,) - - Returns: - tuple(audio, text, audio_lens, text_lens): batched data. - audio : (B, Tmax, D) - audio_lens: (B) - text : (B, Umax) - text_lens: (B) - """ - audios = [] - audio_lens = [] - translation_text = [] - translation_text_lens = [] - transcription_text = [] - transcription_text_lens = [] - - utts = [] - for utt, audio, translation, transcription in batch: - audio, translation, transcription = self.process_utterance( - audio, translation, transcription) - #utt - utts.append(utt) - # audio - audios.append(audio) # [T, D] - audio_lens.append(audio.shape[0]) - # text - # for training, text is token ids - # else text is string, convert to unicode ord - tokens = [[], []] - for idx, text in enumerate([translation, transcription]): - if self._keep_transcription_text: - assert isinstance(text, str), (type(text), text) - tokens[idx] = [ord(t) for t in text] - else: - tokens[idx] = text # token ids - tokens[idx] = tokens[idx] if isinstance( - tokens[idx], np.ndarray) else np.array( - tokens[idx], dtype=np.int64) - translation_text.append(tokens[0]) - translation_text_lens.append(tokens[0].shape[0]) - transcription_text.append(tokens[1]) - transcription_text_lens.append(tokens[1].shape[0]) - - padded_audios = pad_sequence( - audios, padding_value=0.0).astype(np.float32) #[B, T, D] - audio_lens = np.array(audio_lens).astype(np.int64) - padded_translation = pad_sequence( - translation_text, padding_value=IGNORE_ID).astype(np.int64) - translation_lens = np.array(translation_text_lens).astype(np.int64) - padded_transcription = pad_sequence( - transcription_text, padding_value=IGNORE_ID).astype(np.int64) - transcription_lens = np.array(transcription_text_lens).astype(np.int64) - return utts, padded_audios, audio_lens, ( - padded_translation, padded_transcription), (translation_lens, - transcription_lens) - - -class KaldiPrePorocessedCollator(SpeechCollator): - @classmethod - def params(cls, config: Optional[CfgNode]=None) -> CfgNode: - default = CfgNode( - dict( - augmentation_config="", - random_seed=0, - unit_type="char", - vocab_filepath="", - spm_model_prefix="", - feat_dim=0, - stride_ms=10.0, - keep_transcription_text=False)) - - if config is not None: - config.merge_from_other_cfg(default) - return default - - @classmethod - def from_config(cls, config): - """Build a SpeechCollator object from a config. - - Args: - config (yacs.config.CfgNode): configs object. - - Returns: - SpeechCollator: collator object. - """ - assert 'augmentation_config' in config.collator - assert 'keep_transcription_text' in config.collator - assert 'vocab_filepath' in config.collator - assert config.collator - - if isinstance(config.collator.augmentation_config, (str, bytes)): - if config.collator.augmentation_config: - aug_file = io.open( - config.collator.augmentation_config, - mode='r', - encoding='utf8') - else: - aug_file = io.StringIO(initial_value='{}', newline='') - else: - aug_file = config.collator.augmentation_config - assert isinstance(aug_file, io.StringIO) - - speech_collator = cls( - aug_file=aug_file, - random_seed=0, - unit_type=config.collator.unit_type, - vocab_filepath=config.collator.vocab_filepath, - spm_model_prefix=config.collator.spm_model_prefix, - feat_dim=config.collator.feat_dim, - stride_ms=config.collator.stride_ms, - keep_transcription_text=config.collator.keep_transcription_text) - return speech_collator - - def __init__(self, - aug_file, - vocab_filepath, - spm_model_prefix, - random_seed=0, - unit_type="char", - feat_dim=0, - stride_ms=10.0, - keep_transcription_text=True): - """SpeechCollator Collator - - Args: - unit_type(str): token unit type, e.g. char, word, spm - vocab_filepath (str): vocab file path. - spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. - augmentation_config (str, optional): augmentation json str. Defaults to '{}'. - random_seed (int, optional): for random generator. Defaults to 0. - keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. - if ``keep_transcription_text`` is False, text is token ids else is raw string. - - Do augmentations - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one batch. - """ - self._keep_transcription_text = keep_transcription_text - self._feat_dim = feat_dim - self._stride_ms = stride_ms - - self._local_data = TarLocalData(tar2info={}, tar2object={}) - self._augmentation_pipeline = AugmentationPipeline( - augmentation_config=aug_file.read(), random_seed=random_seed) - - self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath, - spm_model_prefix) - - def process_utterance(self, audio_file, translation): - """Load, augment, featurize and normalize for speech data. - - :param audio_file: Filepath or file object of kaldi processed feature. - :type audio_file: str | file - :param translation: Translation text. - :type translation: str - :return: Tuple of audio feature tensor and data of translation part, - where translation part could be token ids or text. - :rtype: tuple of (2darray, list) - """ - specgram = kaldiio.load_mat(audio_file) - assert specgram.shape[ - 1] == self._feat_dim, 'expect feat dim {}, but got {}'.format( - self._feat_dim, specgram.shape[1]) - - # specgram augment - specgram = self._augmentation_pipeline.transform_feature(specgram) - - if self._keep_transcription_text: - return specgram, translation - else: - text_ids = self._text_featurizer.featurize(translation) - return specgram, text_ids - - -class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator): - def process_utterance(self, audio_file, translation, transcript): - """Load, augment, featurize and normalize for speech data. - - :param audio_file: Filepath or file object of kali processed feature. - :type audio_file: str | file - :param translation: Translation text. - :type translation: str - :param transcript: Transcription text. - :type transcript: str - :return: Tuple of audio feature tensor and data of translation and transcription parts, - where translation and transcription parts could be token ids or text. - :rtype: tuple of (2darray, (list, list)) - """ - specgram = kaldiio.load_mat(audio_file) - assert specgram.shape[ - 1] == self._feat_dim, 'expect feat dim {}, but got {}'.format( - self._feat_dim, specgram.shape[1]) - - # specgram augment - specgram = self._augmentation_pipeline.transform_feature(specgram) - - if self._keep_transcription_text: - return specgram, translation, transcript - else: - translation_text_ids = self._text_featurizer.featurize(translation) - transcript_text_ids = self._text_featurizer.featurize(transcript) - return specgram, translation_text_ids, transcript_text_ids - - def __call__(self, batch): - """batch examples - - Args: - batch ([List]): batch is (audio, text) - audio (np.ndarray) shape (T, D) - translation (List[int] or str): shape (U,) - transcription (List[int] or str): shape (V,) - - Returns: - tuple(audio, text, audio_lens, text_lens): batched data. - audio : (B, Tmax, D) - audio_lens: (B) - translation_text : (B, Umax) - translation_text_lens: (B) - transcription_text : (B, Vmax) - transcription_text_lens: (B) - """ - audios = [] - audio_lens = [] - translation_text = [] - translation_text_lens = [] - transcription_text = [] - transcription_text_lens = [] - - utts = [] - for utt, audio, translation, transcription in batch: - audio, translation, transcription = self.process_utterance( - audio, translation, transcription) - #utt - utts.append(utt) - # audio - audios.append(audio) # [T, D] - audio_lens.append(audio.shape[0]) - # text - # for training, text is token ids - # else text is string, convert to unicode ord - tokens = [[], []] - for idx, text in enumerate([translation, transcription]): - if self._keep_transcription_text: - assert isinstance(text, str), (type(text), text) - tokens[idx] = [ord(t) for t in text] - else: - tokens[idx] = text # token ids - tokens[idx] = tokens[idx] if isinstance( - tokens[idx], np.ndarray) else np.array( - tokens[idx], dtype=np.int64) - translation_text.append(tokens[0]) - translation_text_lens.append(tokens[0].shape[0]) - transcription_text.append(tokens[1]) - transcription_text_lens.append(tokens[1].shape[0]) - - padded_audios = pad_sequence( - audios, padding_value=0.0).astype(np.float32) #[B, T, D] - audio_lens = np.array(audio_lens).astype(np.int64) - padded_translation = pad_sequence( - translation_text, padding_value=IGNORE_ID).astype(np.int64) - translation_lens = np.array(translation_text_lens).astype(np.int64) - padded_transcription = pad_sequence( - transcription_text, padding_value=IGNORE_ID).astype(np.int64) - transcription_lens = np.array(transcription_text_lens).astype(np.int64) - return utts, padded_audios, audio_lens, ( - padded_translation, padded_transcription), (translation_lens, - transcription_lens) diff --git a/deepspeech/io/reader.py b/deepspeech/io/reader.py index 95cdbb951cce19203a437877f962c42a9177fbd6..30ae98f06de9d640475214caf843d3b796576ff3 100644 --- a/deepspeech/io/reader.py +++ b/deepspeech/io/reader.py @@ -321,6 +321,22 @@ class LoadInputsAndTargets(): raise NotImplementedError( "Not supported: loader_type={}".format(filetype)) + def file_type(self, filepath): + suffix = filepath.split(":")[0].split('.')[1] + if suffix == 'ark': + return 'mat' + elif suffix == 'scp': + return 'scp' + elif suffix == 'npy': + return 'npy' + elif suffix == 'npz': + return 'npz' + elif suffix in ['wav', 'flac']: + # PCM16 + return 'sound' + else: + raise ValueError(f"Not support filetype: {suffix}") + class SoundHDF5File(): """Collecting sound files to a HDF5 file