提交 6f7917b7 编写于 作者: H Hui Zhang

fix streaming asr

上级 f07f57a3
......@@ -4,7 +4,7 @@
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
port: 8091
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online']
......
......@@ -53,7 +53,7 @@ class PaddleASRConnectionHanddler:
logger.info(
"create an paddle asr connection handler to process the websocket connection"
)
self.config = asr_engine.config
self.config = asr_engine.config # server config
self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine
......@@ -249,10 +249,13 @@ class PaddleASRConnectionHanddler:
def reset(self):
if "deepspeech2" in self.model_type:
# for deepspeech2
self.chunk_state_h_box = copy.deepcopy(
self.asr_engine.executor.chunk_state_h_box)
self.chunk_state_c_box = copy.deepcopy(
self.asr_engine.executor.chunk_state_c_box)
# init state
self.chunk_state_h_box = np.zeros(
(self.model_config .num_rnn_layers, 1, self.model_config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1, self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1)
self.device = None
......@@ -803,36 +806,6 @@ class ASRServerExecutor(ASRExecutor):
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
# decoder
logger.info("ASR engine start to create the ctc decoder instance")
self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2,
blank_id=self.config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder
logger.info("ASR engine start to init the ctc decoder")
cfg = self.config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# init state box
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
......@@ -847,15 +820,11 @@ class ASRServerExecutor(ASRExecutor):
model_dict = paddle.load(self.am_model)
self.model.set_state_dict(model_dict)
logger.info("create the transformer like model success")
# update the ctc decoding
self.searcher = CTCPrefixBeamSearch(self.config.decode)
self.transformer_decode_reset()
else:
raise ValueError(f"Not support: {model_type}")
return True
class ASREngine(BaseEngine):
"""ASR server resource
......@@ -881,8 +850,8 @@ class ASREngine(BaseEngine):
self.executor = ASRServerExecutor()
try:
default_dev = paddle.get_device()
paddle.set_device(self.config.get("device", default_dev))
self.device = self.config.get("device", paddle.get_device())
paddle.set_device(self.device)
except BaseException as e:
logger.error(
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册