transformer.yaml 3.0 KB
Newer Older
H
Hui Zhang 已提交
1 2 3
# https://yaml.org/type/float.html
# network architecture
model:
H
Hui Zhang 已提交
4 5
    cmvn_file:  
    cmvn_file_type: "json"
H
Hui Zhang 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
    # encoder related
    encoder: transformer
    encoder_conf:
        output_size: 256    # dimension of attention
        attention_heads: 4
        linear_units: 2048  # the number of units of position-wise feed forward
        num_blocks: 12      # the number of encoder blocks
        dropout_rate: 0.1
        positional_dropout_rate: 0.1
        attention_dropout_rate: 0.0
        input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
        normalize_before: true

    # decoder related
    decoder: transformer
    decoder_conf:
        attention_heads: 4
        linear_units: 2048
        num_blocks: 6
        dropout_rate: 0.1
        positional_dropout_rate: 0.1
        self_attention_dropout_rate: 0.0
        src_attention_dropout_rate: 0.0

    # hybrid CTC/attention
    model_conf:
        ctc_weight: 0.3
H
Hui Zhang 已提交
33
        ctc_dropoutrate: 0.0
H
Hui Zhang 已提交
34
        ctc_grad_norm_type: batch
H
Hui Zhang 已提交
35 36 37
        lsm_weight: 0.1     # label smoothing option
        length_normalized_loss: false

H
Hui Zhang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
data:
  train_manifest: data/manifest.train
  dev_manifest: data/manifest.dev
  test_manifest: data/manifest.test-clean

collator:
  vocab_filepath: data/lang_char/train_960_unigram5000_units.txt
  unit_type: spm
  spm_model_prefix: data/lang_char/train_960_unigram5000
  feat_dim: 83
  stride_ms: 10.0
  window_ms: 25.0
  sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs 
  batch_size: 30 
  maxlen_in: 512  # if input length  > maxlen-in, batchsize is automatically reduced
  maxlen_out: 150  # if output length > maxlen-out, batchsize is automatically reduced
  minibatches: 0 # for debug
  batch_count: auto
  batch_bins: 0 
  batch_frames_in: 0
  batch_frames_out: 0
  batch_frames_inout: 0
  augmentation_config: conf/augmentation.json
  num_workers: 0
  subsampling_factor: 1
  num_encs: 1

H
Hui Zhang 已提交
65 66 67 68 69 70 71 72 73

training:
  n_epoch: 120
  accum_grad: 2
  log_interval: 100
  checkpoint:
    kbest_n: 50
    latest_n: 5

H
Hui Zhang 已提交
74 75 76 77 78 79 80 81 82
optim: adam
optim_conf:
  global_grad_clip: 5.0
  weight_decay: 1.0e-06
scheduler: warmuplr     # pytorch v1.1.0+ required
scheduler_conf:
  lr: 0.004
  warmup_steps: 25000
  lr_decay: 1.0
H
Hui Zhang 已提交
83 84

decoding:
H
Hui Zhang 已提交
85
  batch_size: 1
H
Hui Zhang 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
  error_rate_type: wer
  decoding_method: attention  # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
  lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
  alpha: 2.5
  beta: 0.3
  beam_size: 10
  cutoff_prob: 1.0
  cutoff_top_n: 0
  num_proc_bsearch: 8
  ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
  decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
      # <0: for decoding, use full chunk.
      # >0: for decoding, use fixed chunk size as set.
      # 0: used for training, it's prohibited here. 
  num_decoding_left_chunks: -1  # number of left chunks for decoding. Defaults to -1.
  simulate_streaming: False  # simulate streaming inference. Defaults to False.