提交 0c5dbbee 编写于 作者: X xiongxinlei

add conformer ctc prefix beam search decoding method, test=doc

上级 9d20a10b
...@@ -12,8 +12,9 @@ ...@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from paddlespeech.s2t.utils.utility import log_add
from typing import Optional from typing import Optional
from collections import defaultdict
import numpy as np import numpy as np
import paddle import paddle
from numpy import float32 from numpy import float32
...@@ -23,10 +24,14 @@ from paddlespeech.cli.asr.infer import ASRExecutor ...@@ -23,10 +24,14 @@ from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.asr.infer import model_alias from paddlespeech.cli.asr.infer import model_alias
from paddlespeech.cli.asr.infer import pretrained_models from paddlespeech.cli.asr.infer import pretrained_models
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.modules.mask import mask_finished_preds
from paddlespeech.s2t.modules.mask import mask_finished_scores
from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
...@@ -57,17 +62,17 @@ pretrained_models = { ...@@ -57,17 +62,17 @@ pretrained_models = {
}, },
"conformer2online_aishell-zh-16k": { "conformer2online_aishell-zh-16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz',
'md5': 'md5':
'4814e52e0fc2fd48899373f95c84b0c9', '7989b3248c898070904cf042fd656003',
'cfg_path': 'cfg_path':
'exp/chunk_conformer//conf/config.yaml', 'model.yaml',
'ckpt_path': 'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30/', 'exp/chunk_conformer/checkpoints/multi_cn',
'model': 'model':
'exp/chunk_conformer/checkpoints/avg_30.pdparams', 'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'params': 'params':
'exp/chunk_conformer/checkpoints/avg_30.pdparams', 'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'lm_url': 'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5': 'lm_md5':
...@@ -81,6 +86,23 @@ class ASRServerExecutor(ASRExecutor): ...@@ -81,6 +86,23 @@ class ASRServerExecutor(ASRExecutor):
super().__init__() super().__init__()
pass pass
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(self, def _init_from_path(self,
model_type: str='wenetspeech', model_type: str='wenetspeech',
am_model: Optional[os.PathLike]=None, am_model: Optional[os.PathLike]=None,
...@@ -101,7 +123,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -101,7 +123,7 @@ class ASRServerExecutor(ASRExecutor):
logger.info(f"Load the pretrained model, tag = {tag}") logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path self.res_path = res_path
self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/paddlespeech/server/tests/asr/online/conf/config.yaml" self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
# self.cfg_path = os.path.join(res_path, # self.cfg_path = os.path.join(res_path,
# pretrained_models[tag]['cfg_path']) # pretrained_models[tag]['cfg_path'])
...@@ -147,8 +169,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -147,8 +169,7 @@ class ASRServerExecutor(ASRExecutor):
if self.config.spm_model_prefix: if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join( self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix) self.res_path, self.config.spm_model_prefix)
self.config.vocab_filepath = os.path.join( self.vocab = self.config.vocab_filepath
self.res_path, self.config.vocab_filepath)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
...@@ -203,19 +224,31 @@ class ASRServerExecutor(ASRExecutor): ...@@ -203,19 +224,31 @@ class ASRServerExecutor(ASRExecutor):
model_conf = self.config model_conf = self.config
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
self.model = model self.model = model
self.model.eval()
# load model
model_dict = paddle.load(self.am_model)
self.model.set_state_dict(model_dict)
logger.info("create the transformer like model success") logger.info("create the transformer like model success")
# update the ctc decoding
self.searcher = None
self.transformer_decode_reset()
def reset_decoder_and_chunk(self): def reset_decoder_and_chunk(self):
"""reset decoder and chunk state for an new audio """reset decoder and chunk state for an new audio
""" """
self.decoder.reset_decoder(batch_size=1) if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
# init state box, for new audio request self.decoder.reset_decoder(batch_size=1)
self.chunk_state_h_box = np.zeros( # init state box, for new audio request
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size), self.chunk_state_h_box = np.zeros(
dtype=float32) (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
self.chunk_state_c_box = np.zeros( dtype=float32)
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size), self.chunk_state_c_box = np.zeros(
dtype=float32) (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
self.transformer_decode_reset()
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
"""decode one chunk """decode one chunk
...@@ -275,24 +308,137 @@ class ASRServerExecutor(ASRExecutor): ...@@ -275,24 +308,137 @@ class ASRServerExecutor(ASRExecutor):
logger.info( logger.info(
f"we will use the transformer like model : {self.model_type}" f"we will use the transformer like model : {self.model_type}"
) )
cfg = self.config.decode self.advanced_decoding(x_chunk, x_chunk_lens)
result_transcripts = self.model.decode( self.update_result()
x_chunk,
x_chunk_lens, return self.result_transcripts[0]
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
return result_transcripts[0][0]
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")
def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens):
logger.info("start to decode with advanced_decoding method")
encoder_out, encoder_mask = self.decode_forward(xs)
self.ctc_prefix_beam_search(xs, encoder_out, encoder_mask)
def decode_forward(self, xs):
logger.info("get the model out from the feat")
cfg = self.config.decode
decoding_chunk_size = cfg.decoding_chunk_size
num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
logger.info("start to do model forward")
outputs = []
# num_frames - context + 1 ensure that current frame can get context window
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache)
outputs.append(y)
self.offset += y.shape[1]
ys = paddle.cat(outputs, 1)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks
def transformer_decode_reset(self):
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.hyps = None
self.offset = 0
self.cur_hyps = None
self.hyps = None
def ctc_prefix_beam_search(self, xs, encoder_out, encoder_mask, blank_id=0):
# decode
logger.info("start to ctc prefix search")
device = xs.place
cfg = self.config.decode
batch_size = xs.shape[0]
beam_size = cfg.beam_size
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
self.hyps = [hyps[0][0]]
logger.info("ctc prefix search success")
return hyps, encoder_out
def update_result(self):
logger.info("update the final result")
self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in self.hyps
]
self.result_tokenids = [hyp for hyp in self.hyps]
def extract_feat(self, samples, sample_rate): def extract_feat(self, samples, sample_rate):
"""extract feat """extract feat
...@@ -304,9 +450,10 @@ class ASRServerExecutor(ASRExecutor): ...@@ -304,9 +450,10 @@ class ASRServerExecutor(ASRExecutor):
x_chunk (numpy.array): shape[B, T, D] x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B] x_chunk_lens (numpy.array): shape[B]
""" """
# pcm16 -> pcm 32
samples = pcm2float(samples)
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
# pcm16 -> pcm 32
samples = pcm2float(samples)
# read audio # read audio
speech_segment = SpeechSegment.from_pcm( speech_segment = SpeechSegment.from_pcm(
samples, sample_rate, transcript=" ") samples, sample_rate, transcript=" ")
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import json import json
import numpy as np import numpy as np
import json
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import WebSocket from fastapi import WebSocket
from fastapi import WebSocketDisconnect from fastapi import WebSocketDisconnect
...@@ -86,16 +85,21 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -86,16 +85,21 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
asr_results = "" asr_results = ""
frames = chunk_buffer.frame_generator(message) # frames = chunk_buffer.frame_generator(message)
for frame in frames: # for frame in frames:
# get the pcm data from the bytes # # get the pcm data from the bytes
samples = np.frombuffer(frame.bytes, dtype=np.int16) # samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate # sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples, # x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate) # sample_rate)
asr_engine.run(x_chunk, x_chunk_lens) # asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess() # asr_results = asr_engine.postprocess()
samples = np.frombuffer(message, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
# asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results} resp = {'asr_results': asr_results}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册