diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index 050a50b00a70e2a9546886b28d530a6b5e694fe3..7d2250fc705c9652bcdf745b2aa4c1f115dac2b7 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -11,80 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from yacs.config import CfgNode as CN +from yacs.config import CfgNode from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.collator import SpeechCollator +from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer +from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester -_C = CN() -_C.data = CN( - dict( - train_manifest="", - dev_manifest="", - test_manifest="", - max_duration=float('inf'), - min_duration=0.0, - )) -_C.collator =CN( - dict( - unit_type="char", - vocab_filepath="", - spm_model_prefix="", - mean_std_filepath="", - augmentation_config="", - random_seed=0, - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delta_delta=False, # 'mfcc', 'fbank' - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - dither=1.0, # feature dither - keep_transcription_text=False, - batch_size=32, # batch size - num_workers=0, # data loader workers - sortagrad=False, # sorted in first epoch when True - shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' - )) +_C = CfgNode() -_C.model = CN( - dict( - num_conv_layers=2, #Number of stacking convolution layers. - num_rnn_layers=3, #Number of stacking RNN layers. - rnn_layer_size=1024, #RNN layer size (number of RNN cells). - use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) +_C.data = ManifestDataset.params() +_C.collator = SpeechCollator.params() -DeepSpeech2Model.params(_C.model) +_C.model = DeepSpeech2Model.params() -_C.training = CN( - dict( - lr=5e-4, # learning rate - lr_decay=1.0, # learning rate decay - weight_decay=1e-6, # the coeff of weight decay - global_grad_clip=5.0, # the global norm clip - n_epoch=50, # train epochs - )) +_C.training = DeepSpeech2Trainer.params() -_C.decoding = CN( - dict( - alpha=2.5, # Coef of LM for beam search. - beta=0.3, # Coef of WC for beam search. - cutoff_prob=1.0, # Cutoff probability for pruning. - cutoff_top_n=40, # Cutoff number for pruning. - lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. - decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy - error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' - num_proc_bsearch=8, # # of CPUs for beam search. - beam_size=500, # Beam search width. - batch_size=128, # decoding batch size - )) +_C.decoding = DeepSpeech2Tester.params() def get_cfg_defaults(): diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 1eefc871bb6036f85afbcc0a61ca7013a41eb91d..c11d1e2591df334e8b9616ac9c8a1440a475c469 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -34,10 +34,28 @@ from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Log +from typing import Optional +from yacs.config import CfgNode logger = Log(__name__).getlog() class DeepSpeech2Trainer(Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # training config + default = CfgNode( + dict( + lr=5e-4, # learning rate + lr_decay=1.0, # learning rate decay + weight_decay=1e-6, # the coeff of weight decay + global_grad_clip=5.0, # the global norm clip + n_epoch=50, # train epochs + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + def __init__(self, config, args): super().__init__(config, args) @@ -184,6 +202,27 @@ class DeepSpeech2Trainer(Trainer): class DeepSpeech2Tester(DeepSpeech2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # testing config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + def __init__(self, config, args): super().__init__(config, args)