diff --git a/examples/librispeech/README.md b/examples/librispeech/README.md index 697cb91d46d03ba7a62707d0eb68f4a1048665ff..f46749b7f1a5f13b1f3a9037389a9bd9f62ed33b 100644 --- a/examples/librispeech/README.md +++ b/examples/librispeech/README.md @@ -1 +1,3 @@ -* s0 is for deepspeech +# ASR +* s0 is for deepspeech2 +* s1 is for U2 diff --git a/examples/tiny/s0/local/data.sh b/examples/tiny/s0/local/data.sh old mode 100644 new mode 100755 diff --git a/examples/tiny/s0/local/download_lm_en.sh b/examples/tiny/s0/local/download_lm_en.sh old mode 100644 new mode 100755 diff --git a/examples/tiny/s1/.gitignore b/examples/tiny/s1/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..08728cdc9c3e26ab42b8c941ef16fdb7fd0ae653 --- /dev/null +++ b/examples/tiny/s1/.gitignore @@ -0,0 +1,3 @@ +data +exp +log diff --git a/examples/tiny/s1/conf/augmentation.json b/examples/tiny/s1/conf/augmentation.json index 1987ad4245dcf5542f1e22a545c36899659acef9..c1078393d2f2f57fbcb3b48ce0975c2612c39dcb 100644 --- a/examples/tiny/s1/conf/augmentation.json +++ b/examples/tiny/s1/conf/augmentation.json @@ -1,4 +1,12 @@ [ + { + "type": "shift", + "params": { + "min_shift_ms": -5, + "max_shift_ms": 5 + }, + "prob": 1.0 + }, { "type": "speed", "params": { @@ -8,14 +16,6 @@ }, "prob": 0.0 }, - { - "type": "shift", - "params": { - "min_shift_ms": -5, - "max_shift_ms": 5 - }, - "prob": 1.0 - }, { "type": "specaug", "params": { diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index 60d0205bf7cee91df444cf4c610736126f03bddb..bd4279e2be5adc71937afb524c5739c689bae074 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -1,90 +1,115 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.tiny + dev_manifest: data/manifest.tiny + test_manifest: data/manifest.tiny + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_200' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 4 + min_input_len: 0.5 + max_input_len: 20.0 + min_output_len: 0.0 + max_output_len: 400.0 + min_output_input_ratio: 0.05 + max_output_input_ratio: 10.0 + raw_wav: True # use raw_wav or kaldi feature + specgram_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + # network architecture -# encoder related -encoder: conformer -encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 - normalize_before: true - cnn_module_kernel: 15 - use_cnn_module: True - activation_type: 'swish' - pos_enc_layer_type: 'rel_pos' - selfattention_layer_type: 'rel_selfattn' - causal: true - use_dynamic_chunk: true - cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster - use_dynamic_left_chunk: false +model: + cmvn_file: "data/mean_std.json" + cmvn_file_type: "json" + # encoder related + encoder: conformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + use_cnn_module: True + cnn_module_kernel: 15 + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + causal: True + use_dynamic_chunk: True + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false -# decoder related -decoder: transformer -decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 -# hybrid CTC/attention -model_conf: - ctc_weight: 0.3 - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false + # hybrid CTC/attention + model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false -# use raw_wav or kaldi feature -raw_wav: true +training: + n_epoch: 20 + accum_grad: 1 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 0.001 + weight_decay: 1e-06 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 1 -# feature extraction -collate_conf: - # waveform level config - wav_distortion_conf: - wav_dither: 1.0 - wav_distortion_rate: 0.0 - distortion_methods: [] - speed_perturb: true - feature_extraction_conf: - feature_type: 'fbank' - mel_bins: 80 - frame_shift: 10 - frame_length: 25 - using_pitch: false - # spec level config - # spec_swap: false - feature_dither: 0.0 # add dither [-feature_dither,feature_dither] on fbank feature - spec_aug: true - spec_aug_conf: - warp_for_time: False - num_t_mask: 2 - num_f_mask: 2 - max_t: 50 - max_f: 10 - max_w: 80 -# dataset related -dataset_conf: - max_length: 40960 - min_length: 0 - batch_type: 'static' # static or dynamic - # the size of batch_size should be set according to your gpu memory size, here we used 2080ti gpu whose memory size is 11GB - batch_size: 16 - sort: true +decoding: + batch_size: 64 + error_rate_type: wer + decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.0 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. -grad_clip: 5 -accum_grad: 1 -max_epoch: 180 -log_interval: 100 -optim: adam -optim_conf: - lr: 0.001 -scheduler: warmuplr # pytorch v1.1.0+ required -scheduler_conf: - warmup_steps: 25000 \ No newline at end of file diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index 8de0738115ae17b49a22161bfcf2c4967df2e796..ba60c273564540a627e7da3e9f4502ed4f2bb28a 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -1,83 +1,108 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.tiny + dev_manifest: data/manifest.tiny + test_manifest: data/manifest.tiny + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_200' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 4 + min_input_len: 0.5 # second + max_input_len: 20.0 # second + min_output_len: 0.0 # tokens + max_output_len: 400.0 # tokens + min_output_input_ratio: 0.05 + max_output_input_ratio: 10.0 + raw_wav: True # use raw_wav or kaldi feature + specgram_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + # network architecture -# encoder related -encoder: transformer -encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder architecture type - normalize_before: true - use_dynamic_chunk: true - use_dynamic_left_chunk: false +model: + cmvn_file: "data/mean_std.json" + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + use_dynamic_chunk: true + use_dynamic_left_chunk: false -# decoder related -decoder: transformer -decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 -# hybrid CTC/attention -model_conf: - ctc_weight: 0.3 - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false + # hybrid CTC/attention + model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false -# use raw_wav or kaldi feature -raw_wav: true -# feature extraction -collate_conf: - # waveform level config - wav_distortion_conf: - wav_dither: 0.0 - wav_distortion_rate: 0.0 - distortion_methods: [] - speed_perturb: false - feature_extraction_conf: - feature_type: 'fbank' - mel_bins: 80 - frame_shift: 10 - frame_length: 25 - using_pitch: false - # spec level config - # spec_swap: false - feature_dither: 0.0 # add dither [-feature_dither,feature_dither] on fbank feature - spec_aug: true - spec_aug_conf: - warp_for_time: False - num_t_mask: 2 - num_f_mask: 2 - max_t: 50 - max_f: 10 - max_w: 80 +training: + n_epoch: 20 + accum_grad: 1 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 0.002 + weight_decay: 1e-06 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 1 -# dataset related -dataset_conf: - max_length: 40960 - min_length: 0 - batch_type: 'static' # static or dynamic - # the size of batch_size should be set according to your gpu memory size, here we used 2080ti gpu whose memory size is 11GB - batch_size: 16 - sort: true +decoding: + batch_size: 64 + error_rate_type: wer + decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.0 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. -grad_clip: 5 -accum_grad: 1 -max_epoch: 180 -log_interval: 100 -optim: adam -optim_conf: - lr: 0.002 -scheduler: warmuplr # pytorch v1.1.0+ required -scheduler_conf: - warmup_steps: 25000 \ No newline at end of file diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 27de3360539b2a839c65b2d43eaa2618b0969acf..83f4f5af46100735f5fc78761c4cc21901fe15e7 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -31,46 +31,7 @@ data: keep_transcription_text: False sortagrad: True shuffle_method: batch_shuffle - num_workers: 0 - - -# # feature extraction -# collate_conf: -# # waveform level config -# wav_distortion_conf: -# wav_dither: 0.1 -# wav_distortion_rate: 0.0 -# distortion_methods: [] -# speed_perturb: true -# feature_extraction_conf: -# feature_type: 'fbank' -# mel_bins: 80 -# frame_shift: 10 -# frame_length: 25 -# using_pitch: false -# # spec level config -# # spec_swap: false -# feature_dither: 0.0 # add dither [-feature_dither,feature_dither] on fbank feature -# spec_aug: true -# spec_aug_conf: -# warp_for_time: False -# num_t_mask: 2 -# num_f_mask: 2 -# max_t: 50 -# max_f: 10 -# max_w: 80 - - -# # dataset related -# dataset_conf: -# max_length: 40960 -# min_length: 0 -# batch_type: 'static' # static or dynamic -# # the size of batch_size should be set according to your gpu memory size, here we used 2080ti gpu whose memory size is 11GB -# batch_size: 16 -# sort: true - - + num_workers: 2 # network architecture @@ -129,7 +90,7 @@ training: decoding: - batch_size: 16 + batch_size: 64 error_rate_type: wer decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 219389da0b6d7299aad189521fa44b45aefc71b6..3f3170bdfab088a85b7d365834e8f659c42f31bd 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -1,80 +1,106 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.tiny + dev_manifest: data/manifest.tiny + test_manifest: data/manifest.tiny + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_200' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 4 + min_input_len: 0.5 # second + max_input_len: 20.0 # second + min_output_len: 0.0 # tokens + max_output_len: 400.0 # tokens + min_output_input_ratio: 0.05 + max_output_input_ratio: 10.0 + raw_wav: True # use raw_wav or kaldi feature + specgram_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + # network architecture -# encoder related -encoder: transformer -encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder architecture type - normalize_before: true +model: + cmvn_file: "data/mean_std.json" + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true -# decoder related -decoder: transformer -decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 -# hybrid CTC/attention -model_conf: - ctc_weight: 0.3 - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false + # hybrid CTC/attention + model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false -# use raw_wav or kaldi feature -raw_wav: true -# feature extraction -collate_conf: - # waveform level config - wav_distortion_conf: - wav_dither: 0.1 - wav_distortion_rate: 0.0 - distortion_methods: [] - speed_perturb: true - feature_extraction_conf: - feature_type: 'fbank' - mel_bins: 80 - frame_shift: 10 - frame_length: 25 - using_pitch: false - # spec level config - feature_dither: 0.0 # add dither [-feature_dither,feature_dither] on fbank feature - spec_aug: true - spec_aug_conf: - warp_for_time: False - num_t_mask: 2 - num_f_mask: 2 - max_t: 50 - max_f: 10 - max_w: 80 +training: + n_epoch: 20 + accum_grad: 1 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 0.002 + weight_decay: 1e-06 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 1 -# dataset related -dataset_conf: - max_length: 40960 - min_length: 0 - batch_type: 'static' # static or dynamic - # the size of batch_size should be set according to your gpu memory size, here we used 2080ti gpu whose memory size is 11GB - batch_size: 26 - sort: true +decoding: + batch_size: 64 + error_rate_type: wer + decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.0 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. -grad_clip: 5 -accum_grad: 1 -max_epoch: 240 -log_interval: 100 -optim: adam -optim_conf: - lr: 0.002 -scheduler: warmuplr # pytorch v1.1.0+ required -scheduler_conf: - warmup_steps: 25000 \ No newline at end of file diff --git a/examples/tiny/s1/local/avg.sh b/examples/tiny/s1/local/avg.sh new file mode 100755 index 0000000000000000000000000000000000000000..8589e35308c97fdcb1abbb13bcf35ca4042c3ae2 --- /dev/null +++ b/examples/tiny/s1/local/avg.sh @@ -0,0 +1,23 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} ckpt_dir avg_num" + exit -1 +fi + +ckpt_dir=${1} +average_num=${2} +decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams + +python3 -u ${MAIN_ROOT}/utils/avg_model.py \ +--dst_model ${decode_checkpoint} \ +--ckpt_dir ${ckpt_dir} \ +--num ${average_num} \ +--val_best + +if [ $? -ne 0 ]; then + echo "Failed in avg ckpt!" + exit 1 +fi + +exit 0 \ No newline at end of file diff --git a/examples/tiny/s1/local/export.sh b/examples/tiny/s1/local/export.sh old mode 100644 new mode 100755 index 864ecb2d5bf754c1af1ea2c4787f7d24e66c83ef..b83a13a980d63473448d06cf39e6002fa0c55881 --- a/examples/tiny/s1/local/export.sh +++ b/examples/tiny/s1/local/export.sh @@ -1,18 +1,22 @@ #! /usr/bin/env bash -if [ $# != 2 ];then - echo "usage: export ckpt_path jit_model_path" +if [ $# != 3 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path" exit -1 fi +config_path=$1 +ckpt_path_prefix=$2 +jit_model_export_path=$3 + python3 -u ${BIN_DIR}/export.py \ ---config conf/conformer.yaml \ ---checkpoint_path ${1} \ ---export_path ${2} +--config ${config_path} \ +--checkpoint_path ${ckpt_path_prefix} \ +--export_path ${jit_model_export_path} if [ $? -ne 0 ]; then - echo "Failed in evaluation!" + echo "Failed in export!" exit 1 fi diff --git a/examples/tiny/s1/local/test.sh b/examples/tiny/s1/local/test.sh old mode 100644 new mode 100755 index e7ecc9b40d15971fdcc0df06688c6e9b89f8a76b..c5e61bff70d04e18ae1fcc1cab7708b6f3c7af51 --- a/examples/tiny/s1/local/test.sh +++ b/examples/tiny/s1/local/test.sh @@ -1,5 +1,20 @@ #! /usr/bin/env bash +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu != 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + # download language model #bash local/download_lm_en.sh #if [ $? -ne 0 ]; then @@ -7,11 +22,11 @@ #fi python3 -u ${BIN_DIR}/test.py \ ---device 'gpu' \ +--device ${device} \ --nproc 1 \ ---config conf/conformer.yaml \ ---result_file data/asr.result \ ---output ckpt +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/tiny/s1/local/train.sh b/examples/tiny/s1/local/train.sh old mode 100644 new mode 100755 index 2f73e9be9a7e8cb5c0517527613359f77e5db3b4..3ed5338088ca550092c7c01b4a4b6ea7d7f47549 --- a/examples/tiny/s1/local/train.sh +++ b/examples/tiny/s1/local/train.sh @@ -1,18 +1,31 @@ #! /usr/bin/env bash +if [ $# != 2 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" + exit -1 +fi + ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." +config_path=$1 +ckpt_name=$2 +device=gpu +if [ ngpu != 0 ];then + device=cpu +fi + +mkdir -p exp + python3 -u ${BIN_DIR}/train.py \ ---device 'gpu' \ +--device ${device} \ --nproc ${ngpu} \ ---config conf/conformer.yaml \ ---output ckpt-${1} +--config ${config_path} \ +--output exp/${ckpt_name} if [ $? -ne 0 ]; then echo "Failed in training!" exit 1 fi - exit 0 diff --git a/examples/tiny/s1/run.sh b/examples/tiny/s1/run.sh index 2b5ed5308993a904b7ee499f8ddc4fe25839d1d4..3ee16e3fc8234d4d5599cfc9e356adc34a429e11 100644 --- a/examples/tiny/s1/run.sh +++ b/examples/tiny/s1/run.sh @@ -2,15 +2,19 @@ set -e source path.sh +source ${MAIN_ROOT}/utils/parse_options.sh # prepare data bash ./local/data.sh -# train model -bash ./local/train.sh +# train model, all `ckpt` under `exp` dir +CUDA_VISIBLE_DEVICES=0 ./local/train.sh conf/conformer.yaml test -# test model -bash ./local/test.sh +# test ckpt 1 +CUDA_VISIBLE_DEVICES=0 ./local/test.sh conf/conformer.yaml exp/test/checkpoints/1 -# infer model -bash ./local/infer.sh +# avg 1 best model +./local/avg.sh exp/test/checkpoints 1 + +# export ckpt 1 +./local/export.sh conf/conformer.yaml exp/test/checkpoints/1 exp/test/checkpoints/1.jit.model \ No newline at end of file