normalizer.py 6.7 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
"""Contains feature normalizers."""
15
import json
16 17

import numpy as np
18 19 20 21
import paddle
from paddle.io import DataLoader
from paddle.io import Dataset

H
Hui Zhang 已提交
22
from deepspeech.frontend.audio import AudioSegment
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
from deepspeech.frontend.utility import load_cmvn
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log

__all__ = ["FeatureNormalizer"]

logger = Log(__name__).getlog()


# https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object):
    def __init__(self, feature_func):
        self.feature_func = feature_func

    def __call__(self, batch):
        mean_stat = None
        var_stat = None
        number = 0
        for item in batch:
            audioseg = AudioSegment.from_file(item['feat'])
H
Hui Zhang 已提交
43
            feat = self.feature_func(audioseg)  #(T, D)
44

H
Hui Zhang 已提交
45
            sums = np.sum(feat, axis=0)
46 47 48 49 50
            if mean_stat is None:
                mean_stat = sums
            else:
                mean_stat += sums

H
Hui Zhang 已提交
51
            square_sums = np.sum(np.square(feat), axis=0)
52 53 54 55 56
            if var_stat is None:
                var_stat = square_sums
            else:
                var_stat += square_sums

H
Hui Zhang 已提交
57
            number += feat.shape[0]
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        return number, mean_stat, var_stat


class AudioDataset(Dataset):
    def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):
        self._rng = rng if rng else np.random.RandomState(random_seed)
        manifest = read_manifest(manifest_path)
        if num_samples == -1:
            sampled_manifest = manifest
        else:
            sampled_manifest = self._rng.choice(
                manifest, num_samples, replace=False)
        self.items = sampled_manifest

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        return self.items[idx]
77 78 79


class FeatureNormalizer(object):
80 81 82 83 84 85
    """Feature normalizer. Normalize features to be of zero mean and unit
    stddev.

    if mean_std_filepath is provided (not None), the normalizer will directly
    initilize from the file. Otherwise, both manifest_path and featurize_func
    should be given for on-the-fly mean and stddev computing.
Y
Yibing Liu 已提交
86

87
    :param mean_std_filepath: File containing the pre-computed mean and stddev.
H
Hui Zhang 已提交
88
    :type mean_std_filepath: None|str
89
    :param manifest_path: Manifest of instances for computing mean and stddev.
H
Hui Zhang 已提交
90
    :type meanifest_path: None|str
91 92 93 94 95 96 97 98 99 100 101
    :param featurize_func: Function to extract features. It should be callable
                           with ``featurize_func(audio_segment)``.
    :type featurize_func: None|callable
    :param num_samples: Number of random samples for computing mean and stddev.
    :type num_samples: int
    :param random_seed: Random seed for sampling instances.
    :type random_seed: int
    :raises ValueError: If both mean_std_filepath and manifest_path
                        (or both mean_std_filepath and featurize_func) are None.
    """

102 103 104 105 106
    def __init__(self,
                 mean_std_filepath,
                 manifest_path=None,
                 featurize_func=None,
                 num_samples=500,
107
                 num_workers=0,
108 109 110 111 112
                 random_seed=0):
        if not mean_std_filepath:
            if not (manifest_path and featurize_func):
                raise ValueError("If mean_std_filepath is None, meanifest_path "
                                 "and featurize_func should not be None.")
113 114 115
            self._rng = np.random.RandomState(random_seed)
            self._compute_mean_std(manifest_path, featurize_func, num_samples,
                                   num_workers)
116 117 118
        else:
            self._read_mean_std_from_file(mean_std_filepath)

119
    def apply(self, features):
120 121 122
        """Normalize features to be of zero mean and unit stddev.

        :param features: Input features to be normalized.
H
Hui Zhang 已提交
123
        :type features: ndarray, shape (T, D)
124 125 126 127 128
        :param eps:  added to stddev to provide numerical stablibity.
        :type eps: float
        :return: Normalized features.
        :rtype: ndarray
        """
129 130 131 132
        return (features - self._mean) * self._istd

    def _read_mean_std_from_file(self, filepath, eps=1e-20):
        """Load mean and std from file."""
H
huangyuxin 已提交
133 134
        filetype = filepath.split(".")[-1]
        mean, istd = load_cmvn(filepath, filetype=filetype)
H
Hui Zhang 已提交
135 136
        self._mean = np.expand_dims(mean, axis=0)
        self._istd = np.expand_dims(istd, axis=0)
137 138

    def write_to_file(self, filepath):
139 140 141
        """Write the mean and stddev to the file.

        :param filepath: File to write mean and stddev.
H
Hui Zhang 已提交
142
        :type filepath: str
143
        """
144 145
        with open(filepath, 'w') as fout:
            fout.write(json.dumps(self.cmvn_info))
146

147 148 149 150 151 152 153
    def _compute_mean_std(self,
                          manifest_path,
                          featurize_func,
                          num_samples,
                          num_workers,
                          batch_size=64,
                          eps=1e-20):
154
        """Compute mean and std from randomly sampled instances."""
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        paddle.set_device('cpu')

        collate_func = CollateFunc(featurize_func)
        dataset = AudioDataset(manifest_path, num_samples, self._rng)
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            collate_fn=collate_func)

        with paddle.no_grad():
            all_mean_stat = None
            all_var_stat = None
            all_number = 0
            wav_number = 0
            for i, batch in enumerate(data_loader):
                number, mean_stat, var_stat = batch
                if i == 0:
                    all_mean_stat = mean_stat
                    all_var_stat = var_stat
                else:
                    all_mean_stat += mean_stat
                    all_var_stat += var_stat
                all_number += number
                wav_number += batch_size

                if wav_number % 1000 == 0:
183 184
                    logger.info(
                        f'process {wav_number} wavs,{all_number} frames.')
185 186 187 188 189 190 191 192

        self.cmvn_info = {
            'mean_stat': list(all_mean_stat.tolist()),
            'var_stat': list(all_var_stat.tolist()),
            'frame_num': all_number,
        }

        return self.cmvn_info