diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index eb39636c8850df92c4c77e8fbdbdd0b2401fb94b..bc48e692c9bf40fe0ea89debbcba10f79cfe56b0 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -24,6 +24,7 @@ from .utils import add_results_to_json from deepspeech.exps import dynamic_import_tester from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.models.asr_interface import ASRInterface +from deepspeech.models.lm_interface import dynamic_import_lm from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -31,11 +32,15 @@ logger = Log(__name__).getlog() # NOTE: you need this func to generate our sphinx doc +def get_config(config_path): + confs = CfgNode(new_allowed=True) + confs.merge_from_file(config_path) + return confs + + def load_trained_model(args): args.nprocs = args.ngpu - confs = CfgNode() - confs.set_new_allowed(True) - confs.merge_from_file(args.model_conf) + confs = get_config(args.model_conf) class_obj = dynamic_import_tester(args.model_name) exp = class_obj(confs, args) with exp.eval(): @@ -46,19 +51,11 @@ def load_trained_model(args): return model, char_list, exp, confs -def get_config(config_path): - stream = open(config_path, mode='r', encoding="utf-8") - config = yaml.load(stream, Loader=yaml.FullLoader) - stream.close() - return config - - def load_trained_lm(args): lm_args = get_config(args.rnnlm_conf) - # NOTE: for a compatibility with less than 0.5.0 version models - lm_model_module = getattr(lm_args, "model_module", "default") + lm_model_module = lm_args.model_module lm_class = dynamic_import_lm(lm_model_module) - lm = lm_class(lm_args.model) + lm = lm_class(**lm_args.model) model_dict = paddle.load(args.rnnlm) lm.set_state_dict(model_dict) return lm diff --git a/examples/librispeech/s2/local/recog.sh b/examples/librispeech/s2/local/recog.sh index f0e961097544fa3e25ed5fd73277502aba3e8aae..e2578ba637403807b1ffd79f60b37240ec1870b9 100755 --- a/examples/librispeech/s2/local/recog.sh +++ b/examples/librispeech/s2/local/recog.sh @@ -11,9 +11,9 @@ tag= decode_config=conf/decode/decode.yaml # lm params -lang_model=transformerLM.pdparams -lmexpdir=exp/lm/transformer rnnlm_config_path=conf/lm/transformer.yaml +lmexpdir=exp/lm +lang_model=rnnlm.pdparams lmtag='transformer' train_set=train_960 @@ -53,6 +53,9 @@ if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then fi echo "chunk mode: ${chunk_mode}" echo "decode conf: ${decode_config}" +echo "lm conf: ${rnnlm_config_path}" +echo "lm model: ${lmexpdir}/${lang_model}" + # download language model #bash local/download_lm_en.sh @@ -61,6 +64,13 @@ echo "decode conf: ${decode_config}" #fi +# download rnnlm +mkdir -p ${lmexpdir} +if [ ! -f ${lmexpdir}/${lang_model} ]; then + wget -c -O ${lmexpdir}/${lang_model} https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams +fi + + pids=() # initialize pids for dmethd in join_ctc; do diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index 61172d25c9ff1ed73124a7a90d3e9b9cf3144a53..146f133d8c310f3ff5a05aaed54f8cce369e0886 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -37,12 +37,9 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] && ${use_lm} == true; then +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # join ctc decoder, use transformerlm to score - if [ ! -f exp/lm/transformer/transformerLM.pdparams ]; then - wget https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams exp/lm/transformer/ - fi - bash local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt} + ./local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt} fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then