提交 11435a1b 编写于 作者: H Hui Zhang

fix compute cmvn, need paddle 2.1

上级 02caa564
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
"""Contains feature normalizers.""" """Contains feature normalizers."""
import json import json
import random
import numpy as np import numpy as np
import paddle import paddle
...@@ -27,18 +26,19 @@ from deepspeech.frontend.utility import read_manifest ...@@ -27,18 +26,19 @@ from deepspeech.frontend.utility import read_manifest
__all__ = ["FeatureNormalizer"] __all__ = ["FeatureNormalizer"]
# https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object): class CollateFunc(object):
''' Collate function for AudioDataset def __init__(self, feature_func):
''' self.feature_func = feature_func
def __init__(self):
pass
def __call__(self, batch): def __call__(self, batch):
mean_stat = None mean_stat = None
var_stat = None var_stat = None
number = 0 number = 0
for feat in batch: for item in batch:
audioseg = AudioSegment.from_file(item['feat'])
feat = self.feature_func(audioseg) #(D, T)
sums = np.sum(feat, axis=1) sums = np.sum(feat, axis=1)
if mean_stat is None: if mean_stat is None:
mean_stat = sums mean_stat = sums
...@@ -52,30 +52,25 @@ class CollateFunc(object): ...@@ -52,30 +52,25 @@ class CollateFunc(object):
var_stat += square_sums var_stat += square_sums
number += feat.shape[1] number += feat.shape[1]
return paddle.to_tensor(number), paddle.to_tensor( return number, mean_stat, var_stat
mean_stat), paddle.to_tensor(var_stat)
#return number, mean_stat, var_stat
class AudioDataset(Dataset): class AudioDataset(Dataset):
def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None): def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):
self.feature_func = feature_func self._rng = rng if rng else np.random.RandomState(random_seed)
self._rng = rng
manifest = read_manifest(manifest_path) manifest = read_manifest(manifest_path)
if num_samples == -1: if num_samples == -1:
sampled_manifest = manifest sampled_manifest = manifest
else: else:
sampled_manifest = self._rng.sample(manifest, num_samples) sampled_manifest = self._rng.choice(
manifest, num_samples, replace=False)
self.items = sampled_manifest self.items = sampled_manifest
def __len__(self): def __len__(self):
return len(self.items) return len(self.items)
def __getitem__(self, idx): def __getitem__(self, idx):
key = self.items[idx]['feat'] return self.items[idx]
audioseg = AudioSegment.from_file(key)
feat = self.feature_func(audioseg) #(D, T)
return feat
class FeatureNormalizer(object): class FeatureNormalizer(object):
...@@ -112,7 +107,7 @@ class FeatureNormalizer(object): ...@@ -112,7 +107,7 @@ class FeatureNormalizer(object):
if not (manifest_path and featurize_func): if not (manifest_path and featurize_func):
raise ValueError("If mean_std_filepath is None, meanifest_path " raise ValueError("If mean_std_filepath is None, meanifest_path "
"and featurize_func should not be None.") "and featurize_func should not be None.")
self._rng = random.Random(random_seed) self._rng = np.random.RandomState(random_seed)
self._compute_mean_std(manifest_path, featurize_func, num_samples, self._compute_mean_std(manifest_path, featurize_func, num_samples,
num_workers) num_workers)
else: else:
...@@ -150,29 +145,11 @@ class FeatureNormalizer(object): ...@@ -150,29 +145,11 @@ class FeatureNormalizer(object):
featurize_func, featurize_func,
num_samples, num_samples,
num_workers, num_workers,
batch_size=64,
eps=1e-20): eps=1e-20):
"""Compute mean and std from randomly sampled instances.""" """Compute mean and std from randomly sampled instances."""
# manifest = read_manifest(manifest_path) collate_func = CollateFunc(featurize_func)
# if num_samples == -1: dataset = AudioDataset(manifest_path, num_samples, self._rng)
# sampled_manifest = manifest
# else:
# sampled_manifest = self._rng.sample(manifest, num_samples)
# features = []
# for instance in sampled_manifest:
# features.append(
# featurize_func(AudioSegment.from_file(instance["feat"])))
# features = np.hstack(features) #(D, T)
# self._mean = np.mean(features, axis=1) #(D,)
# std = np.std(features, axis=1) #(D,)
# std = np.clip(std, eps, None)
# self._istd = 1.0 / std
collate_func = CollateFunc()
dataset = AudioDataset(manifest_path, featurize_func, num_samples,
self._rng)
batch_size = 20
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
...@@ -185,9 +162,9 @@ class FeatureNormalizer(object): ...@@ -185,9 +162,9 @@ class FeatureNormalizer(object):
all_var_stat = None all_var_stat = None
all_number = 0 all_number = 0
wav_number = 0 wav_number = 0
for batch in data_loader(): for i, batch in enumerate(data_loader):
number, mean_stat, var_stat = batch number, mean_stat, var_stat = batch
if all_mean_stat is None: if i == 0:
all_mean_stat = mean_stat all_mean_stat = mean_stat
all_var_stat = var_stat all_var_stat = var_stat
else: else:
...@@ -198,12 +175,12 @@ class FeatureNormalizer(object): ...@@ -198,12 +175,12 @@ class FeatureNormalizer(object):
if wav_number % 1000 == 0: if wav_number % 1000 == 0:
print('process {} wavs,{} frames'.format(wav_number, print('process {} wavs,{} frames'.format(wav_number,
int(all_number))) all_number))
self.cmvn_info = { self.cmvn_info = {
'mean_stat': list(all_mean_stat.tolist()), 'mean_stat': list(all_mean_stat.tolist()),
'var_stat': list(all_var_stat.tolist()), 'var_stat': list(all_var_stat.tolist()),
'frame_num': int(all_number), 'frame_num': all_number,
} }
return self.cmvn_info return self.cmvn_info
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io import io
import random
import tarfile import tarfile
import time import time
from collections import namedtuple from collections import namedtuple
from typing import Optional from typing import Optional
import numpy as np
from paddle.io import Dataset from paddle.io import Dataset
from yacs.config import CfgNode from yacs.config import CfgNode
...@@ -209,7 +209,7 @@ class ManifestDataset(Dataset): ...@@ -209,7 +209,7 @@ class ManifestDataset(Dataset):
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
target_dB=target_dB) target_dB=target_dB)
self._rng = random.Random(random_seed) self._rng = np.random.RandomState(random_seed)
self._keep_transcription_text = keep_transcription_text self._keep_transcription_text = keep_transcription_text
# for caching tar files info # for caching tar files info
self._local_data = namedtuple('local_data', ['tar2info', 'tar2object']) self._local_data = namedtuple('local_data', ['tar2info', 'tar2object'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册