diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index a8d452a9918739d0232793ec209b37a4fa1d1787..2f0f5c24bb89a629322a961fa49aec1520666baf 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -11,74 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from yacs.config import CfgNode as CN +from yacs.config import CfgNode +from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester +from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset from deepspeech.models.deepspeech2 import DeepSpeech2Model -_C = CN() -_C.data = CN( - dict( - train_manifest="", - dev_manifest="", - test_manifest="", - unit_type="char", - vocab_filepath="", - spm_model_prefix="", - mean_std_filepath="", - augmentation_config="", - max_duration=float('inf'), - min_duration=0.0, - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delat_delta=False, # 'mfcc', 'fbank' - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - random_seed=0, - keep_transcription_text=False, - batch_size=32, # batch size - num_workers=0, # data loader workers - sortagrad=False, # sorted in first epoch when True - shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' - )) +_C = CfgNode() -_C.model = CN( - dict( - num_conv_layers=2, #Number of stacking convolution layers. - num_rnn_layers=3, #Number of stacking RNN layers. - rnn_layer_size=1024, #RNN layer size (number of RNN cells). - use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) +_C.data = ManifestDataset.params() -DeepSpeech2Model.params(_C.model) +_C.collator = SpeechCollator.params() -_C.training = CN( - dict( - lr=5e-4, # learning rate - lr_decay=1.0, # learning rate decay - weight_decay=1e-6, # the coeff of weight decay - global_grad_clip=5.0, # the global norm clip - n_epoch=50, # train epochs - )) +_C.model = DeepSpeech2Model.params() -_C.decoding = CN( - dict( - alpha=2.5, # Coef of LM for beam search. - beta=0.3, # Coef of WC for beam search. - cutoff_prob=1.0, # Cutoff probability for pruning. - cutoff_top_n=40, # Cutoff number for pruning. - lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. - decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy - error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' - num_proc_bsearch=8, # # of CPUs for beam search. - beam_size=500, # Beam search width. - batch_size=128, # decoding batch size - )) +_C.training = DeepSpeech2Trainer.params() + +_C.decoding = DeepSpeech2Tester.params() def get_cfg_defaults(): diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 468bc65216cf6256ca96d1c4eef0663f673edaf7..deb8752b72deb7efc52bd1f86becf6de7d4d0073 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -15,11 +15,13 @@ import time from collections import defaultdict from pathlib import Path +from typing import Optional import numpy as np import paddle from paddle import distributed as dist from paddle.io import DataLoader +from yacs.config import CfgNode from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset @@ -33,11 +35,26 @@ from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Log - logger = Log(__name__).getlog() class DeepSpeech2Trainer(Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # training config + default = CfgNode( + dict( + lr=5e-4, # learning rate + lr_decay=1.0, # learning rate decay + weight_decay=1e-6, # the coeff of weight decay + global_grad_clip=5.0, # the global norm clip + n_epoch=50, # train epochs + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + def __init__(self, config, args): super().__init__(config, args) @@ -55,7 +72,7 @@ class DeepSpeech2Trainer(Trainer): 'train_loss': float(loss), } msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.data.batch_size) + msg += "batch size: {}, ".format(self.config.collator.batch_size) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) logger.info(msg) @@ -102,8 +119,8 @@ class DeepSpeech2Trainer(Trainer): def setup_model(self): config = self.config model = DeepSpeech2Model( - feat_size=self.train_loader.dataset.feature_size, - dict_size=self.train_loader.dataset.vocab_size, + feat_size=self.train_loader.collate_fn.feature_size, + dict_size=self.train_loader.collate_fn.vocab_size, num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, @@ -137,50 +154,73 @@ class DeepSpeech2Trainer(Trainer): def setup_dataloader(self): config = self.config.clone() config.defrost() - config.data.keep_transcription_text = False + config.collator.keep_transcription_text = False config.data.manifest = config.data.train_manifest train_dataset = ManifestDataset.from_config(config) config.data.manifest = config.data.dev_manifest - config.data.augmentation_config = "" dev_dataset = ManifestDataset.from_config(config) if self.parallel: batch_sampler = SortagradDistributedBatchSampler( train_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) else: batch_sampler = SortagradBatchSampler( train_dataset, shuffle=True, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) - collate_fn = SpeechCollator(keep_transcription_text=False) + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, - collate_fn=collate_fn, - num_workers=config.data.num_workers) + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers) self.valid_loader = DataLoader( dev_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn) + collate_fn=collate_fn_dev) logger.info("Setup train/valid Dataloader!") class DeepSpeech2Tester(DeepSpeech2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # testing config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + def __init__(self, config, args): super().__init__(config, args) @@ -193,13 +233,19 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None): + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer - vocab_list = self.test_loader.dataset.vocab_list + vocab_list = self.test_loader.collate_fn.vocab_list target_transcripts = self.ordid2token(texts, texts_len) result_transcripts = self.model.decode( @@ -215,7 +261,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) - for utt, target, result in zip(utts, target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref @@ -245,7 +292,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): with open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch - metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout) + metrics = self.compute_metrics(utts, audio, audio_len, texts, + texts_len, fout) errors_sum += metrics['errors_sum'] len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] @@ -272,7 +320,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): infer_model = DeepSpeech2InferModel.from_pretrained( self.test_loader.dataset, self.config, self.args.checkpoint_path) infer_model.eval() - feat_dim = self.test_loader.dataset.feature_size + feat_dim = self.test_loader.collate_fn.feature_size static_model = paddle.jit.to_static( infer_model, input_spec=[ @@ -308,8 +356,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def setup_model(self): config = self.config model = DeepSpeech2Model( - feat_size=self.test_loader.dataset.feature_size, - dict_size=self.test_loader.dataset.vocab_size, + feat_size=self.test_loader.collate_fn.feature_size, + dict_size=self.test_loader.collate_fn.vocab_size, num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, @@ -324,8 +372,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): # return raw text config.data.manifest = config.data.test_manifest - config.data.keep_transcription_text = True - config.data.augmentation_config = "" # filter test examples, will cause less examples, but no mismatch with training # and can use large batch size , save training time, so filter test egs now. # config.data.min_input_len = 0.0 # second @@ -336,13 +382,15 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): # config.data.max_output_input_ratio = float('inf') test_dataset = ManifestDataset.from_config(config) + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" # return text ord id self.test_loader = DataLoader( test_dataset, batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=SpeechCollator(keep_transcription_text=True)) + collate_fn=SpeechCollator.from_config(config)) logger.info("Setup test Dataloader!") def setup_output_dir(self): diff --git a/deepspeech/exps/u2/config.py b/deepspeech/exps/u2/config.py index 5a0b53f9a3850a9af5d700e5f036d5215e35a170..4ec7bd1908715bb583987bd1e2aae2165eadc683 100644 --- a/deepspeech/exps/u2/config.py +++ b/deepspeech/exps/u2/config.py @@ -15,6 +15,7 @@ from yacs.config import CfgNode from deepspeech.exps.u2.model import U2Tester from deepspeech.exps.u2.model import U2Trainer +from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.u2 import U2Model @@ -22,6 +23,8 @@ _C = CfgNode() _C.data = ManifestDataset.params() +_C.collator = SpeechCollator.params() + _C.model = U2Model.params() _C.training = U2Trainer.params() diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 334d6bc8e94a47d3fe4ba644f24965df3ea45579..055518755d1bc48db8ad5c34b050c91cb7031c35 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -78,7 +78,8 @@ class U2Trainer(Trainer): start = time.time() utt, audio, audio_len, text, text_len = batch_data - loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad loss.backward() @@ -100,7 +101,7 @@ class U2Trainer(Trainer): if (batch_index + 1) % train_conf.log_interval == 0: msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.data.batch_size) + msg += "batch size: {}, ".format(self.config.collator.batch_size) msg += "accum: {}, ".format(train_conf.accum_grad) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) @@ -121,7 +122,8 @@ class U2Trainer(Trainer): total_loss = 0.0 for i, batch in enumerate(self.valid_loader): utt, audio, audio_len, text, text_len = batch - loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) if paddle.isfinite(loss): num_utts = batch[1].shape[0] num_seen_utts += num_utts @@ -211,51 +213,52 @@ class U2Trainer(Trainer): def setup_dataloader(self): config = self.config.clone() config.defrost() - config.data.keep_transcription_text = False + config.collator.keep_transcription_text = False # train/valid dataset, return token ids config.data.manifest = config.data.train_manifest train_dataset = ManifestDataset.from_config(config) config.data.manifest = config.data.dev_manifest - config.data.augmentation_config = "" dev_dataset = ManifestDataset.from_config(config) - collate_fn = SpeechCollator(keep_transcription_text=False) + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) + if self.parallel: batch_sampler = SortagradDistributedBatchSampler( train_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) else: batch_sampler = SortagradBatchSampler( train_dataset, shuffle=True, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, - collate_fn=collate_fn, - num_workers=config.data.num_workers, ) + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers, ) self.valid_loader = DataLoader( dev_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn) + collate_fn=collate_fn_dev) # test dataset, return raw text config.data.manifest = config.data.test_manifest - config.data.keep_transcription_text = True - config.data.augmentation_config = "" # filter test examples, will cause less examples, but no mismatch with training # and can use large batch size , save training time, so filter test egs now. # config.data.min_input_len = 0.0 # second @@ -264,22 +267,25 @@ class U2Trainer(Trainer): # config.data.max_output_len = float('inf') # tokens # config.data.min_output_input_ratio = 0.00 # config.data.max_output_input_ratio = float('inf') + test_dataset = ManifestDataset.from_config(config) # return text ord id + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" self.test_loader = DataLoader( test_dataset, batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=SpeechCollator(keep_transcription_text=True)) + collate_fn=SpeechCollator.from_config(config)) logger.info("Setup train/valid/test Dataloader!") def setup_model(self): config = self.config model_conf = config.model model_conf.defrost() - model_conf.input_dim = self.train_loader.dataset.feature_size - model_conf.output_dim = self.train_loader.dataset.vocab_size + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size model_conf.freeze() model = U2Model.from_config(model_conf) @@ -368,14 +374,20 @@ class U2Tester(U2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None): + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer start_time = time.time() - text_feature = self.test_loader.dataset.text_feature + text_feature = self.test_loader.collate_fn.text_feature target_transcripts = self.ordid2token(texts, texts_len) result_transcripts = self.model.decode( audio, @@ -395,7 +407,8 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for utt, target, result in zip(utts, target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref @@ -423,7 +436,7 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") - stride_ms = self.test_loader.dataset.stride_ms + stride_ms = self.test_loader.collate_fn.stride_ms error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 @@ -496,7 +509,7 @@ class U2Tester(U2Trainer): infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, self.config.model.clone(), self.args.checkpoint_path) - feat_dim = self.test_loader.dataset.feature_size + feat_dim = self.test_loader.collate_fn.feature_size input_spec = [ paddle.static.InputSpec( shape=[None, feat_dim, None], diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index e6761cb52ec954f6fce15d06e0b63918dcfa4f62..0fbbc5648835b34f8a5c0cf3ea5d0b06c0dcaefa 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -107,7 +107,6 @@ class SpeechFeaturizer(object): @property def vocab_size(self): """Return the vocabulary size. - Returns: int: Vocabulary size. """ @@ -116,7 +115,6 @@ class SpeechFeaturizer(object): @property def vocab_list(self): """Return the vocabulary in list. - Returns: List[str]: """ @@ -125,7 +123,6 @@ class SpeechFeaturizer(object): @property def vocab_dict(self): """Return the vocabulary in dict. - Returns: Dict[str, int]: """ @@ -134,7 +131,6 @@ class SpeechFeaturizer(object): @property def feature_size(self): """Return the audio feature size. - Returns: int: audio feature size. """ @@ -143,7 +139,6 @@ class SpeechFeaturizer(object): @property def stride_ms(self): """time length in `ms` unit per frame - Returns: float: time(ms)/frame """ @@ -152,7 +147,6 @@ class SpeechFeaturizer(object): @property def text_feature(self): """Return the text feature object. - Returns: TextFeaturizer: object. """ diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 3bec9875f43b8242d4e90ce921b5edeb19f10414..1061f97cfabc8e259ceefc353b264bde3d21a758 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -11,8 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io +from collections import namedtuple +from typing import Optional + import numpy as np +from yacs.config import CfgNode +from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline +from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer +from deepspeech.frontend.normalizer import FeatureNormalizer +from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.utility import IGNORE_ID from deepspeech.io.utility import pad_sequence from deepspeech.utils.log import Log @@ -21,17 +30,220 @@ __all__ = ["SpeechCollator"] logger = Log(__name__).getlog() +# namedtupe need global for pickle. +TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + class SpeechCollator(): - def __init__(self, keep_transcription_text=True): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + augmentation_config="", + random_seed=0, + mean_std_filepath="", + unit_type="char", + vocab_filepath="", + spm_model_prefix="", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, # feature dither + keep_transcription_text=False)) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + @classmethod + def from_config(cls, config): + """Build a SpeechCollator object from a config. + + Args: + config (yacs.config.CfgNode): configs object. + + Returns: + SpeechCollator: collator object. """ - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one bach. + assert 'augmentation_config' in config.collator + assert 'keep_transcription_text' in config.collator + assert 'mean_std_filepath' in config.collator + assert 'vocab_filepath' in config.collator + assert 'specgram_type' in config.collator + assert 'n_fft' in config.collator + assert config.collator - if ``keep_transcription_text`` is False, text is token ids else is raw string. + if isinstance(config.collator.augmentation_config, (str, bytes)): + if config.collator.augmentation_config: + aug_file = io.open( + config.collator.augmentation_config, + mode='r', + encoding='utf8') + else: + aug_file = io.StringIO(initial_value='{}', newline='') + else: + aug_file = config.collator.augmentation_config + assert isinstance(aug_file, io.StringIO) + + speech_collator = cls( + aug_file=aug_file, + random_seed=0, + mean_std_filepath=config.collator.mean_std_filepath, + unit_type=config.collator.unit_type, + vocab_filepath=config.collator.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix, + specgram_type=config.collator.specgram_type, + feat_dim=config.collator.feat_dim, + delta_delta=config.collator.delta_delta, + stride_ms=config.collator.stride_ms, + window_ms=config.collator.window_ms, + n_fft=config.collator.n_fft, + max_freq=config.collator.max_freq, + target_sample_rate=config.collator.target_sample_rate, + use_dB_normalization=config.collator.use_dB_normalization, + target_dB=config.collator.target_dB, + dither=config.collator.dither, + keep_transcription_text=config.collator.keep_transcription_text) + return speech_collator + + def __init__( + self, + aug_file, + mean_std_filepath, + vocab_filepath, + spm_model_prefix, + random_seed=0, + unit_type="char", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, + keep_transcription_text=True): + """SpeechCollator Collator + + Args: + unit_type(str): token unit type, e.g. char, word, spm + vocab_filepath (str): vocab file path. + mean_std_filepath (str): mean and std file path, which suffix is *.npy + spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. + augmentation_config (str, optional): augmentation json str. Defaults to '{}'. + stride_ms (float, optional): stride size in ms. Defaults to 10.0. + window_ms (float, optional): window size in ms. Defaults to 20.0. + n_fft (int, optional): fft points for rfft. Defaults to None. + max_freq (int, optional): max cut freq. Defaults to None. + target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. + specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. + feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. + delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. + use_dB_normalization (bool, optional): do dB normalization. Defaults to True. + target_dB (int, optional): target dB. Defaults to -20. + random_seed (int, optional): for random generator. Defaults to 0. + keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. + if ``keep_transcription_text`` is False, text is token ids else is raw string. + + Do augmentations + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one batch. """ self._keep_transcription_text = keep_transcription_text + self._local_data = TarLocalData(tar2info={}, tar2object={}) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=aug_file.read(), random_seed=random_seed) + + self._normalizer = FeatureNormalizer( + mean_std_filepath) if mean_std_filepath else None + + self._stride_ms = stride_ms + self._target_sample_rate = target_sample_rate + + self._speech_featurizer = SpeechFeaturizer( + unit_type=unit_type, + vocab_filepath=vocab_filepath, + spm_model_prefix=spm_model_prefix, + specgram_type=specgram_type, + feat_dim=feat_dim, + delta_delta=delta_delta, + stride_ms=stride_ms, + window_ms=window_ms, + n_fft=n_fft, + max_freq=max_freq, + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB, + dither=dither) + + def _parse_tar(self, file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + def _subfile_from_tar(self, file): + """Get subfile object from tar. + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + if 'tar2info' not in self._local_data.__dict__: + self._local_data.tar2info = {} + if 'tar2object' not in self._local_data.__dict__: + self._local_data.tar2object = {} + if tarpath not in self._local_data.tar2info: + object, infoes = self._parse_tar(tarpath) + self._local_data.tar2info[tarpath] = infoes + self._local_data.tar2object[tarpath] = object + return self._local_data.tar2object[tarpath].extractfile( + self._local_data.tar2info[tarpath][filename]) + + def process_utterance(self, audio_file, transcript): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of audio file. + :type audio_file: str | file + :param transcript: Transcription text. + :type transcript: str + :return: Tuple of audio feature tensor and data of transcription part, + where transcription part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + if isinstance(audio_file, str) and audio_file.startswith('tar:'): + speech_segment = SpeechSegment.from_file( + self._subfile_from_tar(audio_file), transcript) + else: + speech_segment = SpeechSegment.from_file(audio_file, transcript) + + # audio augment + self._augmentation_pipeline.transform_audio(speech_segment) + + specgram, transcript_part = self._speech_featurizer.featurize( + speech_segment, self._keep_transcription_text) + if self._normalizer: + specgram = self._normalizer.apply(specgram) + + # specgram augment + specgram = self._augmentation_pipeline.transform_feature(specgram) + return specgram, transcript_part + def __call__(self, batch): """batch examples @@ -53,6 +265,7 @@ class SpeechCollator(): text_lens = [] utts = [] for utt, audio, text in batch: + audio, text = self.process_utterance(audio, text) #utt utts.append(utt) # audio @@ -79,3 +292,31 @@ class SpeechCollator(): texts, padding_value=IGNORE_ID).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64) return utts, padded_audios, audio_lens, padded_texts, text_lens + + @property + def manifest(self): + return self._manifest + + @property + def vocab_size(self): + return self._speech_featurizer.vocab_size + + @property + def vocab_list(self): + return self._speech_featurizer.vocab_list + + @property + def vocab_dict(self): + return self._speech_featurizer.vocab_dict + + @property + def text_feature(self): + return self._speech_featurizer.text_feature + + @property + def feature_size(self): + return self._speech_featurizer.feature_size + + @property + def stride_ms(self): + return self._speech_featurizer.stride_ms diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 1cf3827d344ad57bfa18d0b3ce227cc5b2f6e6f8..3fc4e98872246a23af78bd1990df58b4ed4e7691 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -11,20 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io -import tarfile -import time -from collections import namedtuple from typing import Optional -import numpy as np from paddle.io import Dataset from yacs.config import CfgNode -from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline -from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer -from deepspeech.frontend.normalizer import FeatureNormalizer -from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.utility import read_manifest from deepspeech.utils.log import Log @@ -34,49 +25,19 @@ __all__ = [ logger = Log(__name__).getlog() -# namedtupe need global for pickle. -TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) - class ManifestDataset(Dataset): @classmethod def params(cls, config: Optional[CfgNode]=None) -> CfgNode: default = CfgNode( dict( - train_manifest="", - dev_manifest="", - test_manifest="", manifest="", - unit_type="char", - vocab_filepath="", - spm_model_prefix="", - mean_std_filepath="", - augmentation_config="", max_input_len=27.0, min_input_len=0.0, max_output_len=float('inf'), min_output_len=0.0, max_output_input_ratio=float('inf'), - min_output_input_ratio=0.0, - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - raw_wav=True, # use raw_wav or kaldi feature - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delta_delta=False, # 'mfcc', 'fbank' - dither=1.0, # feature dither - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - random_seed=0, - keep_transcription_text=False, - batch_size=32, # batch size - num_workers=0, # data loader workers - sortagrad=False, # sorted in first epoch when True - shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' - )) + min_output_input_ratio=0.0, )) if config is not None: config.merge_from_other_cfg(default) @@ -94,128 +55,38 @@ class ManifestDataset(Dataset): """ assert 'manifest' in config.data assert config.data.manifest - assert 'keep_transcription_text' in config.data - - if isinstance(config.data.augmentation_config, (str, bytes)): - if config.data.augmentation_config: - aug_file = io.open( - config.data.augmentation_config, mode='r', encoding='utf8') - else: - aug_file = io.StringIO(initial_value='{}', newline='') - else: - aug_file = config.data.augmentation_config - assert isinstance(aug_file, io.StringIO) dataset = cls( manifest_path=config.data.manifest, - unit_type=config.data.unit_type, - vocab_filepath=config.data.vocab_filepath, - mean_std_filepath=config.data.mean_std_filepath, - spm_model_prefix=config.data.spm_model_prefix, - augmentation_config=aug_file.read(), max_input_len=config.data.max_input_len, min_input_len=config.data.min_input_len, max_output_len=config.data.max_output_len, min_output_len=config.data.min_output_len, max_output_input_ratio=config.data.max_output_input_ratio, - min_output_input_ratio=config.data.min_output_input_ratio, - stride_ms=config.data.stride_ms, - window_ms=config.data.window_ms, - n_fft=config.data.n_fft, - max_freq=config.data.max_freq, - target_sample_rate=config.data.target_sample_rate, - specgram_type=config.data.specgram_type, - feat_dim=config.data.feat_dim, - delta_delta=config.data.delta_delta, - dither=config.data.dither, - use_dB_normalization=config.data.use_dB_normalization, - target_dB=config.data.target_dB, - random_seed=config.data.random_seed, - keep_transcription_text=config.data.keep_transcription_text) + min_output_input_ratio=config.data.min_output_input_ratio, ) return dataset def __init__(self, manifest_path, - unit_type, - vocab_filepath, - mean_std_filepath, - spm_model_prefix=None, - augmentation_config='{}', max_input_len=float('inf'), min_input_len=0.0, max_output_len=float('inf'), min_output_len=0.0, max_output_input_ratio=float('inf'), - min_output_input_ratio=0.0, - stride_ms=10.0, - window_ms=20.0, - n_fft=None, - max_freq=None, - target_sample_rate=16000, - specgram_type='linear', - feat_dim=None, - delta_delta=False, - dither=1.0, - use_dB_normalization=True, - target_dB=-20, - random_seed=0, - keep_transcription_text=False): + min_output_input_ratio=0.0): """Manifest Dataset Args: manifest_path (str): manifest josn file path - unit_type(str): token unit type, e.g. char, word, spm - vocab_filepath (str): vocab file path. - mean_std_filepath (str): mean and std file path, which suffix is *.npy - spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. - augmentation_config (str, optional): augmentation json str. Defaults to '{}'. max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. - stride_ms (float, optional): stride size in ms. Defaults to 10.0. - window_ms (float, optional): window size in ms. Defaults to 20.0. - n_fft (int, optional): fft points for rfft. Defaults to None. - max_freq (int, optional): max cut freq. Defaults to None. - target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. - specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. - feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. - delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. - use_dB_normalization (bool, optional): do dB normalization. Defaults to True. - target_dB (int, optional): target dB. Defaults to -20. - random_seed (int, optional): for random generator. Defaults to 0. - keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. + """ super().__init__() - self._stride_ms = stride_ms - self._target_sample_rate = target_sample_rate - - self._normalizer = FeatureNormalizer( - mean_std_filepath) if mean_std_filepath else None - self._augmentation_pipeline = AugmentationPipeline( - augmentation_config=augmentation_config, random_seed=random_seed) - self._speech_featurizer = SpeechFeaturizer( - unit_type=unit_type, - vocab_filepath=vocab_filepath, - spm_model_prefix=spm_model_prefix, - specgram_type=specgram_type, - feat_dim=feat_dim, - delta_delta=delta_delta, - stride_ms=stride_ms, - window_ms=window_ms, - n_fft=n_fft, - max_freq=max_freq, - target_sample_rate=target_sample_rate, - use_dB_normalization=use_dB_normalization, - target_dB=target_dB, - dither=dither) - - self._rng = np.random.RandomState(random_seed) - self._keep_transcription_text = keep_transcription_text - # for caching tar files info - self._local_data = TarLocalData(tar2info={}, tar2object={}) # read manifest self._manifest = read_manifest( @@ -228,125 +99,9 @@ class ManifestDataset(Dataset): min_output_input_ratio=min_output_input_ratio) self._manifest.sort(key=lambda x: x["feat_shape"][0]) - @property - def manifest(self): - return self._manifest - - @property - def vocab_size(self): - return self._speech_featurizer.vocab_size - - @property - def vocab_list(self): - return self._speech_featurizer.vocab_list - - @property - def vocab_dict(self): - return self._speech_featurizer.vocab_dict - - @property - def text_feature(self): - return self._speech_featurizer.text_feature - - @property - def feature_size(self): - return self._speech_featurizer.feature_size - - @property - def stride_ms(self): - return self._speech_featurizer.stride_ms - - def _parse_tar(self, file): - """Parse a tar file to get a tarfile object - and a map containing tarinfoes - """ - result = {} - f = tarfile.open(file) - for tarinfo in f.getmembers(): - result[tarinfo.name] = tarinfo - return f, result - - def _subfile_from_tar(self, file): - """Get subfile object from tar. - - It will return a subfile object from tar file - and cached tar file info for next reading request. - """ - tarpath, filename = file.split(':', 1)[1].split('#', 1) - if 'tar2info' not in self._local_data.__dict__: - self._local_data.tar2info = {} - if 'tar2object' not in self._local_data.__dict__: - self._local_data.tar2object = {} - if tarpath not in self._local_data.tar2info: - object, infoes = self._parse_tar(tarpath) - self._local_data.tar2info[tarpath] = infoes - self._local_data.tar2object[tarpath] = object - return self._local_data.tar2object[tarpath].extractfile( - self._local_data.tar2info[tarpath][filename]) - - def process_utterance(self, audio_file, transcript): - """Load, augment, featurize and normalize for speech data. - - :param audio_file: Filepath or file object of audio file. - :type audio_file: str | file - :param transcript: Transcription text. - :type transcript: str - :return: Tuple of audio feature tensor and data of transcription part, - where transcription part could be token ids or text. - :rtype: tuple of (2darray, list) - """ - start_time = time.time() - if isinstance(audio_file, str) and audio_file.startswith('tar:'): - speech_segment = SpeechSegment.from_file( - self._subfile_from_tar(audio_file), transcript) - else: - speech_segment = SpeechSegment.from_file(audio_file, transcript) - load_wav_time = time.time() - start_time - #logger.debug(f"load wav time: {load_wav_time}") - - # audio augment - start_time = time.time() - self._augmentation_pipeline.transform_audio(speech_segment) - audio_aug_time = time.time() - start_time - #logger.debug(f"audio augmentation time: {audio_aug_time}") - - start_time = time.time() - specgram, transcript_part = self._speech_featurizer.featurize( - speech_segment, self._keep_transcription_text) - if self._normalizer: - specgram = self._normalizer.apply(specgram) - feature_time = time.time() - start_time - #logger.debug(f"audio & test feature time: {feature_time}") - - # specgram augment - start_time = time.time() - specgram = self._augmentation_pipeline.transform_feature(specgram) - feature_aug_time = time.time() - start_time - #logger.debug(f"audio feature augmentation time: {feature_aug_time}") - return specgram, transcript_part - - def _instance_reader_creator(self, manifest): - """ - Instance reader creator. Create a callable function to produce - instances of data. - - Instance: a tuple of ndarray of audio spectrogram and a list of - token indices for transcript. - """ - - def reader(): - for instance in manifest: - inst = self.process_utterance(instance["feat"], - instance["text"]) - yield inst - - return reader - def __len__(self): return len(self._manifest) def __getitem__(self, idx): instance = self._manifest[idx] - feat, text =self.process_utterance(instance["feat"], - instance["text"]) - return instance["utt"], feat, text + return instance["utt"], instance["feat"], instance["text"] diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index bcfddaef0e4f397a901f916b59f1a31c30bf0ac8..238e2d35c5492097868c2ff8ea1ff941bd27dc9e 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -905,7 +905,6 @@ class U2InferModel(U2Model): def __init__(self, configs: dict): super().__init__(configs) - def forward(self, feats, feats_lengths, diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 8b08ee308743a2550d1e85926ea1001e5c8e73c8..54ce240e7c4eb1e53e81477994ab1ae27f0c1db3 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -5,29 +5,34 @@ data: test_manifest: data/manifest.test mean_std_filepath: data/mean_std.json vocab_filepath: data/vocab.txt - augmentation_config: conf/augmentation.json - batch_size: 64 # one gpu min_input_len: 0.0 max_input_len: 27.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf + + +collator: + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: specgram_type: linear - target_sample_rate: 16000 - max_freq: None - n_fft: None + feat_dim: + delta_delta: False stride_ms: 10.0 window_ms: 20.0 - delta_delta: False - dither: 1.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 use_dB_normalization: True target_dB: -20 - random_seed: 0 + dither: 1.0 keep_transcription_text: False sortagrad: True shuffle_method: batch_shuffle num_workers: 0 + batch_size: 64 # one gpu model: num_conv_layers: 2 diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index b880f858755e1716d4304e4ee475cc8f5190f81b..116c919279134bf7ca7f3aa9c50171ca1488be82 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -3,17 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - vocab_filepath: data/vocab.txt - unit_type: 'char' - spm_model_prefix: '' - augmentation_config: conf/augmentation.json - batch_size: 64 min_input_len: 0.5 max_input_len: 20.0 # second min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'char' + spm_model_prefix: '' + augmentation_config: conf/augmentation.json + batch_size: 64 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -32,7 +35,6 @@ data: shuffle_method: batch_shuffle num_workers: 2 - # network architecture model: cmvn_file: "data/mean_std.json" diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index dd9ce51f032c8b5161e853a6bdc1bd209cee4cf8..6737d1b75a88fb3b00ca8d7cbb168878de21766b 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -2,32 +2,38 @@ data: train_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny - test_manifest: data/manifest.tiny - mean_std_filepath: data/mean_std.json - vocab_filepath: data/vocab.txt - augmentation_config: conf/augmentation.json - batch_size: 4 + test_manifest: data/manifest.tiny min_input_len: 0.0 max_input_len: 27.0 min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + + +collator: + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: specgram_type: linear - target_sample_rate: 16000 - max_freq: None - n_fft: None + feat_dim: + delta_delta: False stride_ms: 10.0 window_ms: 20.0 - delta_delta: False - dither: 1.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 use_dB_normalization: True target_dB: -20 - random_seed: 0 + dither: 1.0 keep_transcription_text: False sortagrad: True shuffle_method: batch_shuffle num_workers: 0 + batch_size: 4 model: num_conv_layers: 2 @@ -37,7 +43,7 @@ model: share_rnn_weights: True training: - n_epoch: 20 + n_epoch: 24 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 0a7cf3be845b68a87799904f7fdf167813fb1794..250995faadc8b4e668ed717d70b9ebadcdc67b60 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -3,35 +3,37 @@ data: train_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny test_manifest: data/manifest.tiny - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_200' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 4 min_input_len: 0.5 # second max_input_len: 20.0 # second min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 - raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + +collator: + vocab_filepath: data/vocab.txt + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + random_seed: 0 + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_200' + specgram_type: fbank feat_dim: 80 delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None stride_ms: 10.0 - window_ms: 25.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 use_dB_normalization: True target_dB: -20 - random_seed: 0 + dither: 1.0 keep_transcription_text: False + batch_size: 4 sortagrad: True shuffle_method: batch_shuffle - num_workers: 2 + num_workers: 0 #2 + raw_wav: True # use raw_wav or kaldi feature # network architecture @@ -70,7 +72,7 @@ model: training: - n_epoch: 2 + n_epoch: 21 accum_grad: 1 global_grad_clip: 5.0 optim: adam