asr_engine.py 44.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
X
xiongxinlei 已提交
14
import copy
15
import os
X
xiongxinlei 已提交
16
import time
17
from typing import Optional
X
xiongxinlei 已提交
18

H
Hui Zhang 已提交
19
import numpy as np
20
import paddle
H
Hui Zhang 已提交
21
from numpy import float32
22 23 24
from yacs.config import CfgNode

from paddlespeech.cli.asr.infer import ASRExecutor
25
from paddlespeech.cli.asr.infer import model_alias
26
from paddlespeech.cli.log import logger
27
from paddlespeech.cli.utils import download_and_decompress
28 29
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
H
Hui Zhang 已提交
30
from paddlespeech.s2t.frontend.speech import SpeechSegment
31
from paddlespeech.s2t.modules.ctc import CTCDecoder
32 33
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
34 35
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence
36
from paddlespeech.s2t.utils.utility import UpdateConfig
37
from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch
38
from paddlespeech.server.engine.base_engine import BaseEngine
L
lym0302 已提交
39
from paddlespeech.server.utils.audio_process import pcm2float
40 41 42 43 44 45 46
from paddlespeech.server.utils.paddle_predictor import init_predictor

__all__ = ['ASREngine']

pretrained_models = {
    "deepspeech2online_aishell-zh-16k": {
        'url':
H
huangyuxin 已提交
47
        'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz',
48
        'md5':
H
huangyuxin 已提交
49
        '98b87b171b7240b7cae6e07d8d0bc9be',
50 51 52 53 54 55 56 57 58 59 60 61 62
        'cfg_path':
        'model.yaml',
        'ckpt_path':
        'exp/deepspeech2_online/checkpoints/avg_1',
        'model':
        'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
        'params':
        'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
        'lm_url':
        'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
        'lm_md5':
        '29e02312deb2e59b3c8686c7966d4fe3'
    },
X
xiongxinlei 已提交
63
    "conformer_online_multicn-zh-16k": {
64
        'url':
65
        'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
66
        'md5':
67
        '0ac93d390552336f2a906aec9e33c5fa',
68
        'cfg_path':
69
        'model.yaml',
70
        'ckpt_path':
71
        'exp/chunk_conformer/checkpoints/multi_cn',
72
        'model':
73
        'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
74
        'params':
75
        'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
76 77 78 79 80
        'lm_url':
        'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
        'lm_md5':
        '29e02312deb2e59b3c8686c7966d4fe3'
    },
81 82
}

83

84
# ASR server connection process class
85 86
class PaddleASRConnectionHanddler:
    def __init__(self, asr_engine):
87 88 89 90 91
        """Init a Paddle ASR Connection Handler instance

        Args:
            asr_engine (ASREngine): the global asr engine
        """
92
        super().__init__()
93 94 95
        logger.info(
            "create an paddle asr connection handler to process the websocket connection"
        )
96 97 98 99 100 101 102 103
        self.config = asr_engine.config
        self.model_config = asr_engine.executor.config
        self.asr_engine = asr_engine

        self.init()
        self.reset()

    def init(self):
X
xiongxinlei 已提交
104
        # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
105
        self.model_type = self.asr_engine.executor.model_type
X
xiongxinlei 已提交
106 107 108 109
        self.sample_rate = self.asr_engine.executor.sample_rate
        # tokens to text
        self.text_feature = self.asr_engine.executor.text_feature

110
        if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
111 112
            from paddlespeech.s2t.io.collator import SpeechCollator
            self.am_predictor = self.asr_engine.executor.am_predictor
X
xiongxinlei 已提交
113

114 115 116 117 118 119 120 121
            self.collate_fn_test = SpeechCollator.from_config(self.model_config)
            self.decoder = CTCDecoder(
                odim=self.model_config.output_dim,  # <blank> is in  vocab
                enc_n_units=self.model_config.rnn_layer_size * 2,
                blank_id=self.model_config.blank_id,
                dropout_rate=0.0,
                reduction=True,  # sum
                batch_average=True,  # sum / batch_size
X
xiongxinlei 已提交
122 123
                grad_norm_type=self.model_config.get('ctc_grad_norm_type',
                                                     None))
124 125 126 127 128 129 130 131

            cfg = self.model_config.decode
            decode_batch_size = 1  # for online
            self.decoder.init_decoder(
                decode_batch_size, self.text_feature.vocab_list,
                cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
                cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
                cfg.num_proc_bsearch)
X
xiongxinlei 已提交
132
            # frame window samples length and frame shift samples length
133

X
xiongxinlei 已提交
134
            self.win_length = int(self.model_config.window_ms / 1000 *
X
xiongxinlei 已提交
135
                                  self.sample_rate)
X
xiongxinlei 已提交
136 137
            self.n_shift = int(self.model_config.stride_ms / 1000 *
                               self.sample_rate)
138

X
xiongxinlei 已提交
139
        elif "conformer" in self.model_type or "transformer" in self.model_type:
140 141
            # acoustic model
            self.model = self.asr_engine.executor.model
142 143

            # ctc decoding config
144 145 146
            self.ctc_decode_config = self.asr_engine.executor.config.decode
            self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)

147
            # extract feat, new only fbank in conformer model
148 149 150
            self.preprocess_conf = self.model_config.preprocess_config
            self.preprocess_args = {"train": False}
            self.preprocessing = Transformation(self.preprocess_conf)
151 152

            # frame window samples length and frame shift samples length
153 154
            self.win_length = self.preprocess_conf.process[0]['win_length']
            self.n_shift = self.preprocess_conf.process[0]['n_shift']
155

156
    def extract_feat(self, samples):
X
xiongxinlei 已提交
157 158 159 160 161 162

        # we compute the elapsed time of first char occuring 
        # and we record the start time at the first pcm sample arraving
        # if self.first_char_occur_elapsed is not None:
        #     self.first_char_occur_elapsed = time.time()

163
        if "deepspeech2online" in self.model_type:
164 165 166 167 168
            # self.reamined_wav stores all the samples, 
            # include the original remained_wav and this package samples
            samples = np.frombuffer(samples, dtype=np.int16)
            assert samples.ndim == 1

X
xiongxinlei 已提交
169 170 171 172 173
            # pcm16 -> pcm 32
            # pcm2float will change the orignal samples, 
            # so we shoule do pcm2float before concatenate
            samples = pcm2float(samples)

174 175 176 177 178 179 180 181 182 183 184
            if self.remained_wav is None:
                self.remained_wav = samples
            else:
                assert self.remained_wav.ndim == 1
                self.remained_wav = np.concatenate([self.remained_wav, samples])
            logger.info(
                f"The connection remain the audio samples: {self.remained_wav.shape}"
            )

            # read audio
            speech_segment = SpeechSegment.from_pcm(
X
xiongxinlei 已提交
185
                self.remained_wav, self.sample_rate, transcript=" ")
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
            # audio augment
            self.collate_fn_test.augmentation.transform_audio(speech_segment)

            # extract speech feature
            spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
                speech_segment, self.collate_fn_test.keep_transcription_text)
            # CMVN spectrum
            if self.collate_fn_test._normalizer:
                spectrum = self.collate_fn_test._normalizer.apply(spectrum)

            # spectrum augment
            audio = self.collate_fn_test.augmentation.transform_feature(
                spectrum)

            audio_len = audio.shape[0]
            audio = paddle.to_tensor(audio, dtype='float32')
            # audio_len = paddle.to_tensor(audio_len)
            audio = paddle.unsqueeze(audio, axis=0)
X
xiongxinlei 已提交
204

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
            if self.cached_feat is None:
                self.cached_feat = audio
            else:
                assert (len(audio.shape) == 3)
                assert (len(self.cached_feat.shape) == 3)
                self.cached_feat = paddle.concat(
                    [self.cached_feat, audio], axis=1)

                # set the feat device
            if self.device is None:
                self.device = self.cached_feat.place

            self.num_frames += audio_len
            self.remained_wav = self.remained_wav[self.n_shift * audio_len:]

            logger.info(
                f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
            )
            logger.info(
                f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
            )
X
xiongxinlei 已提交
226
        elif "conformer_online" in self.model_type:
227 228 229
            logger.info("Online ASR extract the feat")
            samples = np.frombuffer(samples, dtype=np.int16)
            assert samples.ndim == 1
230

231 232
            logger.info(f"This package receive {samples.shape[0]} pcm data")
            self.num_samples += samples.shape[0]
233

234 235 236 237 238 239 240 241 242 243 244 245 246 247
            # self.reamined_wav stores all the samples, 
            # include the original remained_wav and this package samples
            if self.remained_wav is None:
                self.remained_wav = samples
            else:
                assert self.remained_wav.ndim == 1
                self.remained_wav = np.concatenate([self.remained_wav, samples])
            logger.info(
                f"The connection remain the audio samples: {self.remained_wav.shape}"
            )
            if len(self.remained_wav) < self.win_length:
                return 0

            # fbank
248 249
            x_chunk = self.preprocessing(self.remained_wav,
                                         **self.preprocess_args)
250 251 252 253 254
            x_chunk = paddle.to_tensor(
                x_chunk, dtype="float32").unsqueeze(axis=0)
            if self.cached_feat is None:
                self.cached_feat = x_chunk
            else:
255 256 257 258 259
                assert (len(x_chunk.shape) == 3)
                assert (len(self.cached_feat.shape) == 3)
                self.cached_feat = paddle.concat(
                    [self.cached_feat, x_chunk], axis=1)

X
xiongxinlei 已提交
260 261
            # set the feat device
            if self.device is None:
262
                self.device = self.cached_feat.place
X
xiongxinlei 已提交
263

264 265 266 267 268 269
            num_frames = x_chunk.shape[1]
            self.num_frames += num_frames
            self.remained_wav = self.remained_wav[self.n_shift * num_frames:]

            logger.info(
                f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
270
            )
271 272 273 274 275 276
            logger.info(
                f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
            )
            # logger.info(f"accumulate samples: {self.num_samples}")       

    def reset(self):
277 278
        if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
            # for deepspeech2 
X
xiongxinlei 已提交
279 280 281 282
            self.chunk_state_h_box = copy.deepcopy(
                self.asr_engine.executor.chunk_state_h_box)
            self.chunk_state_c_box = copy.deepcopy(
                self.asr_engine.executor.chunk_state_c_box)
283
            self.decoder.reset_decoder(batch_size=1)
X
xiongxinlei 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299

        # for conformer online
        self.subsampling_cache = None
        self.elayers_output_cache = None
        self.conformer_cnn_cache = None
        self.encoder_out = None
        self.cached_feat = None
        self.remained_wav = None
        self.offset = 0
        self.num_samples = 0
        self.device = None
        self.hyps = []
        self.num_frames = 0
        self.chunk_num = 0
        self.global_frame_offset = 0
        self.result_transcripts = ['']
X
xiongxinlei 已提交
300
        self.first_char_occur_elapsed = None
301 302 303

    def decode(self, is_finished=False):
        if "deepspeech2online" in self.model_type:
304
            # x_chunk 是特征数据
X
xiongxinlei 已提交
305 306 307
            decoding_chunk_size = 1  # decoding_chunk_size=1 in deepspeech2 model
            context = 7  # context=7 in deepspeech2 model
            subsampling = 4  # subsampling=4 in deepspeech2 model
308 309 310
            stride = subsampling * decoding_chunk_size
            cached_feature_num = context - subsampling
            # decoding window for model
X
xiongxinlei 已提交
311 312
            decoding_window = (decoding_chunk_size - 1) * subsampling + context

313 314
            if self.cached_feat is None:
                logger.info("no audio feat, please input more pcm data")
X
xiongxinlei 已提交
315 316
                return

317 318 319 320 321 322 323
            num_frames = self.cached_feat.shape[1]
            logger.info(
                f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
            )
            # the cached feat must be larger decoding_window
            if num_frames < decoding_window and not is_finished:
                logger.info(
X
xiongxinlei 已提交
324
                    f"frame feat num is less than {decoding_window}, please input more pcm data"
325 326 327 328 329 330
                )
                return None, None

            # if is_finished=True, we need at least context frames
            if num_frames < context:
                logger.info(
X
xiongxinlei 已提交
331
                    "flast {num_frames} is less than context {context} frames, and we cannot do model forward"
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
                )
                return None, None
            logger.info("start to do model forward")
            # num_frames - context + 1 ensure that current frame can get context window
            if is_finished:
                # if get the finished chunk, we need process the last context
                left_frames = context
            else:
                # we only process decoding_window frames for one chunk
                left_frames = decoding_window

            for cur in range(0, num_frames - left_frames + 1, stride):
                end = min(cur + decoding_window, num_frames)
                # extract the audio
                x_chunk = self.cached_feat[:, cur:end, :].numpy()
                x_chunk_lens = np.array([x_chunk.shape[1]])
                trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)

            self.result_transcripts = [trans_best]

X
xiongxinlei 已提交
352
            self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
353
            # return trans_best[0]            
354 355 356 357 358 359
        elif "conformer" in self.model_type or "transformer" in self.model_type:
            try:
                logger.info(
                    f"we will use the transformer like model : {self.model_type}"
                )
                self.advance_decoding(is_finished)
X
xiongxinlei 已提交
360
                self.update_result()
361 362 363 364 365 366

            except Exception as e:
                logger.exception(e)
        else:
            raise Exception("invalid model name")

367
    @paddle.no_grad()
368 369 370 371
    def decode_one_chunk(self, x_chunk, x_chunk_lens):
        logger.info("start to decoce one chunk with deepspeech2 model")
        input_names = self.am_predictor.get_input_names()
        audio_handle = self.am_predictor.get_input_handle(input_names[0])
X
xiongxinlei 已提交
372
        audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
        h_box_handle = self.am_predictor.get_input_handle(input_names[2])
        c_box_handle = self.am_predictor.get_input_handle(input_names[3])

        audio_handle.reshape(x_chunk.shape)
        audio_handle.copy_from_cpu(x_chunk)

        audio_len_handle.reshape(x_chunk_lens.shape)
        audio_len_handle.copy_from_cpu(x_chunk_lens)

        h_box_handle.reshape(self.chunk_state_h_box.shape)
        h_box_handle.copy_from_cpu(self.chunk_state_h_box)

        c_box_handle.reshape(self.chunk_state_c_box.shape)
        c_box_handle.copy_from_cpu(self.chunk_state_c_box)

        output_names = self.am_predictor.get_output_names()
        output_handle = self.am_predictor.get_output_handle(output_names[0])
        output_lens_handle = self.am_predictor.get_output_handle(
X
xiongxinlei 已提交
391
            output_names[1])
392
        output_state_h_handle = self.am_predictor.get_output_handle(
X
xiongxinlei 已提交
393
            output_names[2])
394
        output_state_c_handle = self.am_predictor.get_output_handle(
X
xiongxinlei 已提交
395
            output_names[3])
396 397 398 399 400 401 402 403 404 405

        self.am_predictor.run()

        output_chunk_probs = output_handle.copy_to_cpu()
        output_chunk_lens = output_lens_handle.copy_to_cpu()
        self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
        self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()

        self.decoder.next(output_chunk_probs, output_chunk_lens)
        trans_best, trans_beam = self.decoder.decode()
X
xiongxinlei 已提交
406
        logger.info(f"decode one best result: {trans_best[0]}")
407 408
        return trans_best[0]

409
    @paddle.no_grad()
410 411 412 413 414 415 416 417 418 419
    def advance_decoding(self, is_finished=False):
        logger.info("start to decode with advanced_decoding method")
        cfg = self.ctc_decode_config
        decoding_chunk_size = cfg.decoding_chunk_size
        num_decoding_left_chunks = cfg.num_decoding_left_chunks

        assert decoding_chunk_size > 0
        subsampling = self.model.encoder.embed.subsampling_rate
        context = self.model.encoder.embed.right_context + 1
        stride = subsampling * decoding_chunk_size
420
        cached_feature_num = context - subsampling  # processed chunk feature cached for next chunk
421 422 423

        # decoding window for model
        decoding_window = (decoding_chunk_size - 1) * subsampling + context
X
xiongxinlei 已提交
424 425 426
        if self.cached_feat is None:
            logger.info("no audio feat, please input more pcm data")
            return
427

428
        num_frames = self.cached_feat.shape[1]
429 430 431 432
        logger.info(
            f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
        )

433 434
        # the cached feat must be larger decoding_window
        if num_frames < decoding_window and not is_finished:
435 436 437
            logger.info(
                f"frame feat num is less than {decoding_window}, please input more pcm data"
            )
438
            return None, None
X
xiongxinlei 已提交
439

440
        # if is_finished=True, we need at least context frames
X
xiongxinlei 已提交
441
        if num_frames < context:
442 443 444
            logger.info(
                "flast {num_frames} is less than context {context} frames, and we cannot do model forward"
            )
X
xiongxinlei 已提交
445 446
            return None, None

447 448 449 450 451 452
        logger.info("start to do model forward")
        required_cache_size = decoding_chunk_size * num_decoding_left_chunks
        outputs = []

        # num_frames - context + 1 ensure that current frame can get context window
        if is_finished:
453
            # if get the finished chunk, we need process the last context
454 455 456
            left_frames = context
        else:
            # we only process decoding_window frames for one chunk
457 458
            left_frames = decoding_window

X
xiongxinlei 已提交
459
        # record the end for removing the processed feat
460 461 462
        end = None
        for cur in range(0, num_frames - left_frames + 1, stride):
            end = min(cur + decoding_window, num_frames)
463

464
            self.chunk_num += 1
X
xiongxinlei 已提交
465 466 467 468 469 470 471 472
            chunk_xs = self.cached_feat[:, cur:end, :]
            (y, self.subsampling_cache, self.elayers_output_cache,
             self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
                 chunk_xs, self.offset, required_cache_size,
                 self.subsampling_cache, self.elayers_output_cache,
                 self.conformer_cnn_cache)
            outputs.append(y)

473
            # update the offset
X
xiongxinlei 已提交
474
            self.offset += y.shape[1]
475

X
xiongxinlei 已提交
476
        ys = paddle.cat(outputs, 1)
X
xiongxinlei 已提交
477
        if self.encoder_out is None:
478
            self.encoder_out = ys
X
xiongxinlei 已提交
479 480
        else:
            self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
481

X
xiongxinlei 已提交
482 483 484 485
        # get the ctc probs
        ctc_probs = self.model.ctc.log_softmax(ys)  # (1, maxlen, vocab_size)
        ctc_probs = ctc_probs.squeeze(0)

486
        self.searcher.search(ctc_probs, self.cached_feat.place)
487

X
xiongxinlei 已提交
488
        self.hyps = self.searcher.get_one_best_hyps()
489 490
        assert self.cached_feat.shape[0] == 1
        assert end >= cached_feature_num
X
xiongxinlei 已提交
491

492 493 494 495 496
        self.cached_feat = self.cached_feat[0, end -
                                            cached_feature_num:, :].unsqueeze(0)
        assert len(
            self.cached_feat.shape
        ) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
X
xiongxinlei 已提交
497

498 499 500
        logger.info(
            f"This connection handler encoder out shape: {self.encoder_out.shape}"
        )
501 502 503 504 505 506 507 508 509

    def update_result(self):
        logger.info("update the final result")
        hyps = self.hyps
        self.result_transcripts = [
            self.text_feature.defeaturize(hyp) for hyp in hyps
        ]
        self.result_tokenids = [hyp for hyp in hyps]

X
xiongxinlei 已提交
510 511 512 513 514 515
    def get_result(self):
        if len(self.result_transcripts) > 0:
            return self.result_transcripts[0]
        else:
            return ''

516
    @paddle.no_grad()
517
    def rescoring(self):
518 519
        if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
            return
X
xiongxinlei 已提交
520

X
xiongxinlei 已提交
521 522
        logger.info("rescoring the final result")
        if "attention_rescoring" != self.ctc_decode_config.decoding_method:
523 524
            return

X
xiongxinlei 已提交
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575
        self.searcher.finalize_search()
        self.update_result()

        beam_size = self.ctc_decode_config.beam_size
        hyps = self.searcher.get_hyps()
        if hyps is None or len(hyps) == 0:
            return

        # assert len(hyps) == beam_size
        hyp_list = []
        for hyp in hyps:
            hyp_content = hyp[0]
            # Prevent the hyp is empty
            if len(hyp_content) == 0:
                hyp_content = (self.model.ctc.blank_id, )
            hyp_content = paddle.to_tensor(
                hyp_content, place=self.device, dtype=paddle.long)
            hyp_list.append(hyp_content)
        hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
        hyps_lens = paddle.to_tensor(
            [len(hyp[0]) for hyp in hyps], place=self.device,
            dtype=paddle.long)  # (beam_size,)
        hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
                                  self.model.ignore_id)
        hyps_lens = hyps_lens + 1  # Add <sos> at begining

        encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
        encoder_mask = paddle.ones(
            (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
        decoder_out, _ = self.model.decoder(
            encoder_out, encoder_mask, hyps_pad,
            hyps_lens)  # (beam_size, max_hyps_len, vocab_size)
        # ctc score in ln domain
        decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
        decoder_out = decoder_out.numpy()

        # Only use decoder score for rescoring
        best_score = -float('inf')
        best_index = 0
        # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
        for i, hyp in enumerate(hyps):
            score = 0.0
            for j, w in enumerate(hyp[0]):
                score += decoder_out[i][j][w]
            # last decoder output token is `eos`, for laste decoder input token.
            score += decoder_out[i][len(hyp[0])][self.model.eos]
            # add ctc score (which in ln domain)
            score += hyp[1] * self.ctc_decode_config.ctc_weight
            if score > best_score:
                best_score = score
                best_index = i
576

X
xiongxinlei 已提交
577 578 579 580
        # update the one best result
        logger.info(f"best index: {best_index}")
        self.hyps = [hyps[best_index][0]]
        self.update_result()
581

582 583 584 585 586 587

class ASRServerExecutor(ASRExecutor):
    def __init__(self):
        super().__init__()
        pass

588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
    def _get_pretrained_path(self, tag: str) -> os.PathLike:
        """
        Download and returns pretrained resources path of current task.
        """
        support_models = list(pretrained_models.keys())
        assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
            tag, '\n\t\t'.join(support_models))

        res_path = os.path.join(MODEL_HOME, tag)
        decompressed_path = download_and_decompress(pretrained_models[tag],
                                                    res_path)
        decompressed_path = os.path.abspath(decompressed_path)
        logger.info(
            'Use pretrained model stored in: {}'.format(decompressed_path))

        return decompressed_path

605
    def _init_from_path(self,
X
xiongxinlei 已提交
606
                        model_type: str='deepspeech2online_aishell',
607 608 609 610 611 612 613 614 615 616
                        am_model: Optional[os.PathLike]=None,
                        am_params: Optional[os.PathLike]=None,
                        lang: str='zh',
                        sample_rate: int=16000,
                        cfg_path: Optional[os.PathLike]=None,
                        decode_method: str='attention_rescoring',
                        am_predictor_conf: dict=None):
        """
        Init model and other resources from a specific path.
        """
617 618
        self.model_type = model_type
        self.sample_rate = sample_rate
619 620 621
        if cfg_path is None or am_model is None or am_params is None:
            sample_rate_str = '16k' if sample_rate == 16000 else '8k'
            tag = model_type + '-' + lang + '-' + sample_rate_str
622
            logger.info(f"Load the pretrained model, tag = {tag}")
623 624
            res_path = self._get_pretrained_path(tag)  # wenetspeech_zh
            self.res_path = res_path
625 626 627

            self.cfg_path = os.path.join(res_path,
                                         pretrained_models[tag]['cfg_path'])
628 629 630 631 632 633 634 635 636 637 638 639 640

            self.am_model = os.path.join(res_path,
                                         pretrained_models[tag]['model'])
            self.am_params = os.path.join(res_path,
                                          pretrained_models[tag]['params'])
            logger.info(res_path)
        else:
            self.cfg_path = os.path.abspath(cfg_path)
            self.am_model = os.path.abspath(am_model)
            self.am_params = os.path.abspath(am_params)
            self.res_path = os.path.dirname(
                os.path.dirname(os.path.abspath(self.cfg_path)))

641 642 643 644
        logger.info(self.cfg_path)
        logger.info(self.am_model)
        logger.info(self.am_params)

645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
        #Init body.
        self.config = CfgNode(new_allowed=True)
        self.config.merge_from_file(self.cfg_path)

        with UpdateConfig(self.config):
            if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
                from paddlespeech.s2t.io.collator import SpeechCollator
                self.vocab = self.config.vocab_filepath
                self.config.decode.lang_model_path = os.path.join(
                    MODEL_HOME, 'language_model',
                    self.config.decode.lang_model_path)
                self.collate_fn_test = SpeechCollator.from_config(self.config)
                self.text_feature = TextFeaturizer(
                    unit_type=self.config.unit_type, vocab=self.vocab)

                lm_url = pretrained_models[tag]['lm_url']
                lm_md5 = pretrained_models[tag]['lm_md5']
662
                logger.info(f"Start to load language model {lm_url}")
663 664 665
                self.download_lm(
                    lm_url,
                    os.path.dirname(self.config.decode.lang_model_path), lm_md5)
X
xiongxinlei 已提交
666
            elif "conformer" in model_type or "transformer" in model_type:
667
                logger.info("start to create the stream conformer asr engine")
668 669 670
                if self.config.spm_model_prefix:
                    self.config.spm_model_prefix = os.path.join(
                        self.res_path, self.config.spm_model_prefix)
671
                self.vocab = self.config.vocab_filepath
672 673 674 675 676 677 678
                self.text_feature = TextFeaturizer(
                    unit_type=self.config.unit_type,
                    vocab=self.config.vocab_filepath,
                    spm_model_prefix=self.config.spm_model_prefix)
                # update the decoding method
                if decode_method:
                    self.config.decode.decoding_method = decode_method
679 680 681 682 683 684 685 686 687 688 689 690

                # we only support ctc_prefix_beam_search and attention_rescoring dedoding method
                # Generally we set the decoding_method to attention_rescoring
                if self.config.decode.decoding_method not in [
                        "ctc_prefix_beam_search", "attention_rescoring"
                ]:
                    logger.info(
                        "we set the decoding_method to attention_rescoring")
                    self.config.decode.decoding = "attention_rescoring"
                assert self.config.decode.decoding_method in [
                    "ctc_prefix_beam_search", "attention_rescoring"
                ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
691 692
            else:
                raise Exception("wrong type")
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729
        if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
            # AM predictor
            logger.info("ASR engine start to init the am predictor")
            self.am_predictor_conf = am_predictor_conf
            self.am_predictor = init_predictor(
                model_file=self.am_model,
                params_file=self.am_params,
                predictor_conf=self.am_predictor_conf)

            # decoder
            logger.info("ASR engine start to create the ctc decoder instance")
            self.decoder = CTCDecoder(
                odim=self.config.output_dim,  # <blank> is in  vocab
                enc_n_units=self.config.rnn_layer_size * 2,
                blank_id=self.config.blank_id,
                dropout_rate=0.0,
                reduction=True,  # sum
                batch_average=True,  # sum / batch_size
                grad_norm_type=self.config.get('ctc_grad_norm_type', None))

            # init decoder
            logger.info("ASR engine start to init the ctc decoder")
            cfg = self.config.decode
            decode_batch_size = 1  # for online
            self.decoder.init_decoder(
                decode_batch_size, self.text_feature.vocab_list,
                cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
                cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
                cfg.num_proc_bsearch)

            # init state box
            self.chunk_state_h_box = np.zeros(
                (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
                dtype=float32)
            self.chunk_state_c_box = np.zeros(
                (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
                dtype=float32)
X
xiongxinlei 已提交
730
        elif "conformer" in model_type or "transformer" in model_type:
731 732 733 734 735 736 737
            model_name = model_type[:model_type.rindex(
                '_')]  # model_type: {model_name}_{dataset}
            logger.info(f"model name: {model_name}")
            model_class = dynamic_import(model_name, model_alias)
            model_conf = self.config
            model = model_class.from_config(model_conf)
            self.model = model
738 739 740 741 742
            self.model.eval()

            # load model
            model_dict = paddle.load(self.am_model)
            self.model.set_state_dict(model_dict)
743
            logger.info("create the transformer like model success")
744

745
            # update the ctc decoding
746
            self.searcher = CTCPrefixBeamSearch(self.config.decode)
747 748
            self.transformer_decode_reset()

749 750 751
    def reset_decoder_and_chunk(self):
        """reset decoder and chunk state for an new audio
        """
752 753 754 755 756 757 758 759 760
        if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
            self.decoder.reset_decoder(batch_size=1)
            # init state box, for new audio request
            self.chunk_state_h_box = np.zeros(
                (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
                dtype=float32)
            self.chunk_state_c_box = np.zeros(
                (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
                dtype=float32)
X
xiongxinlei 已提交
761
        elif "conformer" in self.model_type or "transformer" in self.model_type:
762
            self.transformer_decode_reset()
763 764 765 766 767 768 769 770 771 772

    def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
        """decode one chunk

        Args:
            x_chunk (numpy.array): shape[B, T, D]
            x_chunk_lens (numpy.array): shape[B]
            model_type (str): online model type

        Returns:
X
xiongxinlei 已提交
773
            str: one best result
774
        """
775
        logger.info("start to decoce chunk by chunk")
H
Hui Zhang 已提交
776
        if "deepspeech2online" in model_type:
777 778
            input_names = self.am_predictor.get_input_names()
            audio_handle = self.am_predictor.get_input_handle(input_names[0])
H
Hui Zhang 已提交
779 780
            audio_len_handle = self.am_predictor.get_input_handle(
                input_names[1])
781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797
            h_box_handle = self.am_predictor.get_input_handle(input_names[2])
            c_box_handle = self.am_predictor.get_input_handle(input_names[3])

            audio_handle.reshape(x_chunk.shape)
            audio_handle.copy_from_cpu(x_chunk)

            audio_len_handle.reshape(x_chunk_lens.shape)
            audio_len_handle.copy_from_cpu(x_chunk_lens)

            h_box_handle.reshape(self.chunk_state_h_box.shape)
            h_box_handle.copy_from_cpu(self.chunk_state_h_box)

            c_box_handle.reshape(self.chunk_state_c_box.shape)
            c_box_handle.copy_from_cpu(self.chunk_state_c_box)

            output_names = self.am_predictor.get_output_names()
            output_handle = self.am_predictor.get_output_handle(output_names[0])
H
Hui Zhang 已提交
798 799
            output_lens_handle = self.am_predictor.get_output_handle(
                output_names[1])
800 801 802 803 804 805 806 807 808 809 810 811 812 813
            output_state_h_handle = self.am_predictor.get_output_handle(
                output_names[2])
            output_state_c_handle = self.am_predictor.get_output_handle(
                output_names[3])

            self.am_predictor.run()

            output_chunk_probs = output_handle.copy_to_cpu()
            output_chunk_lens = output_lens_handle.copy_to_cpu()
            self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
            self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()

            self.decoder.next(output_chunk_probs, output_chunk_lens)
            trans_best, trans_beam = self.decoder.decode()
X
xiongxinlei 已提交
814
            logger.info(f"decode one best result: {trans_best[0]}")
815 816 817
            return trans_best[0]

        elif "conformer" in model_type or "transformer" in model_type:
818 819 820 821
            try:
                logger.info(
                    f"we will use the transformer like model : {self.model_type}"
                )
822 823 824 825
                self.advanced_decoding(x_chunk, x_chunk_lens)
                self.update_result()

                return self.result_transcripts[0]
826 827
            except Exception as e:
                logger.exception(e)
828 829 830
        else:
            raise Exception("invalid model name")

831 832
    def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens):
        logger.info("start to decode with advanced_decoding method")
X
xiongxinlei 已提交
833
        encoder_out, encoder_mask = self.encoder_forward(xs)
834 835 836
        ctc_probs = self.model.ctc.log_softmax(
            encoder_out)  # (1, maxlen, vocab_size)
        ctc_probs = ctc_probs.squeeze(0)
837
        self.searcher.search(ctc_probs, xs.place)
838 839 840 841 842 843
        # update the one best result
        self.hyps = self.searcher.get_one_best_hyps()

        # now we supprot ctc_prefix_beam_search and attention_rescoring
        if "attention_rescoring" in self.config.decode.decoding_method:
            self.rescoring(encoder_out, xs.place)
844

X
xiongxinlei 已提交
845
    def encoder_forward(self, xs):
846 847 848 849 850 851 852 853 854
        logger.info("get the model out from the feat")
        cfg = self.config.decode
        decoding_chunk_size = cfg.decoding_chunk_size
        num_decoding_left_chunks = cfg.num_decoding_left_chunks

        assert decoding_chunk_size > 0
        subsampling = self.model.encoder.embed.subsampling_rate
        context = self.model.encoder.embed.right_context + 1
        stride = subsampling * decoding_chunk_size
855

856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880
        # decoding window for model
        decoding_window = (decoding_chunk_size - 1) * subsampling + context
        num_frames = xs.shape[1]
        required_cache_size = decoding_chunk_size * num_decoding_left_chunks

        logger.info("start to do model forward")
        outputs = []

        # num_frames - context + 1 ensure that current frame can get context window
        for cur in range(0, num_frames - context + 1, stride):
            end = min(cur + decoding_window, num_frames)
            chunk_xs = xs[:, cur:end, :]
            (y, self.subsampling_cache, self.elayers_output_cache,
             self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
                 chunk_xs, self.offset, required_cache_size,
                 self.subsampling_cache, self.elayers_output_cache,
                 self.conformer_cnn_cache)
            outputs.append(y)
            self.offset += y.shape[1]

        ys = paddle.cat(outputs, 1)
        masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
        masks = masks.unsqueeze(1)
        return ys, masks

881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933
    def rescoring(self, encoder_out, device):
        logger.info("start to rescoring the hyps")
        beam_size = self.config.decode.beam_size
        hyps = self.searcher.get_hyps()
        assert len(hyps) == beam_size

        hyp_list = []
        for hyp in hyps:
            hyp_content = hyp[0]
            # Prevent the hyp is empty
            if len(hyp_content) == 0:
                hyp_content = (self.model.ctc.blank_id, )
            hyp_content = paddle.to_tensor(
                hyp_content, place=device, dtype=paddle.long)
            hyp_list.append(hyp_content)
        hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
        hyps_lens = paddle.to_tensor(
            [len(hyp[0]) for hyp in hyps], place=device,
            dtype=paddle.long)  # (beam_size,)
        hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
                                  self.model.ignore_id)
        hyps_lens = hyps_lens + 1  # Add <sos> at begining

        encoder_out = encoder_out.repeat(beam_size, 1, 1)
        encoder_mask = paddle.ones(
            (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
        decoder_out, _ = self.model.decoder(
            encoder_out, encoder_mask, hyps_pad,
            hyps_lens)  # (beam_size, max_hyps_len, vocab_size)
        # ctc score in ln domain
        decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
        decoder_out = decoder_out.numpy()

        # Only use decoder score for rescoring
        best_score = -float('inf')
        best_index = 0
        # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
        for i, hyp in enumerate(hyps):
            score = 0.0
            for j, w in enumerate(hyp[0]):
                score += decoder_out[i][j][w]
            # last decoder output token is `eos`, for laste decoder input token.
            score += decoder_out[i][len(hyp[0])][self.model.eos]
            # add ctc score (which in ln domain)
            score += hyp[1] * self.config.decode.ctc_weight
            if score > best_score:
                best_score = score
                best_index = i

        # update the one best result
        self.hyps = [hyps[best_index][0]]
        return hyps[best_index][0]

934 935 936 937 938
    def transformer_decode_reset(self):
        self.subsampling_cache = None
        self.elayers_output_cache = None
        self.conformer_cnn_cache = None
        self.offset = 0
939 940
        # decoding reset
        self.searcher.reset()
941 942 943

    def update_result(self):
        logger.info("update the final result")
944
        hyps = self.hyps
945
        self.result_transcripts = [
946
            self.text_feature.defeaturize(hyp) for hyp in hyps
947
        ]
948
        self.result_tokenids = [hyp for hyp in hyps]
949

950 951 952 953 954 955 956 957 958 959 960
    def extract_feat(self, samples, sample_rate):
        """extract feat

        Args:
            samples (numpy.array): numpy.float32
            sample_rate (int): sample rate

        Returns:
            x_chunk (numpy.array): shape[B, T, D]
            x_chunk_lens (numpy.array): shape[B]
        """
961

962
        if "deepspeech2online" in self.model_type:
963 964
            # pcm16 -> pcm 32
            samples = pcm2float(samples)
965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990
            # read audio
            speech_segment = SpeechSegment.from_pcm(
                samples, sample_rate, transcript=" ")
            # audio augment
            self.collate_fn_test.augmentation.transform_audio(speech_segment)

            # extract speech feature
            spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
                speech_segment, self.collate_fn_test.keep_transcription_text)
            # CMVN spectrum
            if self.collate_fn_test._normalizer:
                spectrum = self.collate_fn_test._normalizer.apply(spectrum)

            # spectrum augment
            audio = self.collate_fn_test.augmentation.transform_feature(
                spectrum)

            audio_len = audio.shape[0]
            audio = paddle.to_tensor(audio, dtype='float32')
            # audio_len = paddle.to_tensor(audio_len)
            audio = paddle.unsqueeze(audio, axis=0)

            x_chunk = audio.numpy()
            x_chunk_lens = np.array([audio_len])

            return x_chunk, x_chunk_lens
X
xiongxinlei 已提交
991
        elif "conformer_online" in self.model_type:
992 993

            if sample_rate != self.sample_rate:
994
                logger.info(f"audio sample rate {sample_rate} is not match,"
995
                            "the model sample_rate is {self.sample_rate}")
996
            logger.info(f"ASR Engine use the {self.model_type} to process")
997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012
            logger.info("Create the preprocess instance")
            preprocess_conf = self.config.preprocess_config
            preprocess_args = {"train": False}
            preprocessing = Transformation(preprocess_conf)

            logger.info("Read the audio file")
            logger.info(f"audio shape: {samples.shape}")
            # fbank
            x_chunk = preprocessing(samples, **preprocess_args)
            x_chunk_lens = paddle.to_tensor(x_chunk.shape[0])
            x_chunk = paddle.to_tensor(
                x_chunk, dtype="float32").unsqueeze(axis=0)
            logger.info(
                f"process the audio feature success, feat shape: {x_chunk.shape}"
            )
            return x_chunk, x_chunk_lens
1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023


class ASREngine(BaseEngine):
    """ASR server engine

    Args:
        metaclass: Defaults to Singleton.
    """

    def __init__(self):
        super(ASREngine, self).__init__()
X
xiongxinlei 已提交
1024
        logger.info("create the online asr engine instance")
1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038

    def init(self, config: dict) -> bool:
        """init engine resource

        Args:
            config_file (str): config file

        Returns:
            bool: init failed or success
        """
        self.input = None
        self.output = ""
        self.executor = ASRServerExecutor()
        self.config = config
X
xiongxinlei 已提交
1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049
        try:
            if self.config.get("device", None):
                self.device = self.config.device
            else:
                self.device = paddle.get_device()
            logger.info(f"paddlespeech_server set the device: {self.device}")
            paddle.set_device(self.device)
        except BaseException:
            logger.error(
                "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
            )
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063

        self.executor._init_from_path(
            model_type=self.config.model_type,
            am_model=self.config.am_model,
            am_params=self.config.am_params,
            lang=self.config.lang,
            sample_rate=self.config.sample_rate,
            cfg_path=self.config.cfg_path,
            decode_method=self.config.decode_method,
            am_predictor_conf=self.config.am_predictor_conf)

        logger.info("Initialize ASR server engine successfully.")
        return True

1064 1065 1066 1067
    def preprocess(self,
                   samples,
                   sample_rate,
                   model_type="deepspeech2online_aishell-zh-16k"):
1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
        """preprocess

        Args:
            samples (numpy.array): numpy.float32
            sample_rate (int): sample rate

        Returns:
            x_chunk (numpy.array): shape[B, T, D]
            x_chunk_lens (numpy.array): shape[B]
        """
1078
        # if "deepspeech" in model_type:
1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089
        x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
        return x_chunk, x_chunk_lens

    def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1):
        """run online engine

        Args:
            x_chunk (numpy.array): shape[B, T, D]
            x_chunk_lens (numpy.array): shape[B]
            decoder_chunk_size(int)
        """
H
Hui Zhang 已提交
1090 1091
        self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens,
                                                     self.config.model_type)
1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102

    def postprocess(self):
        """postprocess
        """
        return self.output

    def reset(self):
        """reset engine decoder and inference state
        """
        self.executor.reset_decoder_and_chunk()
        self.output = ""