提交 ee5a0c48 编写于 作者: H Hui Zhang

fix cmvn compute

上级 9876bdb4
......@@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "warming-contrast",
"id": "future-wesley",
"metadata": {},
"outputs": [
{
......@@ -32,7 +32,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "genuine-marker",
"id": "eleven-istanbul",
"metadata": {},
"outputs": [
{
......@@ -91,7 +91,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "accepting-genesis",
"id": "provincial-mexico",
"metadata": {},
"outputs": [
{
......@@ -815,7 +815,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "baking-ozone",
"id": "choice-psychology",
"metadata": {},
"outputs": [
{
......@@ -1528,7 +1528,7 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "committed-supplier",
"id": "enabling-botswana",
"metadata": {},
"outputs": [
{
......@@ -1551,7 +1551,7 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "wooden-rugby",
"id": "acute-hunter",
"metadata": {},
"outputs": [],
"source": [
......@@ -1566,7 +1566,7 @@
{
"cell_type": "code",
"execution_count": 7,
"id": "streaming-queue",
"id": "impossible-mount",
"metadata": {},
"outputs": [
{
......@@ -1662,7 +1662,7 @@
{
"cell_type": "code",
"execution_count": 8,
"id": "cardiovascular-controversy",
"id": "dying-ideal",
"metadata": {},
"outputs": [],
"source": [
......@@ -1741,7 +1741,7 @@
{
"cell_type": "code",
"execution_count": 9,
"id": "sorted-nursery",
"id": "pleased-isaac",
"metadata": {},
"outputs": [
{
......@@ -1777,7 +1777,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "typical-destruction",
"id": "appreciated-carpet",
"metadata": {},
"outputs": [],
"source": []
......@@ -1785,7 +1785,7 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "junior-toner",
"id": "suitable-railway",
"metadata": {},
"outputs": [
{
......@@ -1924,7 +1924,7 @@
{
"cell_type": "code",
"execution_count": 11,
"id": "dense-brake",
"id": "afraid-translation",
"metadata": {},
"outputs": [],
"source": [
......@@ -1934,7 +1934,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "voluntary-arcade",
"id": "answering-slide",
"metadata": {},
"outputs": [],
"source": []
......@@ -1942,7 +1942,7 @@
{
"cell_type": "code",
"execution_count": 12,
"id": "surprising-teach",
"id": "undefined-glenn",
"metadata": {},
"outputs": [
{
......@@ -1972,7 +1972,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "permanent-loading",
"id": "twenty-funds",
"metadata": {},
"outputs": [],
"source": []
......@@ -1980,7 +1980,7 @@
{
"cell_type": "code",
"execution_count": 13,
"id": "criminal-setup",
"id": "threatened-phase",
"metadata": {},
"outputs": [
{
......@@ -2003,7 +2003,7 @@
{
"cell_type": "code",
"execution_count": 14,
"id": "brazilian-happening",
"id": "ordered-denver",
"metadata": {},
"outputs": [
{
......@@ -2021,7 +2021,7 @@
{
"cell_type": "code",
"execution_count": 15,
"id": "separate-eligibility",
"id": "above-investigator",
"metadata": {},
"outputs": [
{
......@@ -2053,7 +2053,7 @@
{
"cell_type": "code",
"execution_count": 33,
"id": "alternate-comment",
"id": "dimensional-introduction",
"metadata": {},
"outputs": [
{
......@@ -2216,7 +2216,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "polish-opportunity",
"id": "basic-basement",
"metadata": {},
"outputs": [],
"source": []
......@@ -2224,7 +2224,7 @@
{
"cell_type": "code",
"execution_count": 17,
"id": "improved-alabama",
"id": "decreased-automation",
"metadata": {},
"outputs": [
{
......@@ -2246,7 +2246,7 @@
{
"cell_type": "code",
"execution_count": 18,
"id": "metric-destruction",
"id": "marine-middle",
"metadata": {},
"outputs": [],
"source": [
......@@ -2256,7 +2256,7 @@
{
"cell_type": "code",
"execution_count": 20,
"id": "turkish-watch",
"id": "young-reserve",
"metadata": {},
"outputs": [],
"source": [
......@@ -2267,7 +2267,7 @@
{
"cell_type": "code",
"execution_count": 47,
"id": "drawn-crash",
"id": "differential-mileage",
"metadata": {},
"outputs": [],
"source": [
......@@ -2328,7 +2328,7 @@
{
"cell_type": "code",
"execution_count": 48,
"id": "informative-optics",
"id": "industrial-server",
"metadata": {},
"outputs": [
{
......@@ -2385,7 +2385,7 @@
{
"cell_type": "code",
"execution_count": 49,
"id": "northern-advisory",
"id": "noticed-soviet",
"metadata": {},
"outputs": [
{
......@@ -2416,7 +2416,7 @@
{
"cell_type": "code",
"execution_count": 50,
"id": "prospective-death",
"id": "clinical-matter",
"metadata": {},
"outputs": [
{
......@@ -2455,7 +2455,7 @@
{
"cell_type": "code",
"execution_count": 51,
"id": "closed-partner",
"id": "checked-picking",
"metadata": {},
"outputs": [],
"source": []
......@@ -2463,7 +2463,7 @@
{
"cell_type": "code",
"execution_count": 52,
"id": "silent-animal",
"id": "normal-airfare",
"metadata": {},
"outputs": [
{
......@@ -2483,7 +2483,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "fatal-board",
"id": "fewer-drill",
"metadata": {},
"outputs": [],
"source": []
......
......@@ -15,10 +15,69 @@
import random
import numpy as np
import paddle
from paddle.io import DataLoader
from paddle.io import Dataset
from deepspeech.frontend.audio import AudioSegment
from deepspeech.frontend.utility import load_cmvn
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log
__all__ = ["FeatureNormalizer"]
logger = Log(__name__).getlog()
class CollateFunc(object):
''' Collate function for AudioDataset
'''
def __init__(self):
pass
def __call__(self, batch):
mean_stat = None
var_stat = None
number = 0
for feat in batch:
sums = np.sum(feat, axis=1)
if mean_stat is None:
mean_stat = sums
else:
mean_stat += sums
square_sums = np.sum(np.square(feat), axis=1)
if var_stat is None:
var_stat = square_sums
else:
var_stat += square_sums
number += feat.shape[1]
return paddle.to_tensor(number), paddle.to_tensor(
mean_stat), paddle.to_tensor(var_stat)
#return number, mean_stat, var_stat
class AudioDataset(Dataset):
def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None):
self.feature_func = feature_func
self._rng = rng
manifest = read_manifest(manifest_path)
if num_samples == -1:
sampled_manifest = manifest
else:
sampled_manifest = self._rng.sample(manifest, num_samples)
self.items = sampled_manifest
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
key = self.items[idx]['feat']
audioseg = AudioSegment.from_file(key)
feat = self.feature_func(audioseg) #(D, T)
return feat
class FeatureNormalizer(object):
......@@ -49,13 +108,15 @@ class FeatureNormalizer(object):
manifest_path=None,
featurize_func=None,
num_samples=500,
num_workers=0,
random_seed=0):
if not mean_std_filepath:
if not (manifest_path and featurize_func):
raise ValueError("If mean_std_filepath is None, meanifest_path "
"and featurize_func should not be None.")
self._rng = random.Random(random_seed)
self._compute_mean_std(manifest_path, featurize_func, num_samples)
self._compute_mean_std(manifest_path, featurize_func, num_samples,
num_workers)
else:
self._read_mean_std_from_file(mean_std_filepath)
......@@ -71,37 +132,79 @@ class FeatureNormalizer(object):
"""
return (features - self._mean) * self._istd
def _read_mean_std_from_file(self, filepath, eps=1e-20):
"""Load mean and std from file."""
mean, istd = load_cmvn(filepath, filetype='json')
self._mean = mean
self._istd = istd
def write_to_file(self, filepath):
"""Write the mean and stddev to the file.
:param filepath: File to write mean and stddev.
:type filepath: str
"""
np.savez(filepath, mean=self._mean, istd=self._istd)
def _read_mean_std_from_file(self, filepath, eps=1e-20):
"""Load mean and std from file."""
mean, istd = load_cmvn(filepath, filetype='npz')
self._mean = mean.T
self._istd = istd.T
with open(filepath, 'w') as fout:
fout.write(json.dumps(self.cmvn_info))
def _compute_mean_std(self,
manifest_path,
featurize_func,
num_samples,
num_workers,
eps=1e-20):
"""Compute mean and std from randomly sampled instances."""
manifest = read_manifest(manifest_path)
if num_samples == -1:
sampled_manifest = manifest
else:
sampled_manifest = self._rng.sample(manifest, num_samples)
features = []
for instance in sampled_manifest:
features.append(
featurize_func(AudioSegment.from_file(instance["feat"])))
features = np.hstack(features) #(D, T)
self._mean = np.mean(features, axis=1) #(D,)
std = np.std(features, axis=1) #(D,)
std = np.clip(std, eps, None)
self._istd = 1.0 / std
# manifest = read_manifest(manifest_path)
# if num_samples == -1:
# sampled_manifest = manifest
# else:
# sampled_manifest = self._rng.sample(manifest, num_samples)
# features = []
# for instance in sampled_manifest:
# features.append(
# featurize_func(AudioSegment.from_file(instance["feat"])))
# features = np.hstack(features) #(D, T)
# self._mean = np.mean(features, axis=1) #(D,)
# std = np.std(features, axis=1) #(D,)
# std = np.clip(std, eps, None)
# self._istd = 1.0 / std
collate_func = CollateFunc()
dataset = AudioDataset(manifest_path, featurize_func, num_samples)
batch_size = 20
data_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_func)
with paddle.no_grad():
all_mean_stat = None
all_var_stat = None
all_number = 0
wav_number = 0
for batch in data_loader():
number, mean_stat, var_stat = batch
if all_mean_stat is None:
all_mean_stat = mean_stat
all_var_stat = var_stat
else:
all_mean_stat += mean_stat
all_var_stat += var_stat
all_number += number
wav_number += batch_size
if wav_number % 1000 == 0:
logger.info('process {} wavs,{} frames'.format(
wav_number, int(all_number)))
self.cmvn_info = {
'mean_stat': list(all_mean_stat.tolist()),
'var_stat': list(all_var_stat.tolist()),
'frame_num': int(all_number),
}
return self.cmvn_info
......@@ -42,6 +42,7 @@ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--stride_ms=10.0 \
--window_ms=25.0 \
--sample_rate=16000 \
--num_workers=0 \
--output_path="data/mean_std.npz"
if [ $? -ne 0 ]; then
......
......@@ -39,6 +39,10 @@ add_arg('sample_rate', int, 16000, "target sample rate.")
add_arg('manifest_path', str,
'data/librispeech/manifest.train',
"Filepath of manifest to compute normalizer's mean and stddev.")
add_arg('num_workers',
default=0,
type=int,
help='num of subprocess workers for processing')
add_arg('output_path', str,
'data/librispeech/mean_std.npz',
"Filepath of write mean and stddev to (.npz).")
......@@ -70,7 +74,8 @@ def main():
mean_std_filepath=None,
manifest_path=args.manifest_path,
featurize_func=augment_and_featurize,
num_samples=args.num_samples)
num_samples=args.num_samples,
num_workers=args.num_workers)
normalizer.write_to_file(args.output_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册