diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 51ef1de47988c39d9720431989e1880f84995ac6..de5ff5f44d5e403f4b64d55664367ec4055610be 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -367,20 +367,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): dtype='int64'), # audio_length, [B] ]) elif self.args.model_type == 'online': - static_model = paddle.jit.to_static( - infer_model, - input_spec=[ - paddle.static.InputSpec( - shape=[None, None, - feat_dim], #[B, chunk_size, feat_dim] - dtype='float32'), # audio, [B,T,D] - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - paddle.static.InputSpec( - shape=[None, None, None], dtype='float32'), - paddle.static.InputSpec( - shape=[None, None, None], dtype='float32') - ]) + static_model = DeepSpeech2InferModelOnline.export(infer_model, + feat_dim) else: raise Exception("wrong model type") logger.info(f"Export code: {static_model.forward.code}") diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index b42ac8ec17215bed9f29492ec9a77e518ddf8b4e..ad8a0506f4fb229e9b022045f315b520c125f1f3 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -424,3 +424,20 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box) probs_chunk = self.decoder.softmax(eouts_chunk) return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box + + @classmethod + def export(self, infer_model, feat_dim): + static_model = paddle.jit.to_static( + infer_model, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, feat_dim], #[B, chunk_size, feat_dim] + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32'), + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32') + ]) + return static_model diff --git a/examples/aishell/s0/local/export.sh b/examples/aishell/s0/local/export.sh index f99a15bade1c89f968e84a6c10d500466f884d5b..2e09e5f5e76a7f7cdf9cca8fbb91d66bb48aea0c 100755 --- a/examples/aishell/s0/local/export.sh +++ b/examples/aishell/s0/local/export.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: $0 config_path ckpt_prefix jit_model_path" +if [ $# != 4 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path model_type" exit -1 fi @@ -11,6 +11,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 +model_type=$4 device=gpu if [ ${ngpu} == 0 ];then @@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ ---export_path ${jit_model_export_path} - +--export_path ${jit_model_export_path} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in export!" diff --git a/examples/aishell/s0/local/test.sh b/examples/aishell/s0/local/test.sh index fd9cb5661fad7b05356a98635080ab0d2e3327c3..9fd0bc8d5bcdded1d33990a8ae20101d0a538441 100755 --- a/examples/aishell/s0/local/test.sh +++ b/examples/aishell/s0/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" exit -1 fi @@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then fi config_path=$1 ckpt_prefix=$2 +model_type=$3 # download language model bash local/download_lm_ch.sh @@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \ --nproc 1 \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/aishell/s0/local/train.sh b/examples/aishell/s0/local/train.sh index f6bd2c98359b67292f9654c69ed11fc1a6720046..c6a631800378640434dbe952035840895c1b23b5 100755 --- a/examples/aishell/s0/local/train.sh +++ b/examples/aishell/s0/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# != 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" exit -1 fi @@ -10,6 +10,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 +model_type=$3 device=gpu if [ ${ngpu} == 0 ];then @@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index c9708dcc90df357568282ace7700ceb13b7883d3..7cd63999ce57c27ca9f915601f69435284f87233 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -7,6 +7,7 @@ stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml avg_num=1 +model_type=offline source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -21,7 +22,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -31,10 +32,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} fi diff --git a/examples/librispeech/s0/local/export.sh b/examples/librispeech/s0/local/export.sh index f99a15bade1c89f968e84a6c10d500466f884d5b..2e09e5f5e76a7f7cdf9cca8fbb91d66bb48aea0c 100755 --- a/examples/librispeech/s0/local/export.sh +++ b/examples/librispeech/s0/local/export.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: $0 config_path ckpt_prefix jit_model_path" +if [ $# != 4 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path model_type" exit -1 fi @@ -11,6 +11,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 +model_type=$4 device=gpu if [ ${ngpu} == 0 ];then @@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ ---export_path ${jit_model_export_path} - +--export_path ${jit_model_export_path} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in export!" diff --git a/examples/librispeech/s0/local/test.sh b/examples/librispeech/s0/local/test.sh index 16a5e9ef0d2045f48d4acafe18dec3cfef992e52..b5b68c599c45ab50aa12ee35c120e02fb68740b4 100755 --- a/examples/librispeech/s0/local/test.sh +++ b/examples/librispeech/s0/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" exit -1 fi @@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then fi config_path=$1 ckpt_prefix=$2 +model_type=$3 # download language model bash local/download_lm_en.sh @@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \ --nproc 1 \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/librispeech/s0/local/train.sh b/examples/librispeech/s0/local/train.sh index f3eb98daf6da6c818de1cd2fc2f56e96ff6a926c..039b9cea456b41dd345d5594d142abe7ca165385 100755 --- a/examples/librispeech/s0/local/train.sh +++ b/examples/librispeech/s0/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# != 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" exit -1 fi @@ -10,6 +10,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 +model_type=$3 device=gpu if [ ${ngpu} == 0 ];then @@ -23,7 +24,8 @@ python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/librispeech/s0/run.sh b/examples/librispeech/s0/run.sh index 6553e073ded9ba4105fd9b69cac520a9e87fa963..c7902a56a882abecddd36b6440f9c88e134d628b 100755 --- a/examples/librispeech/s0/run.sh +++ b/examples/librispeech/s0/run.sh @@ -6,6 +6,7 @@ stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml avg_num=30 +model_type=offline source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} @@ -19,7 +20,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} ${model_type} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -29,10 +30,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} fi