infer.py 21.1 KB
Newer Older
K
KP 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
H
huangyuxin 已提交
16
import sys
K
KP 已提交
17
from collections import OrderedDict
K
KP 已提交
18 19 20 21
from typing import List
from typing import Optional
from typing import Union

H
huangyuxin 已提交
22
import librosa
H
huangyuxin 已提交
23
import numpy as np
K
KP 已提交
24
import paddle
H
huangyuxin 已提交
25
import soundfile
H
huangyuxin 已提交
26
from yacs.config import CfgNode
K
KP 已提交
27

28
from ..download import get_path_from_url
K
KP 已提交
29
from ..executor import BaseExecutor
K
KP 已提交
30
from ..log import logger
K
KP 已提交
31
from ..utils import cli_register
K
KP 已提交
32 33
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
K
KP 已提交
34
from ..utils import stats_wrapper
H
huangyuxin 已提交
35 36 37 38
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig
K
KP 已提交
39

K
KP 已提交
40
__all__ = ['ASRExecutor']
K
KP 已提交
41

K
KP 已提交
42
pretrained_models = {
K
KP 已提交
43
    # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
44
    # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
K
KP 已提交
45 46 47
    # Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
    # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
    "conformer_wenetspeech-zh-16k": {
K
KP 已提交
48
        'url':
49
        'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
K
KP 已提交
50
        'md5':
51
        '76cb19ed857e6623856b7cd7ebbfeda4',
H
huangyuxin 已提交
52
        'cfg_path':
53
        'model.yaml',
H
huangyuxin 已提交
54 55
        'ckpt_path':
        'exp/conformer/checkpoints/wenetspeech',
H
huangyuxin 已提交
56
    },
57 58 59 60 61 62 63 64 65 66
    "transformer_librispeech-en-16k": {
        'url':
        'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz',
        'md5':
        '2c667da24922aad391eacafe37bc1660',
        'cfg_path':
        'model.yaml',
        'ckpt_path':
        'exp/transformer/checkpoints/avg_10',
    },
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
    "deepspeech2offline_aishell-zh-16k": {
        'url':
        'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
        'md5':
        '932c3593d62fe5c741b59b31318aa314',
        'cfg_path':
        'model.yaml',
        'ckpt_path':
        'exp/deepspeech2/checkpoints/avg_1',
        'lm_url':
        'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
        'lm_md5':
        '29e02312deb2e59b3c8686c7966d4fe3'
    },
    "deepspeech2online_aishell-zh-16k": {
        'url':
        'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
        'md5':
        'd5e076217cf60486519f72c217d21b9b',
        'cfg_path':
        'model.yaml',
        'ckpt_path':
        'exp/deepspeech2_online/checkpoints/avg_1',
        'lm_url':
        'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
        'lm_md5':
        '29e02312deb2e59b3c8686c7966d4fe3'
    },
95 96 97 98 99 100 101 102 103 104 105 106 107 108
    "deepspeech2offline_librispeech-en-16k": {
        'url':
        'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
        'md5':
        'f5666c81ad015c8de03aac2bc92e5762',
        'cfg_path':
        'model.yaml',
        'ckpt_path':
        'exp/deepspeech2/checkpoints/avg_1',
        'lm_url':
        'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
        'lm_md5':
        '099a601759d467cd0a8523ff939819c5'
    },
K
KP 已提交
109 110
}

H
huangyuxin 已提交
111
model_alias = {
112 113 114 115 116 117 118 119 120 121
    "deepspeech2offline":
    "paddlespeech.s2t.models.ds2:DeepSpeech2Model",
    "deepspeech2online":
    "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
    "conformer":
    "paddlespeech.s2t.models.u2:U2Model",
    "transformer":
    "paddlespeech.s2t.models.u2:U2Model",
    "wenetspeech":
    "paddlespeech.s2t.models.u2:U2Model",
H
huangyuxin 已提交
122 123
}

K
KP 已提交
124

K
KP 已提交
125
@cli_register(
K
KP 已提交
126 127
    name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor):
K
KP 已提交
128
    def __init__(self):
K
KP 已提交
129
        super(ASRExecutor, self).__init__()
K
KP 已提交
130 131

        self.parser = argparse.ArgumentParser(
K
KP 已提交
132 133
            prog='paddlespeech.asr', add_help=True)
        self.parser.add_argument(
K
KP 已提交
134
            '--input', type=str, default=None, help='Audio file to recognize.')
K
KP 已提交
135 136 137
        self.parser.add_argument(
            '--model',
            type=str,
K
KP 已提交
138
            default='conformer_wenetspeech',
K
KP 已提交
139
            choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()],
K
KP 已提交
140 141
            help='Choose model type of asr task.')
        self.parser.add_argument(
H
huangyuxin 已提交
142 143 144
            '--lang',
            type=str,
            default='zh',
145 146
            help='Choose model language. zh or en, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]'
        )
H
huangyuxin 已提交
147
        self.parser.add_argument(
K
KP 已提交
148
            "--sample_rate",
H
huangyuxin 已提交
149 150
            type=int,
            default=16000,
H
revise  
huangyuxin 已提交
151
            choices=[8000, 16000],
H
huangyuxin 已提交
152
            help='Choose the audio sample rate of the model. 8000 or 16000')
K
KP 已提交
153 154 155 156
        self.parser.add_argument(
            '--config',
            type=str,
            default=None,
K
KP 已提交
157
            help='Config of asr task. Use deault config when it is None.')
158 159 160 161
        self.parser.add_argument(
            '--decode_method',
            type=str,
            default='attention_rescoring',
162 163 164 165
            choices=[
                'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention',
                'attention_rescoring'
            ],
166
            help='only support transformer and conformer model')
K
KP 已提交
167 168 169 170 171
        self.parser.add_argument(
            '--ckpt_path',
            type=str,
            default=None,
            help='Checkpoint file of model.')
172
        self.parser.add_argument(
173 174
            '--yes',
            '-y',
175 176
            action="store_true",
            default=False,
177 178
            help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate'
        )
K
KP 已提交
179 180 181
        self.parser.add_argument(
            '--device',
            type=str,
K
KP 已提交
182
            default=paddle.get_device(),
K
KP 已提交
183
            help='Choose device to execute model inference.')
K
KP 已提交
184
        self.parser.add_argument(
K
KP 已提交
185
            '-d',
K
KP 已提交
186
            '--job_dump_result',
K
KP 已提交
187
            action='store_true',
K
KP 已提交
188
            help='Save job result into file.')
K
KP 已提交
189 190 191 192 193
        self.parser.add_argument(
            '-v',
            '--verbose',
            action='store_true',
            help='Increase logger verbosity of current task.')
K
KP 已提交
194

K
KP 已提交
195
    def _get_pretrained_path(self, tag: str) -> os.PathLike:
K
KP 已提交
196
        """
小湉湉's avatar
小湉湉 已提交
197
        Download and returns pretrained resources path of current task.
K
KP 已提交
198
        """
H
huangyuxin 已提交
199 200 201
        support_models = list(pretrained_models.keys())
        assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
            tag, '\n\t\t'.join(support_models))
K
KP 已提交
202 203 204 205

        res_path = os.path.join(MODEL_HOME, tag)
        decompressed_path = download_and_decompress(pretrained_models[tag],
                                                    res_path)
H
huangyuxin 已提交
206
        decompressed_path = os.path.abspath(decompressed_path)
K
KP 已提交
207 208
        logger.info(
            'Use pretrained model stored in: {}'.format(decompressed_path))
H
huangyuxin 已提交
209

K
KP 已提交
210
        return decompressed_path
K
KP 已提交
211

K
KP 已提交
212 213 214
    def _init_from_path(self,
                        model_type: str='wenetspeech',
                        lang: str='zh',
H
huangyuxin 已提交
215
                        sample_rate: int=16000,
K
KP 已提交
216
                        cfg_path: Optional[os.PathLike]=None,
217
                        decode_method: str='attention_rescoring',
H
huangyuxin 已提交
218
                        ckpt_path: Optional[os.PathLike]=None):
K
KP 已提交
219
        """
小湉湉's avatar
小湉湉 已提交
220
        Init model and other resources from a specific path.
K
KP 已提交
221
        """
K
KP 已提交
222 223 224 225
        if hasattr(self, 'model'):
            logger.info('Model had been initialized.')
            return

K
KP 已提交
226
        if cfg_path is None or ckpt_path is None:
H
huangyuxin 已提交
227
            sample_rate_str = '16k' if sample_rate == 16000 else '8k'
K
KP 已提交
228
            tag = model_type + '-' + lang + '-' + sample_rate_str
H
huangyuxin 已提交
229
            res_path = self._get_pretrained_path(tag)  # wenetspeech_zh
H
huangyuxin 已提交
230
            self.res_path = res_path
H
huangyuxin 已提交
231 232
            self.cfg_path = os.path.join(res_path,
                                         pretrained_models[tag]['cfg_path'])
H
huangyuxin 已提交
233 234
            self.ckpt_path = os.path.join(
                res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams")
K
KP 已提交
235
            logger.info(res_path)
H
huangyuxin 已提交
236 237 238 239
            logger.info(self.cfg_path)
            logger.info(self.ckpt_path)
        else:
            self.cfg_path = os.path.abspath(cfg_path)
H
revise  
huangyuxin 已提交
240
            self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
241
            self.res_path = os.path.dirname(
H
huangyuxin 已提交
242
                os.path.dirname(os.path.abspath(self.cfg_path)))
K
KP 已提交
243

H
huangyuxin 已提交
244
        #Init body.
H
huangyuxin 已提交
245
        self.config = CfgNode(new_allowed=True)
H
huangyuxin 已提交
246 247
        self.config.merge_from_file(self.cfg_path)

H
huangyuxin 已提交
248
        with UpdateConfig(self.config):
249
            if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
H
huangyuxin 已提交
250
                from paddlespeech.s2t.io.collator import SpeechCollator
251
                self.vocab = self.config.vocab_filepath
252 253 254
                self.config.decode.lang_model_path = os.path.join(
                    MODEL_HOME, 'language_model',
                    self.config.decode.lang_model_path)
H
huangyuxin 已提交
255
                self.collate_fn_test = SpeechCollator.from_config(self.config)
256
                self.text_feature = TextFeaturizer(
257 258 259 260 261 262 263
                    unit_type=self.config.unit_type, vocab=self.vocab)
                lm_url = pretrained_models[tag]['lm_url']
                lm_md5 = pretrained_models[tag]['lm_md5']
                self.download_lm(
                    lm_url,
                    os.path.dirname(self.config.decode.lang_model_path), lm_md5)

K
KP 已提交
264
            elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
265 266
                self.config.spm_model_prefix = os.path.join(
                    self.res_path, self.config.spm_model_prefix)
267
                self.text_feature = TextFeaturizer(
268 269 270 271
                    unit_type=self.config.unit_type,
                    vocab=self.config.vocab_filepath,
                    spm_model_prefix=self.config.spm_model_prefix)
                self.config.decode.decoding_method = decode_method
H
huangyuxin 已提交
272

H
huangyuxin 已提交
273 274
            else:
                raise Exception("wrong type")
K
KP 已提交
275 276
        model_name = model_type[:model_type.rindex(
            '_')]  # model_type: {model_name}_{dataset}
K
KP 已提交
277
        model_class = dynamic_import(model_name, model_alias)
278
        model_conf = self.config
H
huangyuxin 已提交
279 280 281 282 283
        model = model_class.from_config(model_conf)
        self.model = model
        self.model.eval()

        # load model
H
revise  
huangyuxin 已提交
284
        model_dict = paddle.load(self.ckpt_path)
H
huangyuxin 已提交
285
        self.model.set_state_dict(model_dict)
K
KP 已提交
286

H
huangyuxin 已提交
287
    def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
K
KP 已提交
288
        """
小湉湉's avatar
小湉湉 已提交
289 290
        Input preprocess and return paddle.Tensor stored in self.input.
        Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
K
KP 已提交
291
        """
H
huangyuxin 已提交
292 293

        audio_file = input
294 295
        if isinstance(audio_file, (str, os.PathLike)):
            logger.info("Preprocess audio_file:" + audio_file)
H
huangyuxin 已提交
296 297

        # Get the object for feature extraction
298
        if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
H
huangyuxin 已提交
299
            audio, _ = self.collate_fn_test.process_utterance(
H
huangyuxin 已提交
300 301 302
                audio_file=audio_file, transcript=" ")
            audio_len = audio.shape[0]
            audio = paddle.to_tensor(audio, dtype='float32')
H
huangyuxin 已提交
303 304
            audio_len = paddle.to_tensor(audio_len)
            audio = paddle.unsqueeze(audio, axis=0)
305
            # vocab_list = collate_fn_test.vocab_list
H
huangyuxin 已提交
306 307 308 309
            self._inputs["audio"] = audio
            self._inputs["audio_len"] = audio_len
            logger.info(f"audio feat shape: {audio.shape}")

K
KP 已提交
310
        elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
H
huangyuxin 已提交
311
            logger.info("get the preprocess conf")
312
            preprocess_conf = self.config.preprocess_config
H
huangyuxin 已提交
313 314
            preprocess_args = {"train": False}
            preprocessing = Transformation(preprocess_conf)
H
huangyuxin 已提交
315
            logger.info("read the audio file")
H
huangyuxin 已提交
316
            audio, audio_sample_rate = soundfile.read(
H
huangyuxin 已提交
317
                audio_file, dtype="int16", always_2d=True)
H
huangyuxin 已提交
318 319 320

            if self.change_format:
                if audio.shape[1] >= 2:
H
huangyuxin 已提交
321
                    audio = audio.mean(axis=1, dtype=np.int16)
H
huangyuxin 已提交
322 323
                else:
                    audio = audio[:, 0]
H
revise  
huangyuxin 已提交
324
                # pcm16 -> pcm 32
H
huangyuxin 已提交
325
                audio = self._pcm16to32(audio)
H
huangyuxin 已提交
326 327 328 329
                audio = librosa.resample(
                    audio,
                    orig_sr=audio_sample_rate,
                    target_sr=self.sample_rate)
H
huangyuxin 已提交
330
                audio_sample_rate = self.sample_rate
J
Jackwaterveg 已提交
331
                # pcm32 -> pcm 16
H
huangyuxin 已提交
332
                audio = self._pcm32to16(audio)
H
huangyuxin 已提交
333 334 335
            else:
                audio = audio[:, 0]

H
huangyuxin 已提交
336 337 338 339
            logger.info(f"audio shape: {audio.shape}")
            # fbank
            audio = preprocessing(audio, **preprocess_args)

H
huangyuxin 已提交
340 341
            audio_len = paddle.to_tensor(audio.shape[0])
            audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
342

H
huangyuxin 已提交
343 344 345
            self._inputs["audio"] = audio
            self._inputs["audio_len"] = audio_len
            logger.info(f"audio feat shape: {audio.shape}")
H
huangyuxin 已提交
346 347 348

        else:
            raise Exception("wrong type")
K
KP 已提交
349 350

    @paddle.no_grad()
H
huangyuxin 已提交
351
    def infer(self, model_type: str):
K
KP 已提交
352
        """
小湉湉's avatar
小湉湉 已提交
353
        Model inference and result stored in self.output.
K
KP 已提交
354
        """
355

356
        cfg = self.config.decode
H
huangyuxin 已提交
357 358
        audio = self._inputs["audio"]
        audio_len = self._inputs["audio_len"]
359
        if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
360 361 362 363 364 365 366 367 368
            decode_batch_size = audio.shape[0]
            self.model.decoder.init_decoder(
                decode_batch_size, self.text_feature.vocab_list,
                cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
                cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
                cfg.num_proc_bsearch)

            result_transcripts = self.model.decode(audio, audio_len)
            self.model.decoder.del_decoder()
H
huangyuxin 已提交
369
            self._outputs["result"] = result_transcripts[0]
H
huangyuxin 已提交
370

K
KP 已提交
371
        elif "conformer" in model_type or "transformer" in model_type:
H
huangyuxin 已提交
372 373 374
            result_transcripts = self.model.decode(
                audio,
                audio_len,
375
                text_feature=self.text_feature,
H
huangyuxin 已提交
376 377 378 379 380 381
                decoding_method=cfg.decoding_method,
                beam_size=cfg.beam_size,
                ctc_weight=cfg.ctc_weight,
                decoding_chunk_size=cfg.decoding_chunk_size,
                num_decoding_left_chunks=cfg.num_decoding_left_chunks,
                simulate_streaming=cfg.simulate_streaming)
H
huangyuxin 已提交
382
            self._outputs["result"] = result_transcripts[0][0]
H
huangyuxin 已提交
383 384 385
        else:
            raise Exception("invalid model name")

K
KP 已提交
386 387 388 389
    def postprocess(self) -> Union[str, os.PathLike]:
        """
            Output postprocess and return human-readable results such as texts and audio files.
        """
H
huangyuxin 已提交
390
        return self._outputs["result"]
K
KP 已提交
391

392 393 394 395 396 397 398
    def download_lm(self, url, lm_dir, md5sum):
        download_path = get_path_from_url(
            url=url,
            root_dir=lm_dir,
            md5sum=md5sum,
            decompress=False, )

H
huangyuxin 已提交
399
    def _pcm16to32(self, audio):
H
huangyuxin 已提交
400
        assert (audio.dtype == np.int16)
H
huangyuxin 已提交
401 402 403 404 405 406
        audio = audio.astype("float32")
        bits = np.iinfo(np.int16).bits
        audio = audio / (2**(bits - 1))
        return audio

    def _pcm32to16(self, audio):
H
huangyuxin 已提交
407
        assert (audio.dtype == np.float32)
H
huangyuxin 已提交
408 409 410 411 412
        bits = np.iinfo(np.int16).bits
        audio = audio * (2**(bits - 1))
        audio = np.round(audio).astype("int16")
        return audio

413
    def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
H
huangyuxin 已提交
414 415
        self.sample_rate = sample_rate
        if self.sample_rate != 16000 and self.sample_rate != 8000:
416 417
            logger.error("invalid sample rate, please input --sr 8000 or --sr 16000")
            return False
H
huangyuxin 已提交
418

419 420 421 422 423
        if isinstance(audio_file, (str, os.PathLike)):
            if not os.path.isfile(audio_file):
                logger.error("Please input the right audio file path")
                return False

H
huangyuxin 已提交
424 425
        logger.info("checking the audio file format......")
        try:
H
huangyuxin 已提交
426
            audio, audio_sample_rate = soundfile.read(
H
huangyuxin 已提交
427 428
                audio_file, dtype="int16", always_2d=True)
        except Exception as e:
K
KP 已提交
429
            logger.exception(e)
H
huangyuxin 已提交
430 431 432 433 434 435 436 437 438
            logger.error(
                "can not open the audio file, please check the audio file format is 'wav'. \n \
                 you can try to use sox to change the file format.\n \
                 For example: \n \
                 sample rate: 16k \n \
                 sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
                 sample rate: 8k \n \
                 sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
                 ")
439
            return False
H
huangyuxin 已提交
440 441
        logger.info("The sample rate is %d" % audio_sample_rate)
        if audio_sample_rate != self.sample_rate:
H
huangyuxin 已提交
442
            logger.warning("The sample rate of the input file is not {}.\n \
H
huangyuxin 已提交
443 444
                            The program will resample the wav file to {}.\n \
                            If the result does not meet your expectations,\n \
H
revise  
huangyuxin 已提交
445
                            Please input the 16k 16 bit 1 channel wav file. \
H
huangyuxin 已提交
446
                        ".format(self.sample_rate, self.sample_rate))
447
            if force_yes is False:
448
                while (True):
H
huangyuxin 已提交
449
                    logger.info(
450 451 452 453
                        "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
                    )
                    content = input("Input(Y/N):")
                    if content.strip() == "Y" or content.strip(
454 455
                    ) == "y" or content.strip() == "yes" or content.strip(
                    ) == "Yes":
456
                        logger.info(
457 458
                            "change the sampele rate, channel to 16k and 1 channel"
                        )
459 460
                        break
                    elif content.strip() == "N" or content.strip(
461 462
                    ) == "n" or content.strip() == "no" or content.strip(
                    ) == "No":
463 464 465 466
                        logger.info("Exit the program")
                        exit(1)
                    else:
                        logger.warning("Not regular input, please input again")
H
huangyuxin 已提交
467 468 469 470 471 472

            self.change_format = True
        else:
            logger.info("The audio file format is right")
            self.change_format = False

473 474
        return True

K
KP 已提交
475
    def execute(self, argv: List[str]) -> bool:
K
KP 已提交
476 477 478
        """
            Command line entry.
        """
H
huangyuxin 已提交
479
        parser_args = self.parser.parse_args(argv)
K
KP 已提交
480

H
huangyuxin 已提交
481 482
        model = parser_args.model
        lang = parser_args.lang
K
KP 已提交
483
        sample_rate = parser_args.sample_rate
H
huangyuxin 已提交
484 485
        config = parser_args.config
        ckpt_path = parser_args.ckpt_path
486
        decode_method = parser_args.decode_method
487
        force_yes = parser_args.yes
H
huangyuxin 已提交
488
        device = parser_args.device
K
KP 已提交
489

K
KP 已提交
490
        if not parser_args.verbose:
K
KP 已提交
491
            self.disable_task_loggers()
K
KP 已提交
492

K
KP 已提交
493 494 495 496 497 498 499 500 501 502 503 504 505 506
        task_source = self.get_task_source(parser_args.input)
        task_results = OrderedDict()
        has_exceptions = False

        for id_, input_ in task_source.items():
            try:
                res = self(input_, model, lang, sample_rate, config, ckpt_path,
                           decode_method, force_yes, device)
                task_results[id_] = res
            except Exception as e:
                has_exceptions = True
                task_results[id_] = f'{e.__class__.__name__}: {e}'

        self.process_task_results(parser_args.input, task_results,
K
KP 已提交
507
                                  parser_args.job_dump_result)
K
KP 已提交
508 509

        if has_exceptions:
K
KP 已提交
510
            return False
K
KP 已提交
511 512
        else:
            return True
H
huangyuxin 已提交
513

K
KP 已提交
514
    @stats_wrapper
K
KP 已提交
515 516 517 518 519 520 521
    def __call__(self,
                 audio_file: os.PathLike,
                 model: str='conformer_wenetspeech',
                 lang: str='zh',
                 sample_rate: int=16000,
                 config: os.PathLike=None,
                 ckpt_path: os.PathLike=None,
522
                 decode_method: str='attention_rescoring',
523
                 force_yes: bool=False,
K
KP 已提交
524
                 device=paddle.get_device()):
K
KP 已提交
525
        """
小湉湉's avatar
小湉湉 已提交
526
        Python API to call an executor.
K
KP 已提交
527
        """
H
huangyuxin 已提交
528
        audio_file = os.path.abspath(audio_file)
529 530
        if not self._check(audio_file, sample_rate, force_yes):
            sys.exit(-1)
K
KP 已提交
531
        paddle.set_device(device)
532 533
        self._init_from_path(model, lang, sample_rate, config, decode_method,
                             ckpt_path)
H
huangyuxin 已提交
534 535
        self.preprocess(model, audio_file)
        self.infer(model)
K
KP 已提交
536
        res = self.postprocess()  # Retrieve result of asr.
H
huangyuxin 已提交
537

K
KP 已提交
538
        return res