提交 b9110af9 编写于 作者: H Haoxin Ma

feat_dim, vocab_size

上级 3855522e
......@@ -137,7 +137,7 @@ class DeepSpeech2Trainer(Trainer):
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.data.keep_transcription_text = False
config.collator.keep_transcription_text = False
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
......@@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
collate_fn = SpeechCollator(config=config, keep_transcription_text=False)
collate_fn = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
......
......@@ -104,50 +104,7 @@ class SpeechFeaturizer(object):
speech_segment.transcript)
return spec_feature, text_ids
@property
def vocab_size(self):
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return self._text_featurizer.vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]:
"""
return self._text_featurizer.vocab_dict
@property
def feature_size(self):
"""Return the audio feature size.
Returns:
int: audio feature size.
"""
return self._audio_featurizer.feature_size
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
return self._audio_featurizer.stride_ms
@property
def text_feature(self):
......
......@@ -82,7 +82,7 @@ def read_manifest(
]
if all(conditions):
manifest.append(json_data)
return manifest
return manifest, json_data["feat_shape"][-1]
def rms_to_db(rms: float):
......
......@@ -22,6 +22,8 @@ from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
import io
import time
from yacs.config import CfgNode
from typing import Optional
from collections import namedtuple
......@@ -33,51 +35,134 @@ logger = Log(__name__).getlog()
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator():
def __init__(self, config, keep_transcription_text=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=True
))
if ``keep_transcription_text`` is False, text is token ids else is raw string.
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
self._keep_transcription_text = keep_transcription_text
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.data
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if isinstance(config.data.augmentation_config, (str, bytes)):
if config.data.augmentation_config:
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.data.augmentation_config, mode='r', encoding='utf8')
config.collator.augmentation_config, mode='r', encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.data.augmentation_config
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.data.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text
)
return speech_collator
def __init__(self, aug_file, mean_std_filepath,
vocab_filepath, spm_model_prefix,
random_seed=0,
unit_type="char",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0,
keep_transcription_text=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
"""
self._keep_transcription_text = keep_transcription_text
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(),
random_seed=config.data.random_seed)
random_seed=random_seed)
self._normalizer = FeatureNormalizer(
config.data.mean_std_filepath) if config.data.mean_std_filepath else None
mean_std_filepath) if mean_std_filepath else None
self._stride_ms = config.data.stride_ms
self._target_sample_rate = config.data.target_sample_rate
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer(
unit_type=config.data.unit_type,
vocab_filepath=config.data.vocab_filepath,
spm_model_prefix=config.data.spm_model_prefix,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delta_delta,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
dither=config.data.dither)
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
......@@ -196,3 +281,28 @@ class SpeechCollator():
texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._text_featurizer
self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
......@@ -55,20 +55,6 @@ class ManifestDataset(Dataset):
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
raw_wav=True, # use raw_wav or kaldi feature
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
dither=1.0, # feature dither
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
......@@ -116,21 +102,19 @@ class ManifestDataset(Dataset):
min_output_len=config.data.min_output_len,
max_output_input_ratio=config.data.max_output_input_ratio,
min_output_input_ratio=config.data.min_output_input_ratio,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delta_delta,
dither=config.data.dither,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=config.data.keep_transcription_text)
)
return dataset
def _read_vocab(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r', encoding='utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
return vocab_list
def __init__(self,
manifest_path,
unit_type,
......@@ -143,20 +127,7 @@ class ManifestDataset(Dataset):
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
stride_ms=10.0,
window_ms=20.0,
n_fft=None,
max_freq=None,
target_sample_rate=16000,
specgram_type='linear',
feat_dim=None,
delta_delta=False,
dither=1.0,
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False):
min_output_input_ratio=0.0):
"""Manifest Dataset
Args:
......@@ -186,30 +157,11 @@ class ManifestDataset(Dataset):
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
"""
super().__init__()
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
self._rng = np.random.RandomState(random_seed)
self._keep_transcription_text = keep_transcription_text
# self._rng = np.random.RandomState(random_seed)
# read manifest
self._manifest = read_manifest(
self._manifest, self._feature_size = read_manifest(
manifest_path=manifest_path,
max_input_len=max_input_len,
min_input_len=min_input_len,
......@@ -219,9 +171,59 @@ class ManifestDataset(Dataset):
min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0])
self._vocab_list = self._read_vocab(vocab_filepath)
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return len(self._vocab_list)
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return self._vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]:
"""
vocab_dict = dict(
[(token, idx) for (idx, token) in enumerate(self._vocab_list)])
return vocab_dict
@property
def feature_size(self):
"""Return the audio feature size.
Returns:
int: audio feature size.
"""
return self._feature_size
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
return self._audio_featurizer.stride_ms
def __len__(self):
return len(self._manifest)
......
......@@ -4,9 +4,10 @@ data:
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 2
batch_size: 4
min_input_len: 0.0
max_input_len: 27.0
min_output_len: 0.0
......@@ -28,6 +29,24 @@ data:
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
collator:
augmentation_config: conf/augmentation.json
random_seed: 0
mean_std_filepath: data/mean_std.json
spm_model_prefix:
specgram_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: True
model:
num_conv_layers: 2
......@@ -37,7 +56,7 @@ model:
share_rnn_weights: True
training:
n_epoch: 10
n_epoch: 21
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册