提交 11435a1b 编写于 作者: H Hui Zhang

fix compute cmvn, need paddle 2.1

上级 02caa564
......@@ -13,7 +13,6 @@
# limitations under the License.
"""Contains feature normalizers."""
import json
import random
import numpy as np
import paddle
......@@ -27,18 +26,19 @@ from deepspeech.frontend.utility import read_manifest
__all__ = ["FeatureNormalizer"]
# https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object):
''' Collate function for AudioDataset
'''
def __init__(self):
pass
def __init__(self, feature_func):
self.feature_func = feature_func
def __call__(self, batch):
mean_stat = None
var_stat = None
number = 0
for feat in batch:
for item in batch:
audioseg = AudioSegment.from_file(item['feat'])
feat = self.feature_func(audioseg) #(D, T)
sums = np.sum(feat, axis=1)
if mean_stat is None:
mean_stat = sums
......@@ -52,30 +52,25 @@ class CollateFunc(object):
var_stat += square_sums
number += feat.shape[1]
return paddle.to_tensor(number), paddle.to_tensor(
mean_stat), paddle.to_tensor(var_stat)
#return number, mean_stat, var_stat
return number, mean_stat, var_stat
class AudioDataset(Dataset):
def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None):
self.feature_func = feature_func
self._rng = rng
def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):
self._rng = rng if rng else np.random.RandomState(random_seed)
manifest = read_manifest(manifest_path)
if num_samples == -1:
sampled_manifest = manifest
else:
sampled_manifest = self._rng.sample(manifest, num_samples)
sampled_manifest = self._rng.choice(
manifest, num_samples, replace=False)
self.items = sampled_manifest
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
key = self.items[idx]['feat']
audioseg = AudioSegment.from_file(key)
feat = self.feature_func(audioseg) #(D, T)
return feat
return self.items[idx]
class FeatureNormalizer(object):
......@@ -112,7 +107,7 @@ class FeatureNormalizer(object):
if not (manifest_path and featurize_func):
raise ValueError("If mean_std_filepath is None, meanifest_path "
"and featurize_func should not be None.")
self._rng = random.Random(random_seed)
self._rng = np.random.RandomState(random_seed)
self._compute_mean_std(manifest_path, featurize_func, num_samples,
num_workers)
else:
......@@ -150,29 +145,11 @@ class FeatureNormalizer(object):
featurize_func,
num_samples,
num_workers,
batch_size=64,
eps=1e-20):
"""Compute mean and std from randomly sampled instances."""
# manifest = read_manifest(manifest_path)
# if num_samples == -1:
# sampled_manifest = manifest
# else:
# sampled_manifest = self._rng.sample(manifest, num_samples)
# features = []
# for instance in sampled_manifest:
# features.append(
# featurize_func(AudioSegment.from_file(instance["feat"])))
# features = np.hstack(features) #(D, T)
# self._mean = np.mean(features, axis=1) #(D,)
# std = np.std(features, axis=1) #(D,)
# std = np.clip(std, eps, None)
# self._istd = 1.0 / std
collate_func = CollateFunc()
dataset = AudioDataset(manifest_path, featurize_func, num_samples,
self._rng)
batch_size = 20
collate_func = CollateFunc(featurize_func)
dataset = AudioDataset(manifest_path, num_samples, self._rng)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
......@@ -185,9 +162,9 @@ class FeatureNormalizer(object):
all_var_stat = None
all_number = 0
wav_number = 0
for batch in data_loader():
for i, batch in enumerate(data_loader):
number, mean_stat, var_stat = batch
if all_mean_stat is None:
if i == 0:
all_mean_stat = mean_stat
all_var_stat = var_stat
else:
......@@ -198,12 +175,12 @@ class FeatureNormalizer(object):
if wav_number % 1000 == 0:
print('process {} wavs,{} frames'.format(wav_number,
int(all_number)))
all_number))
self.cmvn_info = {
'mean_stat': list(all_mean_stat.tolist()),
'var_stat': list(all_var_stat.tolist()),
'frame_num': int(all_number),
'frame_num': all_number,
}
return self.cmvn_info
......@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import random
import tarfile
import time
from collections import namedtuple
from typing import Optional
import numpy as np
from paddle.io import Dataset
from yacs.config import CfgNode
......@@ -209,7 +209,7 @@ class ManifestDataset(Dataset):
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
self._rng = random.Random(random_seed)
self._rng = np.random.RandomState(random_seed)
self._keep_transcription_text = keep_transcription_text
# for caching tar files info
self._local_data = namedtuple('local_data', ['tar2info', 'tar2object'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册