infer.py 3.1 KB
Newer Older
K
KP 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from typing import List
from typing import Optional
from typing import Union

import paddle

from ..executor import BaseExecutor
from ..utils import cli_register

__all__ = ['S2TExecutor']


K
KP 已提交
28 29
@cli_register(
    name='paddlespeech.s2t', description='Speech to text infer command.')
K
KP 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
class S2TExecutor(BaseExecutor):
    def __init__(self):
        super(S2TExecutor, self).__init__()

        self.parser = argparse.ArgumentParser(
            prog='paddlespeech.s2t', add_help=True)
        self.parser.add_argument(
            '--config',
            type=str,
            default=None,
            help='Config of s2t task. Use deault config when it is None.')
        self.parser.add_argument(
            '--input', type=str, help='Audio file to recognize.')
        self.parser.add_argument(
            '--device',
            type=str,
            default='cpu',
            help='Choose device to execute model inference.')

    def _get_default_cfg_path(self):
        """
            Returns a default config file path of current task.
        """
        pass

    def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None):
        """
            Init model from a specific config file.
        """
        pass

    def preprocess(self, input: Union[str, os.PathLike]):
        """
            Input preprocess and return paddle.Tensor stored in self.input.
            Input content can be a text(t2s), a file(s2t, cls) or a streaming(not supported yet).
        """
        pass

    @paddle.no_grad()
    def infer(self):
        """
            Model inference and result stored in self.output.
        """
        pass

    def postprocess(self) -> Union[str, os.PathLike]:
        """
            Output postprocess and return human-readable results such as texts and audio files.
        """
        pass

    def execute(self, argv: List[str]) -> bool:
        parser_args = self.parser.parse_args(argv)
        print(parser_args)

        config = parser_args.config
        audio_file = parser_args.input
        device = parser_args.device

        if config is not None:
            assert os.path.isfile(config), 'Config file is not valid.'
        else:
            config = self._get_default_cfg_path()

        try:
            self._init_from_cfg(config)
            self.preprocess(audio_file)
            self.infer()
            res = self.postprocess()  # Retrieve result of s2t.
            print(res)
            return True
        except Exception as e:
            print(e)
            return False