未验证 提交 4a133619 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1356 from Jackwaterveg/CLI

[CLI] asr, Add Deepspeech2 online and offline model
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import io
import os import os
import sys import sys
from typing import List from typing import List
...@@ -23,9 +22,9 @@ import librosa ...@@ -23,9 +22,9 @@ import librosa
import numpy as np import numpy as np
import paddle import paddle
import soundfile import soundfile
import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from ..download import get_path_from_url
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register from ..utils import cli_register
...@@ -64,14 +63,47 @@ pretrained_models = { ...@@ -64,14 +63,47 @@ pretrained_models = {
'ckpt_path': 'ckpt_path':
'exp/transformer/checkpoints/avg_10', 'exp/transformer/checkpoints/avg_10',
}, },
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'd5e076217cf60486519f72c217d21b9b',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
} }
model_alias = { model_alias = {
"deepspeech2offline": "paddlespeech.s2t.models.ds2:DeepSpeech2Model", "deepspeech2offline":
"deepspeech2online": "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "paddlespeech.s2t.models.ds2:DeepSpeech2Model",
"conformer": "paddlespeech.s2t.models.u2:U2Model", "deepspeech2online":
"transformer": "paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"wenetspeech": "paddlespeech.s2t.models.u2:U2Model", "conformer":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
"wenetspeech":
"paddlespeech.s2t.models.u2:U2Model",
} }
...@@ -95,7 +127,8 @@ class ASRExecutor(BaseExecutor): ...@@ -95,7 +127,8 @@ class ASRExecutor(BaseExecutor):
'--lang', '--lang',
type=str, type=str,
default='zh', default='zh',
help='Choose model language. zh or en, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]') help='Choose model language. zh or en, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]'
)
self.parser.add_argument( self.parser.add_argument(
"--sample_rate", "--sample_rate",
type=int, type=int,
...@@ -111,7 +144,10 @@ class ASRExecutor(BaseExecutor): ...@@ -111,7 +144,10 @@ class ASRExecutor(BaseExecutor):
'--decode_method', '--decode_method',
type=str, type=str,
default='attention_rescoring', default='attention_rescoring',
choices=['ctc_greedy_search', 'ctc_prefix_beam_search', 'attention', 'attention_rescoring'], choices=[
'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention',
'attention_rescoring'
],
help='only support transformer and conformer model') help='only support transformer and conformer model')
self.parser.add_argument( self.parser.add_argument(
'--ckpt_path', '--ckpt_path',
...@@ -187,13 +223,21 @@ class ASRExecutor(BaseExecutor): ...@@ -187,13 +223,21 @@ class ASRExecutor(BaseExecutor):
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.vocab = self.config.vocab_filepath self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = os.path.join(res_path, self.config.decode.lang_model_path) self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
self.collate_fn_test = SpeechCollator.from_config(self.config) self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, unit_type=self.config.unit_type, vocab=self.vocab)
vocab=self.vocab) lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5']
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
self.config.spm_model_prefix = os.path.join(self.res_path, self.config.spm_model_prefix) self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
...@@ -319,6 +363,13 @@ class ASRExecutor(BaseExecutor): ...@@ -319,6 +363,13 @@ class ASRExecutor(BaseExecutor):
""" """
return self._outputs["result"] return self._outputs["result"]
def download_lm(self, url, lm_dir, md5sum):
download_path = get_path_from_url(
url=url,
root_dir=lm_dir,
md5sum=md5sum,
decompress=False, )
def _pcm16to32(self, audio): def _pcm16to32(self, audio):
assert (audio.dtype == np.int16) assert (audio.dtype == np.int16)
audio = audio.astype("float32") audio = audio.astype("float32")
...@@ -411,7 +462,7 @@ class ASRExecutor(BaseExecutor): ...@@ -411,7 +462,7 @@ class ASRExecutor(BaseExecutor):
try: try:
res = self(audio_file, model, lang, sample_rate, config, ckpt_path, res = self(audio_file, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device) decode_method, force_yes, device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
except Exception as e: except Exception as e:
...@@ -435,7 +486,8 @@ class ASRExecutor(BaseExecutor): ...@@ -435,7 +486,8 @@ class ASRExecutor(BaseExecutor):
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
self._check(audio_file, sample_rate, force_yes) self._check(audio_file, sample_rate, force_yes)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, decode_method, ckpt_path) self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path)
self.preprocess(model, audio_file) self.preprocess(model, audio_file)
self.infer(model) self.infer(model)
res = self.postprocess() # Retrieve result of asr. res = self.postprocess() # Retrieve result of asr.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册