提交 2a42421a 编写于 作者: H huangyuxin

cli add ds2-librispeech offline, fix versionm, test=asr

上级 4128f4d6
...@@ -91,6 +91,20 @@ pretrained_models = { ...@@ -91,6 +91,20 @@ pretrained_models = {
'lm_md5': 'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3' '29e02312deb2e59b3c8686c7966d4fe3'
}, },
"deepspeech2offline_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'f5666c81ad015c8de03aac2bc92e5762',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
'lm_md5':
'099a601759d467cd0a8523ff939819c5'
},
} }
model_alias = { model_alias = {
...@@ -328,18 +342,15 @@ class ASRExecutor(BaseExecutor): ...@@ -328,18 +342,15 @@ class ASRExecutor(BaseExecutor):
audio = self._inputs["audio"] audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"] audio_len = self._inputs["audio_len"]
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
result_transcripts = self.model.decode( decode_batch_size = audio.shape[0]
audio, self.model.decoder.init_decoder(
audio_len, decode_batch_size, self.text_feature.vocab_list,
self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
decoding_method=cfg.decoding_method, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
lang_model_path=cfg.lang_model_path, cfg.num_proc_bsearch)
beam_alpha=cfg.alpha,
beam_beta=cfg.beta, result_transcripts = self.model.decode(audio, audio_len)
beam_size=cfg.beam_size, self.model.decoder.del_decoder()
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
self._outputs["result"] = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
......
...@@ -34,7 +34,7 @@ from .entry import commands ...@@ -34,7 +34,7 @@ from .entry import commands
try: try:
from .. import __version__ from .. import __version__
except ImportError: except ImportError:
__version__ = 0.0.0 # for develop branch __version__ = "0.0.0" # for develop branch
requests.adapters.DEFAULT_RETRIES = 3 requests.adapters.DEFAULT_RETRIES = 3
......
...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): ...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False
......
...@@ -41,4 +41,4 @@ def repeat(N, fn): ...@@ -41,4 +41,4 @@ def repeat(N, fn):
MultiSequential MultiSequential
Repeated model instance. Repeated model instance.
""" """
return MultiSequential(* [fn(n) for n in range(N)]) return MultiSequential(*[fn(n) for n in range(N)])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册