diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py index cbb82b8e301a30e903b42f7b52d9348ac80bb35e..3b1fbe7c6faa7639061ab07d98b7601a2bdf6b35 100644 --- a/deepspeech/frontend/normalizer.py +++ b/deepspeech/frontend/normalizer.py @@ -13,7 +13,6 @@ # limitations under the License. """Contains feature normalizers.""" import json -import random import numpy as np import paddle @@ -27,18 +26,19 @@ from deepspeech.frontend.utility import read_manifest __all__ = ["FeatureNormalizer"] +# https://github.com/PaddlePaddle/Paddle/pull/31481 class CollateFunc(object): - ''' Collate function for AudioDataset - ''' - - def __init__(self): - pass + def __init__(self, feature_func): + self.feature_func = feature_func def __call__(self, batch): mean_stat = None var_stat = None number = 0 - for feat in batch: + for item in batch: + audioseg = AudioSegment.from_file(item['feat']) + feat = self.feature_func(audioseg) #(D, T) + sums = np.sum(feat, axis=1) if mean_stat is None: mean_stat = sums @@ -52,30 +52,25 @@ class CollateFunc(object): var_stat += square_sums number += feat.shape[1] - return paddle.to_tensor(number), paddle.to_tensor( - mean_stat), paddle.to_tensor(var_stat) - #return number, mean_stat, var_stat + return number, mean_stat, var_stat class AudioDataset(Dataset): - def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None): - self.feature_func = feature_func - self._rng = rng + def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0): + self._rng = rng if rng else np.random.RandomState(random_seed) manifest = read_manifest(manifest_path) if num_samples == -1: sampled_manifest = manifest else: - sampled_manifest = self._rng.sample(manifest, num_samples) + sampled_manifest = self._rng.choice( + manifest, num_samples, replace=False) self.items = sampled_manifest def __len__(self): return len(self.items) def __getitem__(self, idx): - key = self.items[idx]['feat'] - audioseg = AudioSegment.from_file(key) - feat = self.feature_func(audioseg) #(D, T) - return feat + return self.items[idx] class FeatureNormalizer(object): @@ -112,7 +107,7 @@ class FeatureNormalizer(object): if not (manifest_path and featurize_func): raise ValueError("If mean_std_filepath is None, meanifest_path " "and featurize_func should not be None.") - self._rng = random.Random(random_seed) + self._rng = np.random.RandomState(random_seed) self._compute_mean_std(manifest_path, featurize_func, num_samples, num_workers) else: @@ -150,29 +145,11 @@ class FeatureNormalizer(object): featurize_func, num_samples, num_workers, + batch_size=64, eps=1e-20): """Compute mean and std from randomly sampled instances.""" - # manifest = read_manifest(manifest_path) - # if num_samples == -1: - # sampled_manifest = manifest - # else: - # sampled_manifest = self._rng.sample(manifest, num_samples) - # features = [] - # for instance in sampled_manifest: - # features.append( - # featurize_func(AudioSegment.from_file(instance["feat"]))) - # features = np.hstack(features) #(D, T) - # self._mean = np.mean(features, axis=1) #(D,) - # std = np.std(features, axis=1) #(D,) - # std = np.clip(std, eps, None) - # self._istd = 1.0 / std - - collate_func = CollateFunc() - - dataset = AudioDataset(manifest_path, featurize_func, num_samples, - self._rng) - - batch_size = 20 + collate_func = CollateFunc(featurize_func) + dataset = AudioDataset(manifest_path, num_samples, self._rng) data_loader = DataLoader( dataset, batch_size=batch_size, @@ -185,9 +162,9 @@ class FeatureNormalizer(object): all_var_stat = None all_number = 0 wav_number = 0 - for batch in data_loader(): + for i, batch in enumerate(data_loader): number, mean_stat, var_stat = batch - if all_mean_stat is None: + if i == 0: all_mean_stat = mean_stat all_var_stat = var_stat else: @@ -198,12 +175,12 @@ class FeatureNormalizer(object): if wav_number % 1000 == 0: print('process {} wavs,{} frames'.format(wav_number, - int(all_number))) + all_number)) self.cmvn_info = { 'mean_stat': list(all_mean_stat.tolist()), 'var_stat': list(all_var_stat.tolist()), - 'frame_num': int(all_number), + 'frame_num': all_number, } return self.cmvn_info diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 4f58d41c6b61bfeda44f085d5124e9c78d7dbf65..333fc0c648fd9e28f94558e5ea5dd8d0101ebfd6 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -import random import tarfile import time from collections import namedtuple from typing import Optional +import numpy as np from paddle.io import Dataset from yacs.config import CfgNode @@ -209,7 +209,7 @@ class ManifestDataset(Dataset): use_dB_normalization=use_dB_normalization, target_dB=target_dB) - self._rng = random.Random(random_seed) + self._rng = np.random.RandomState(random_seed) self._keep_transcription_text = keep_transcription_text # for caching tar files info self._local_data = namedtuple('local_data', ['tar2info', 'tar2object'])