提交 55742773 编写于 作者: H Haoxin Ma

move redundant params

上级 698d7a9b
......@@ -21,32 +21,18 @@ _C.data = CN(
train_manifest="",
dev_manifest="",
test_manifest="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
max_duration=float('inf'),
min_duration=0.0,
))
_C.model = CN(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
_C.collator =CN(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
random_seed=0,
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
......@@ -65,6 +51,16 @@ _C.collator =CN(
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
))
_C.model = CN(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
DeepSpeech2Model.params(_C.model)
_C.training = CN(
......
......@@ -143,7 +143,6 @@ class DeepSpeech2Trainer(Trainer):
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config)
if self.parallel:
......@@ -165,18 +164,22 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
collate_fn = SpeechCollator.from_config(config)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
collate_fn=collate_fn_dev)
logger.info("Setup train/valid Dataloader!")
......@@ -324,8 +327,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
# return raw text
config.data.manifest = config.data.test_manifest
config.data.keep_transcription_text = True
config.data.augmentation_config = ""
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
......@@ -337,6 +338,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
test_dataset = ManifestDataset.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
# return text ord id
self.test_loader = DataLoader(
test_dataset,
......
......@@ -17,21 +17,13 @@ from deepspeech.exps.u2.model import U2Tester
from deepspeech.exps.u2.model import U2Trainer
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.u2 import U2Model
from deepspeech.io.collator import SpeechCollator
_C = CfgNode()
_C.data = ManifestDataset.params()
_C.collator =CfgNode(
dict(
augmentation_config="",
unit_type="char",
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle" # 'batch_shuffle', 'instance_shuffle'
))
_C.collator = SpeechCollator.params()
_C.model = U2Model.params()
......
......@@ -100,7 +100,7 @@ class U2Trainer(Trainer):
if (batch_index + 1) % train_conf.log_interval == 0:
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
......@@ -211,51 +211,52 @@ class U2Trainer(Trainer):
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.data.keep_transcription_text = False
config.collator.keep_transcription_text = False
# train/valid dataset, return token ids
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator.from_config(config)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=config.data.num_workers, )
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers, )
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
collate_fn=collate_fn_dev)
# test dataset, return raw text
config.data.manifest = config.data.test_manifest
config.data.keep_transcription_text = True
config.data.augmentation_config = ""
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
......@@ -264,9 +265,11 @@ class U2Trainer(Trainer):
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
# return text ord id
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
......
......@@ -75,8 +75,8 @@ class SpeechCollator():
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.data
assert 'vocab_filepath' in config.data
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.collator
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
......@@ -94,9 +94,9 @@ class SpeechCollator():
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.data.mean_std_filepath,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.data.vocab_filepath,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim,
......@@ -129,11 +129,31 @@ class SpeechCollator():
target_dB=-20,
dither=1.0,
keep_transcription_text=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
"""SpeechCollator Collator
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self._keep_transcription_text = keep_transcription_text
......
......@@ -40,15 +40,7 @@ class ManifestDataset(Dataset):
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
train_manifest="",
dev_manifest="",
test_manifest="",
manifest="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
max_input_len=27.0,
min_input_len=0.0,
max_output_len=float('inf'),
......@@ -73,25 +65,10 @@ class ManifestDataset(Dataset):
"""
assert 'manifest' in config.data
assert config.data.manifest
assert 'keep_transcription_text' in config.collator
if isinstance(config.data.augmentation_config, (str, bytes)):
if config.data.augmentation_config:
aug_file = io.open(
config.data.augmentation_config, mode='r', encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.data.augmentation_config
assert isinstance(aug_file, io.StringIO)
dataset = cls(
manifest_path=config.data.manifest,
unit_type=config.data.unit_type,
vocab_filepath=config.data.vocab_filepath,
mean_std_filepath=config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config=aug_file.read(),
max_input_len=config.data.max_input_len,
min_input_len=config.data.min_input_len,
max_output_len=config.data.max_output_len,
......@@ -101,23 +78,8 @@ class ManifestDataset(Dataset):
)
return dataset
def _read_vocab(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r', encoding='utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
return vocab_list
def __init__(self,
manifest_path,
unit_type,
vocab_filepath,
mean_std_filepath,
spm_model_prefix=None,
augmentation_config='{}',
max_input_len=float('inf'),
min_input_len=0.0,
max_output_len=float('inf'),
......@@ -128,34 +90,16 @@ class ManifestDataset(Dataset):
Args:
manifest_path (str): manifest josn file path
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
"""
super().__init__()
# self._rng = np.random.RandomState(random_seed)
# read manifest
self._manifest = read_manifest(
manifest_path=manifest_path,
......@@ -167,51 +111,6 @@ class ManifestDataset(Dataset):
min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0])
# self._vocab_list = self._read_vocab(vocab_filepath)
# @property
# def manifest(self):
# return self._manifest
# @property
# def vocab_size(self):
# """Return the vocabulary size.
# Returns:
# int: Vocabulary size.
# """
# return len(self._vocab_list)
# @property
# def vocab_list(self):
# """Return the vocabulary in list.
# Returns:
# List[str]:
# """
# return self._vocab_list
# @property
# def vocab_dict(self):
# """Return the vocabulary in dict.
# Returns:
# Dict[str, int]:
# """
# vocab_dict = dict(
# [(token, idx) for (idx, token) in enumerate(self._vocab_list)])
# return vocab_dict
# @property
# def feature_size(self):
# """Return the audio feature size.
# Returns:
# int: audio feature size.
# """
# return self._manifest[0]["feat_shape"][-1]
def __len__(self):
return len(self._manifest)
......
......@@ -3,17 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 64
min_input_len: 0.5
max_input_len: 20.0 # second
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -32,7 +35,6 @@ data:
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
......
......@@ -2,10 +2,7 @@
data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
test_manifest: data/manifest.tiny
min_input_len: 0.0
max_input_len: 27.0
min_output_len: 0.0
......@@ -15,6 +12,9 @@ data:
collator:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
......@@ -43,7 +43,7 @@ model:
share_rnn_weights: True
training:
n_epoch: 23
n_epoch: 24
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06
......
......@@ -3,26 +3,20 @@ data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
batch_size: 4
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0 #2
collator:
vocab_filepath: data/vocab.txt
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
specgram_type: fbank
feat_dim: 80
delta_delta: False
......@@ -35,6 +29,12 @@ collator:
target_dB: -20
dither: 1.0
keep_transcription_text: False
batch_size: 4
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0 #2
raw_wav: True # use raw_wav or kaldi feature
# network architecture
model:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册