diff --git a/paddlespeech/cli/__init__.py b/paddlespeech/cli/__init__.py index ddf0359bc5fcb7ff80b437a65112869d7faa12eb..ca6993f2b003054062cb99f37675ad7009f70d32 100644 --- a/paddlespeech/cli/__init__.py +++ b/paddlespeech/cli/__init__.py @@ -13,14 +13,7 @@ # limitations under the License. import _locale -from .asr import ASRExecutor from .base_commands import BaseCommand from .base_commands import HelpCommand -from .cls import CLSExecutor -from .st import STExecutor -from .stats import StatsExecutor -from .text import TextExecutor -from .tts import TTSExecutor -from .vector import VectorExecutor _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 2d74afa6d72e166aacbe98003ba4db3e80c4b130..09e8202fd7d54e59b9c535b6b5a598123147f539 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -29,7 +29,6 @@ from yacs.config import CfgNode from ..download import get_path_from_url from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import CLI_TIMER from ..utils import MODEL_HOME from ..utils import stats_wrapper @@ -45,8 +44,6 @@ __all__ = ['ASRExecutor'] @timer_register -@cli_register( - name='paddlespeech.asr', description='Speech to text infer command.') class ASRExecutor(BaseExecutor): def __init__(self): super().__init__() diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index 0a26b12030a0b25fe169be5ad3bc61e82c500fa7..4d4d2cc69b0eec34b626a84ee237c1c4c4c540a2 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -15,6 +15,7 @@ from typing import List from .entry import commands from .utils import cli_register +from .utils import explicit_command_register from .utils import get_command __all__ = [ @@ -73,3 +74,20 @@ class VersionCommand: print(msg) return True + + +# Dynamic import when running specific command +_commands = { + 'asr': ['Speech to text infer command.', 'ASRExecutor'], + 'cls': ['Audio classification infer command.', 'CLSExecutor'], + 'st': ['Speech translation infer command.', 'STExecutor'], + 'text': ['Text command.', 'TextExecutor'], + 'tts': ['Text to Speech infer command.', 'TTSExecutor'], + 'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'], +} + +for com, info in _commands.items(): + explicit_command_register( + name='paddlespeech.{}'.format(com), + description=info[0], + cls='paddlespeech.cli.{}.{}'.format(com, info[1])) diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index 40072d9974e5798dc5d74b921efb905230e06246..3d807b60b3d03d4620875582b41f26f5b699c45b 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -27,7 +27,6 @@ from paddlespeech.utils.dynamic_import import dynamic_import from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import stats_wrapper from .pretrained_models import model_alias from .pretrained_models import pretrained_models @@ -36,8 +35,6 @@ from .pretrained_models import pretrained_models __all__ = ['CLSExecutor'] -@cli_register( - name='paddlespeech.cls', description='Audio classification infer command.') class CLSExecutor(BaseExecutor): def __init__(self): super().__init__() @@ -246,4 +243,4 @@ class CLSExecutor(BaseExecutor): self.infer() res = self.postprocess(topk) # Retrieve result of cls. - return res \ No newline at end of file + return res diff --git a/paddlespeech/cli/entry.py b/paddlespeech/cli/entry.py index 32123ece750457dac8ca90aff1a8731fea569188..e0c306d62a7d55b8a48a147fa7d13dcda866ab79 100644 --- a/paddlespeech/cli/entry.py +++ b/paddlespeech/cli/entry.py @@ -34,6 +34,11 @@ def _execute(): # The method 'execute' of a command instance returns 'True' for a success # while 'False' for a failure. Here converts this result into a exit status # in bash: 0 for a success and 1 for a failure. + if not callable(com['_entry']): + i = com['_entry'].rindex('.') + module, cls = com['_entry'][:i], com['_entry'][i + 1:] + exec("from {} import {}".format(module, cls)) + com['_entry'] = locals()[cls] status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1 return status diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index 4f210fbe685df50236379e196639395b2ce4adf2..ae188b349632bd7af811b363143ed845db012392 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -28,7 +28,6 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import download_and_decompress from ..utils import MODEL_HOME from ..utils import stats_wrapper @@ -42,8 +41,6 @@ from paddlespeech.utils.dynamic_import import dynamic_import __all__ = ["STExecutor"] -@cli_register( - name="paddlespeech.st", description="Speech translation infer command.") class STExecutor(BaseExecutor): def __init__(self): super().__init__() diff --git a/paddlespeech/cli/text/infer.py b/paddlespeech/cli/text/infer.py index 97f3bbe21346dfa6651c773e44eb293c4baa841a..be5b5a10d474c2a50d448d955ae8cead3a13202b 100644 --- a/paddlespeech/cli/text/infer.py +++ b/paddlespeech/cli/text/infer.py @@ -23,7 +23,6 @@ import paddle from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import stats_wrapper from .pretrained_models import model_alias from .pretrained_models import pretrained_models @@ -33,7 +32,6 @@ from paddlespeech.utils.dynamic_import import dynamic_import __all__ = ['TextExecutor'] -@cli_register(name='paddlespeech.text', description='Text infer command.') class TextExecutor(BaseExecutor): def __init__(self): super().__init__() diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index efab9cb258a34c8ee7f003fd6f1a21fe8e77a126..5fa9b3ed0f0a32ced3e28672210a5d3318a16d69 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -28,7 +28,6 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import stats_wrapper from .pretrained_models import model_alias from .pretrained_models import pretrained_models @@ -40,8 +39,6 @@ from paddlespeech.utils.dynamic_import import dynamic_import __all__ = ['TTSExecutor'] -@cli_register( - name='paddlespeech.tts', description='Text to Speech infer command.') class TTSExecutor(BaseExecutor): def __init__(self): super().__init__() diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index e7b499f728c3d93ecdfc3bd8fdf92559ce59845a..128767e627091dc636da2900e5e65b58bdd650ca 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -41,6 +41,7 @@ requests.adapters.DEFAULT_RETRIES = 3 __all__ = [ 'timer_register', 'cli_register', + 'explicit_command_register', 'get_command', 'download_and_decompress', 'load_state_dict_from_url', @@ -70,6 +71,16 @@ def cli_register(name: str, description: str='') -> Any: return _warpper +def explicit_command_register(name: str, description: str='', cls: str=''): + items = name.split('.') + com = commands + for item in items: + com = com[item] + com['_entry'] = cls + if description: + com['_description'] = description + + def get_command(name: str) -> Any: items = name.split('.') com = commands diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index cc664369fa2a82f929e8112610955e1aa727a6d2..07fb73a4c839864a63a1561324a67da55e9df80d 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -28,7 +28,6 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import stats_wrapper from .pretrained_models import model_alias from .pretrained_models import pretrained_models @@ -37,9 +36,6 @@ from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.modules.sid_model import SpeakerIdetification -@cli_register( - name="paddlespeech.vector", - description="Speech to vector embedding infer command.") class VectorExecutor(BaseExecutor): def __init__(self): super().__init__() @@ -476,4 +472,4 @@ class VectorExecutor(BaseExecutor): else: logger.info("The audio file format is right") - return True \ No newline at end of file + return True