提交 55742773 编写于 作者: H Haoxin Ma

move redundant params

上级 698d7a9b
...@@ -21,32 +21,18 @@ _C.data = CN( ...@@ -21,32 +21,18 @@ _C.data = CN(
train_manifest="", train_manifest="",
dev_manifest="", dev_manifest="",
test_manifest="", test_manifest="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
max_duration=float('inf'), max_duration=float('inf'),
min_duration=0.0, min_duration=0.0,
)) ))
_C.model = CN(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
_C.collator =CN( _C.collator =CN(
dict( dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char", unit_type="char",
vocab_filepath="", vocab_filepath="",
spm_model_prefix="", spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
random_seed=0,
specgram_type='linear', # 'linear', 'mfcc', 'fbank' specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank' delta_delta=False, # 'mfcc', 'fbank'
...@@ -65,6 +51,16 @@ _C.collator =CN( ...@@ -65,6 +51,16 @@ _C.collator =CN(
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
)) ))
_C.model = CN(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
DeepSpeech2Model.params(_C.model) DeepSpeech2Model.params(_C.model)
_C.training = CN( _C.training = CN(
......
...@@ -143,7 +143,6 @@ class DeepSpeech2Trainer(Trainer): ...@@ -143,7 +143,6 @@ class DeepSpeech2Trainer(Trainer):
train_dataset = ManifestDataset.from_config(config) train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest config.data.manifest = config.data.dev_manifest
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
if self.parallel: if self.parallel:
...@@ -165,18 +164,22 @@ class DeepSpeech2Trainer(Trainer): ...@@ -165,18 +164,22 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.collator.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method) shuffle_method=config.collator.shuffle_method)
collate_fn = SpeechCollator.from_config(config) collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=collate_fn, collate_fn=collate_fn_train,
num_workers=config.collator.num_workers) num_workers=config.collator.num_workers)
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.collator.batch_size, batch_size=config.collator.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn) collate_fn=collate_fn_dev)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
...@@ -324,8 +327,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -324,8 +327,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
# return raw text # return raw text
config.data.manifest = config.data.test_manifest config.data.manifest = config.data.test_manifest
config.data.keep_transcription_text = True
config.data.augmentation_config = ""
# filter test examples, will cause less examples, but no mismatch with training # filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now. # and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second # config.data.min_input_len = 0.0 # second
...@@ -337,6 +338,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -337,6 +338,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
test_dataset = ManifestDataset.from_config(config) test_dataset = ManifestDataset.from_config(config)
config.collator.keep_transcription_text = True config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
# return text ord id # return text ord id
self.test_loader = DataLoader( self.test_loader = DataLoader(
test_dataset, test_dataset,
......
...@@ -17,21 +17,13 @@ from deepspeech.exps.u2.model import U2Tester ...@@ -17,21 +17,13 @@ from deepspeech.exps.u2.model import U2Tester
from deepspeech.exps.u2.model import U2Trainer from deepspeech.exps.u2.model import U2Trainer
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.io.collator import SpeechCollator
_C = CfgNode() _C = CfgNode()
_C.data = ManifestDataset.params() _C.data = ManifestDataset.params()
_C.collator =CfgNode( _C.collator = SpeechCollator.params()
dict(
augmentation_config="",
unit_type="char",
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle" # 'batch_shuffle', 'instance_shuffle'
))
_C.model = U2Model.params() _C.model = U2Model.params()
......
...@@ -100,7 +100,7 @@ class U2Trainer(Trainer): ...@@ -100,7 +100,7 @@ class U2Trainer(Trainer):
if (batch_index + 1) % train_conf.log_interval == 0: if (batch_index + 1) % train_conf.log_interval == 0:
msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size) msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad) msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items()) for k, v in losses_np.items())
...@@ -211,51 +211,52 @@ class U2Trainer(Trainer): ...@@ -211,51 +211,52 @@ class U2Trainer(Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
config.defrost() config.defrost()
config.data.keep_transcription_text = False config.collator.keep_transcription_text = False
# train/valid dataset, return token ids # train/valid dataset, return token ids
config.data.manifest = config.data.train_manifest config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config) train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest config.data.manifest = config.data.dev_manifest
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator.from_config(config) collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
num_replicas=None, num_replicas=None,
rank=None, rank=None,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
sortagrad=config.data.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.collator.shuffle_method)
else: else:
batch_sampler = SortagradBatchSampler( batch_sampler = SortagradBatchSampler(
train_dataset, train_dataset,
shuffle=True, shuffle=True,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
drop_last=True, drop_last=True,
sortagrad=config.data.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.collator.shuffle_method)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=collate_fn, collate_fn=collate_fn_train,
num_workers=config.data.num_workers, ) num_workers=config.collator.num_workers, )
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn) collate_fn=collate_fn_dev)
# test dataset, return raw text # test dataset, return raw text
config.data.manifest = config.data.test_manifest config.data.manifest = config.data.test_manifest
config.data.keep_transcription_text = True
config.data.augmentation_config = ""
# filter test examples, will cause less examples, but no mismatch with training # filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now. # and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second # config.data.min_input_len = 0.0 # second
...@@ -264,9 +265,11 @@ class U2Trainer(Trainer): ...@@ -264,9 +265,11 @@ class U2Trainer(Trainer):
# config.data.max_output_len = float('inf') # tokens # config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00 # config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf') # config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config) test_dataset = ManifestDataset.from_config(config)
# return text ord id # return text ord id
config.collator.keep_transcription_text = True config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
self.test_loader = DataLoader( self.test_loader = DataLoader(
test_dataset, test_dataset,
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
......
...@@ -75,8 +75,8 @@ class SpeechCollator(): ...@@ -75,8 +75,8 @@ class SpeechCollator():
""" """
assert 'augmentation_config' in config.collator assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.data assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.data assert 'vocab_filepath' in config.collator
assert 'specgram_type' in config.collator assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator assert 'n_fft' in config.collator
assert config.collator assert config.collator
...@@ -94,9 +94,9 @@ class SpeechCollator(): ...@@ -94,9 +94,9 @@ class SpeechCollator():
speech_collator = cls( speech_collator = cls(
aug_file=aug_file, aug_file=aug_file,
random_seed=0, random_seed=0,
mean_std_filepath=config.data.mean_std_filepath, mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type, unit_type=config.collator.unit_type,
vocab_filepath=config.data.vocab_filepath, vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix, spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type, specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim, feat_dim=config.collator.feat_dim,
...@@ -129,11 +129,31 @@ class SpeechCollator(): ...@@ -129,11 +129,31 @@ class SpeechCollator():
target_dB=-20, target_dB=-20,
dither=1.0, dither=1.0,
keep_transcription_text=True): keep_transcription_text=True):
""" """SpeechCollator Collator
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
if ``keep_transcription_text`` is False, text is token ids else is raw string. Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
""" """
self._keep_transcription_text = keep_transcription_text self._keep_transcription_text = keep_transcription_text
......
...@@ -40,15 +40,7 @@ class ManifestDataset(Dataset): ...@@ -40,15 +40,7 @@ class ManifestDataset(Dataset):
def params(cls, config: Optional[CfgNode]=None) -> CfgNode: def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode( default = CfgNode(
dict( dict(
train_manifest="",
dev_manifest="",
test_manifest="",
manifest="", manifest="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
max_input_len=27.0, max_input_len=27.0,
min_input_len=0.0, min_input_len=0.0,
max_output_len=float('inf'), max_output_len=float('inf'),
...@@ -73,25 +65,10 @@ class ManifestDataset(Dataset): ...@@ -73,25 +65,10 @@ class ManifestDataset(Dataset):
""" """
assert 'manifest' in config.data assert 'manifest' in config.data
assert config.data.manifest assert config.data.manifest
assert 'keep_transcription_text' in config.collator
if isinstance(config.data.augmentation_config, (str, bytes)):
if config.data.augmentation_config:
aug_file = io.open(
config.data.augmentation_config, mode='r', encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.data.augmentation_config
assert isinstance(aug_file, io.StringIO)
dataset = cls( dataset = cls(
manifest_path=config.data.manifest, manifest_path=config.data.manifest,
unit_type=config.data.unit_type,
vocab_filepath=config.data.vocab_filepath,
mean_std_filepath=config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config=aug_file.read(),
max_input_len=config.data.max_input_len, max_input_len=config.data.max_input_len,
min_input_len=config.data.min_input_len, min_input_len=config.data.min_input_len,
max_output_len=config.data.max_output_len, max_output_len=config.data.max_output_len,
...@@ -101,23 +78,8 @@ class ManifestDataset(Dataset): ...@@ -101,23 +78,8 @@ class ManifestDataset(Dataset):
) )
return dataset return dataset
def _read_vocab(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r', encoding='utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
return vocab_list
def __init__(self, def __init__(self,
manifest_path, manifest_path,
unit_type,
vocab_filepath,
mean_std_filepath,
spm_model_prefix=None,
augmentation_config='{}',
max_input_len=float('inf'), max_input_len=float('inf'),
min_input_len=0.0, min_input_len=0.0,
max_output_len=float('inf'), max_output_len=float('inf'),
...@@ -128,34 +90,16 @@ class ManifestDataset(Dataset): ...@@ -128,34 +90,16 @@ class ManifestDataset(Dataset):
Args: Args:
manifest_path (str): manifest josn file path manifest_path (str): manifest josn file path
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
""" """
super().__init__() super().__init__()
# self._rng = np.random.RandomState(random_seed)
# read manifest # read manifest
self._manifest = read_manifest( self._manifest = read_manifest(
manifest_path=manifest_path, manifest_path=manifest_path,
...@@ -167,51 +111,6 @@ class ManifestDataset(Dataset): ...@@ -167,51 +111,6 @@ class ManifestDataset(Dataset):
min_output_input_ratio=min_output_input_ratio) min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0]) self._manifest.sort(key=lambda x: x["feat_shape"][0])
# self._vocab_list = self._read_vocab(vocab_filepath)
# @property
# def manifest(self):
# return self._manifest
# @property
# def vocab_size(self):
# """Return the vocabulary size.
# Returns:
# int: Vocabulary size.
# """
# return len(self._vocab_list)
# @property
# def vocab_list(self):
# """Return the vocabulary in list.
# Returns:
# List[str]:
# """
# return self._vocab_list
# @property
# def vocab_dict(self):
# """Return the vocabulary in dict.
# Returns:
# Dict[str, int]:
# """
# vocab_dict = dict(
# [(token, idx) for (idx, token) in enumerate(self._vocab_list)])
# return vocab_dict
# @property
# def feature_size(self):
# """Return the audio feature size.
# Returns:
# int: audio feature size.
# """
# return self._manifest[0]["feat_shape"][-1]
def __len__(self): def __len__(self):
return len(self._manifest) return len(self._manifest)
......
...@@ -3,17 +3,20 @@ data: ...@@ -3,17 +3,20 @@ data:
train_manifest: data/manifest.train train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 64
min_input_len: 0.5 min_input_len: 0.5
max_input_len: 20.0 # second max_input_len: 20.0 # second
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400.0 max_output_len: 400.0
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
...@@ -32,7 +35,6 @@ data: ...@@ -32,7 +35,6 @@ data:
shuffle_method: batch_shuffle shuffle_method: batch_shuffle
num_workers: 2 num_workers: 2
# network architecture # network architecture
model: model:
cmvn_file: "data/mean_std.json" cmvn_file: "data/mean_std.json"
......
...@@ -2,10 +2,7 @@ ...@@ -2,10 +2,7 @@
data: data:
train_manifest: data/manifest.tiny train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 max_input_len: 27.0
min_output_len: 0.0 min_output_len: 0.0
...@@ -15,6 +12,9 @@ data: ...@@ -15,6 +12,9 @@ data:
collator: collator:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
random_seed: 0 random_seed: 0
spm_model_prefix: spm_model_prefix:
...@@ -43,7 +43,7 @@ model: ...@@ -43,7 +43,7 @@ model:
share_rnn_weights: True share_rnn_weights: True
training: training:
n_epoch: 23 n_epoch: 24
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 1.0
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -3,26 +3,20 @@ data: ...@@ -3,26 +3,20 @@ data:
train_manifest: data/manifest.tiny train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
batch_size: 4
min_input_len: 0.5 # second min_input_len: 0.5 # second
max_input_len: 20.0 # second max_input_len: 20.0 # second
min_output_len: 0.0 # tokens min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0 #2
collator: collator:
vocab_filepath: data/vocab.txt
mean_std_filepath: ""
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
random_seed: 0 random_seed: 0
spm_model_prefix: unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
specgram_type: fbank specgram_type: fbank
feat_dim: 80 feat_dim: 80
delta_delta: False delta_delta: False
...@@ -35,6 +29,12 @@ collator: ...@@ -35,6 +29,12 @@ collator:
target_dB: -20 target_dB: -20
dither: 1.0 dither: 1.0
keep_transcription_text: False keep_transcription_text: False
batch_size: 4
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0 #2
raw_wav: True # use raw_wav or kaldi feature
# network architecture # network architecture
model: model:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册