From 17092cbbe24b7f47894e5bbd8dd272f4255eec99 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 9 Oct 2021 09:24:32 +0000 Subject: [PATCH] D,T to T,D --- deepspeech/frontend/audio.py | 20 +-- deepspeech/frontend/augmentor/spec_augment.py | 8 +- deepspeech/frontend/featurizer/__init__.py | 3 + .../frontend/featurizer/audio_featurizer.py | 86 +++++++------ .../frontend/featurizer/speech_featurizer.py | 119 +++++------------- deepspeech/frontend/normalizer.py | 17 +-- deepspeech/frontend/speech.py | 11 +- deepspeech/io/collator.py | 4 +- examples/librispeech/s1/path.sh | 3 +- 9 files changed, 122 insertions(+), 149 deletions(-) diff --git a/deepspeech/frontend/audio.py b/deepspeech/frontend/audio.py index ffdcd4b3..13dc3a44 100644 --- a/deepspeech/frontend/audio.py +++ b/deepspeech/frontend/audio.py @@ -24,8 +24,10 @@ import soundfile import soxbindings as sox from scipy import signal +from .utility import subfile_from_tar -class AudioSegment(object): + +class AudioSegment(): """Monaural audio segment abstraction. :param samples: Audio samples [num_samples x num_channels]. @@ -68,16 +70,20 @@ class AudioSegment(object): self.duration, self.rms_db)) @classmethod - def from_file(cls, file): + def from_file(cls, file, infos=None): """Create audio segment from audio file. - - :param filepath: Filepath or file object to audio file. - :type filepath: str|file - :return: Audio segment instance. - :rtype: AudioSegment + + Args: + filepath (str|file): Filepath or file object to audio file. + infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None. + + Returns: + AudioSegment: Audio segment instance. """ if isinstance(file, str) and re.findall(r".seqbin_\d+$", file): return cls.from_sequence_file(file) + elif isinstance(file, str) and file.startswith('tar:'): + return cls.from_file(subfile_from_tar(file, infos)) else: samples, sample_rate = soundfile.read(file, dtype='float32') return cls(samples, sample_rate) diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index 26c94d41..e78f6f6a 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -29,10 +29,10 @@ class SpecAugmentor(AugmentorBase): SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition https://arxiv.org/abs/1904.08779 - + SpecAugment on Large Scale Datasets https://arxiv.org/abs/1912.05533 - + """ def __init__(self, @@ -61,7 +61,7 @@ class SpecAugmentor(AugmentorBase): adaptive_size_ratio (float): adaptive size ratio for time masking max_n_time_masks (int): maximum number of time masking replace_with_zero (bool): pad zero on mask if true else use mean - warp_mode (str): "PIL" (default, fast, not differentiable) + warp_mode (str): "PIL" (default, fast, not differentiable) or "sparse_image_warp" (slow, differentiable) """ super().__init__() @@ -133,7 +133,7 @@ class SpecAugmentor(AugmentorBase): return self._time_mask def __repr__(self): - return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}" + return f"specaug: F-{self.F}, T-{self.T}, F-n-{self.n_freq_masks}, T-n-{self.n_time_masks}" def time_warp(self, x, mode='PIL'): """time warp for spec augment diff --git a/deepspeech/frontend/featurizer/__init__.py b/deepspeech/frontend/featurizer/__init__.py index 185a92b8..6992700d 100644 --- a/deepspeech/frontend/featurizer/__init__.py +++ b/deepspeech/frontend/featurizer/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .audio_featurizer import AudioFeaturizer #noqa: F401 +from .speech_featurizer import SpeechFeaturizer +from .text_featurizer import TextFeaturizer diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index 2f3163fa..6f3b646c 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -18,7 +18,7 @@ from python_speech_features import logfbank from python_speech_features import mfcc -class AudioFeaturizer(object): +class AudioFeaturizer(): """Audio featurizer, for extracting features from audio contents of AudioSegment or SpeechSegment. @@ -167,32 +167,6 @@ class AudioFeaturizer(object): raise ValueError("Unknown spectrum_type %s. " "Supported values: linear." % self._spectrum_type) - def _compute_linear_specgram(self, - samples, - sample_rate, - stride_ms=10.0, - window_ms=20.0, - max_freq=None, - eps=1e-14): - """Compute the linear spectrogram from FFT energy.""" - if max_freq is None: - max_freq = sample_rate / 2 - if max_freq > sample_rate / 2: - raise ValueError("max_freq must not be greater than half of " - "sample rate.") - if stride_ms > window_ms: - raise ValueError("Stride size must not be greater than " - "window size.") - stride_size = int(0.001 * sample_rate * stride_ms) - window_size = int(0.001 * sample_rate * window_ms) - specgram, freqs = self._specgram_real( - samples, - window_size=window_size, - stride_size=stride_size, - sample_rate=sample_rate) - ind = np.where(freqs <= max_freq)[0][-1] + 1 - return np.log(specgram[:ind, :] + eps) - def _specgram_real(self, samples, window_size, stride_size, sample_rate): """Compute the spectrogram for samples from a real signal.""" # extract strided windows @@ -217,26 +191,65 @@ class AudioFeaturizer(object): freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) return fft, freqs + def _compute_linear_specgram(self, + samples, + sample_rate, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + eps=1e-14): + """Compute the linear spectrogram from FFT energy. + + Args: + samples ([type]): [description] + sample_rate ([type]): [description] + stride_ms (float, optional): [description]. Defaults to 10.0. + window_ms (float, optional): [description]. Defaults to 20.0. + max_freq ([type], optional): [description]. Defaults to None. + eps ([type], optional): [description]. Defaults to 1e-14. + + Raises: + ValueError: [description] + ValueError: [description] + + Returns: + np.ndarray: log spectrogram, (time, freq) + """ + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must not be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + stride_size = int(0.001 * sample_rate * stride_ms) + window_size = int(0.001 * sample_rate * window_ms) + specgram, freqs = self._specgram_real( + samples, + window_size=window_size, + stride_size=stride_size, + sample_rate=sample_rate) + ind = np.where(freqs <= max_freq)[0][-1] + 1 + # (freq, time) + spec = np.log(specgram[:ind, :] + eps) + return np.transpose(spec) + def _concat_delta_delta(self, feat): """append delat, delta-delta feature. Args: - feat (np.ndarray): (D, T) + feat (np.ndarray): (T, D) Returns: - np.ndarray: feat with delta-delta, (3*D, T) + np.ndarray: feat with delta-delta, (T, 3*D) """ - feat = np.transpose(feat) # Deltas d_feat = delta(feat, 2) # Deltas-Deltas dd_feat = delta(feat, 2) - # transpose - feat = np.transpose(feat) - d_feat = np.transpose(d_feat) - dd_feat = np.transpose(dd_feat) # concat above three features - concat_feat = np.concatenate((feat, d_feat, dd_feat)) + concat_feat = np.concatenate((feat, d_feat, dd_feat), axis=1) return concat_feat def _compute_mfcc(self, @@ -292,7 +305,6 @@ class AudioFeaturizer(object): ceplifter=22, useEnergy=True, winfunc='povey') - mfcc_feat = np.transpose(mfcc_feat) if delta_delta: mfcc_feat = self._concat_delta_delta(mfcc_feat) return mfcc_feat @@ -346,8 +358,6 @@ class AudioFeaturizer(object): remove_dc_offset=True, preemph=0.97, wintype='povey') - - fbank_feat = np.transpose(fbank_feat) if delta_delta: fbank_feat = self._concat_delta_delta(fbank_feat) return fbank_feat diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index 50856e16..25687140 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -16,38 +16,8 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer -class SpeechFeaturizer(object): - """Speech featurizer, for extracting features from both audio and transcript - contents of SpeechSegment. - - Currently, for audio parts, it supports feature types of linear - spectrogram and mfcc; for transcript parts, it only supports char-level - tokenizing and conversion into a list of token indices. Note that the - token indexing order follows the given vocabulary file. - - :param vocab_filepath: Filepath to load vocabulary for token indices - conversion. - :type spectrum_type: str - :param spectrum_type: Specgram feature type. Options: 'linear', 'mfcc'. - :type spectrum_type: str - :param stride_ms: Striding size (in milliseconds) for generating frames. - :type stride_ms: float - :param window_ms: Window size (in milliseconds) for generating frames. - :type window_ms: float - :param max_freq: When spectrum_type is 'linear', only FFT bins - corresponding to frequencies between [0, max_freq] are - returned; when spectrum_type is 'mfcc', max_freq is the - highest band edge of mel filters. - :types max_freq: None|float - :param target_sample_rate: Speech are resampled (if upsampling or - downsampling is allowed) to this before - extracting spectrogram features. - :type target_sample_rate: float - :param use_dB_normalization: Whether to normalize the audio to a certain - decibels before extracting the features. - :type use_dB_normalization: bool - :param target_dB: Target audio decibels for normalization. - :type target_dB: float +class SpeechFeaturizer(): + """Speech and Text feature extraction. """ def __init__(self, @@ -64,8 +34,12 @@ class SpeechFeaturizer(object): target_sample_rate=16000, use_dB_normalization=True, target_dB=-20, - dither=1.0): - self._audio_featurizer = AudioFeaturizer( + dither=1.0, + maskctc=False): + self.stride_ms = stride_ms + self.window_ms = window_ms + + self.audio_feature = AudioFeaturizer( spectrum_type=spectrum_type, feat_dim=feat_dim, delta_delta=delta_delta, @@ -77,8 +51,14 @@ class SpeechFeaturizer(object): use_dB_normalization=use_dB_normalization, target_dB=target_dB, dither=dither) - self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath, - spm_model_prefix) + self.feature_size = self.audio_feature.feature_size + + self.text_feature = TextFeaturizer( + unit_type=unit_type, + vocab_filepath=vocab_filepath, + spm_model_prefix=spm_model_prefix, + maskctc=maskctc) + self.vocab_size = self.text_feature.vocab_size def featurize(self, speech_segment, keep_transcription_text): """Extract features for speech segment. @@ -94,66 +74,33 @@ class SpeechFeaturizer(object): Returns: tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices. """ - spec_feature = self._audio_featurizer.featurize(speech_segment) + spec_feature = self.audio_feature.featurize(speech_segment) + if keep_transcription_text: return spec_feature, speech_segment.transcript + if speech_segment.has_token: text_ids = speech_segment.token_ids else: - text_ids = self._text_featurizer.featurize( - speech_segment.transcript) + text_ids = self.text_feature.featurize(speech_segment.transcript) return spec_feature, text_ids - @property - def vocab_size(self): - """Return the vocabulary size. - - Returns: - int: Vocabulary size. - """ - return self._text_featurizer.vocab_size - - @property - def vocab_list(self): - """Return the vocabulary in list. - - Returns: - List[str]: - """ - return self._text_featurizer.vocab_list - - @property - def vocab_dict(self): - """Return the vocabulary in dict. - - Returns: - Dict[str, int]: - """ - return self._text_featurizer.vocab_dict - - @property - def feature_size(self): - """Return the audio feature size. + def text_featurize(self, text, keep_transcription_text): + """Extract features for speech segment. - Returns: - int: audio feature size. - """ - return self._audio_featurizer.feature_size + 1. For audio parts, extract the audio features. + 2. For transcript parts, keep the original text or convert text string + to a list of token indices in char-level. - @property - def stride_ms(self): - """time length in `ms` unit per frame + Args: + text (str): text. + keep_transcription_text (bool): True, keep transcript text, False, token ids Returns: - float: time(ms)/frame + (str|List[int]): text, or list of token indices. """ - return self._audio_featurizer.stride_ms - - @property - def text_feature(self): - """Return the text feature object. + if keep_transcription_text: + return text - Returns: - TextFeaturizer: object. - """ - return self._text_featurizer + text_ids = self.text_feature.featurize(text) + return text_ids diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py index 287b51e5..6ace4fc6 100644 --- a/deepspeech/frontend/normalizer.py +++ b/deepspeech/frontend/normalizer.py @@ -40,21 +40,21 @@ class CollateFunc(object): number = 0 for item in batch: audioseg = AudioSegment.from_file(item['feat']) - feat = self.feature_func(audioseg) #(D, T) + feat = self.feature_func(audioseg) #(T, D) - sums = np.sum(feat, axis=1) + sums = np.sum(feat, axis=0) if mean_stat is None: mean_stat = sums else: mean_stat += sums - square_sums = np.sum(np.square(feat), axis=1) + square_sums = np.sum(np.square(feat), axis=0) if var_stat is None: var_stat = square_sums else: var_stat += square_sums - number += feat.shape[1] + number += feat.shape[0] return number, mean_stat, var_stat @@ -120,7 +120,7 @@ class FeatureNormalizer(object): """Normalize features to be of zero mean and unit stddev. :param features: Input features to be normalized. - :type features: ndarray, shape (D, T) + :type features: ndarray, shape (T, D) :param eps: added to stddev to provide numerical stablibity. :type eps: float :return: Normalized features. @@ -130,9 +130,10 @@ class FeatureNormalizer(object): def _read_mean_std_from_file(self, filepath, eps=1e-20): """Load mean and std from file.""" - mean, istd = load_cmvn(filepath, filetype='json') - self._mean = np.expand_dims(mean, axis=-1) - self._istd = np.expand_dims(istd, axis=-1) + filetype = filepath.split(".")[-1] + mean, istd = load_cmvn(filepath, filetype=filetype) + self._mean = np.expand_dims(mean, axis=0) + self._istd = np.expand_dims(istd, axis=0) def write_to_file(self, filepath): """Write the mean and stddev to the file. diff --git a/deepspeech/frontend/speech.py b/deepspeech/frontend/speech.py index e58795c0..9eed9725 100644 --- a/deepspeech/frontend/speech.py +++ b/deepspeech/frontend/speech.py @@ -68,7 +68,12 @@ class SpeechSegment(AudioSegment): return not self.__eq__(other) @classmethod - def from_file(cls, filepath, transcript, tokens=None, token_ids=None): + def from_file(cls, + filepath, + transcript, + tokens=None, + token_ids=None, + infos=None): """Create speech segment from audio file and corresponding transcript. Args: @@ -76,12 +81,12 @@ class SpeechSegment(AudioSegment): transcript (str): Transcript text for the speech. tokens (List[str], optional): text tokens. Defaults to None. token_ids (List[int], optional): text token ids. Defaults to None. + infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None. Returns: SpeechSegment: Speech segment instance. """ - - audio = AudioSegment.from_file(filepath) + audio = AudioSegment.from_file(filepath, infos) return cls(audio.samples, audio.sample_rate, transcript, tokens, token_ids) diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 280a4073..04d433d6 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -56,8 +56,8 @@ class SpeechCollator(): for utt, audio, text in batch: utts.append(utt) # audio - audios.append(audio.T) # [T, D] - audio_lens.append(audio.shape[1]) + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) # text # for training, text is token ids # else text is string, convert to unicode ord diff --git a/examples/librispeech/s1/path.sh b/examples/librispeech/s1/path.sh index 30adb6ca..dcdfa45e 100644 --- a/examples/librispeech/s1/path.sh +++ b/examples/librispeech/s1/path.sh @@ -3,8 +3,9 @@ export MAIN_ROOT=${PWD}/../../../ export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export LC_ALL=C +export PYTHONDONTWRITEBYTECODE=1 # Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 +export PYTHONIOENCODING=UTF-8 export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ -- GitLab