diff --git a/paddlespeech/cli/__init__.py b/paddlespeech/cli/__init__.py index 0aedab576241b403c2e03286030a07c92117c33a..246d0f381279dd7263415c9e1a6c7ad76b8dbb53 100644 --- a/paddlespeech/cli/__init__.py +++ b/paddlespeech/cli/__init__.py @@ -15,3 +15,4 @@ from .asr import ASRExecutor from .base_commands import BaseCommand from .base_commands import HelpCommand from .cls import CLSExecutor +from .st import STExecutor diff --git a/paddlespeech/cli/st/__init__.py b/paddlespeech/cli/st/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cdb4e34826d50d9898d79e2ffe09c369026b3e5 --- /dev/null +++ b/paddlespeech/cli/st/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import STExecutor diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..534b9e3b9458f75fd89710ee20dd4b06feb38c4d --- /dev/null +++ b/paddlespeech/cli/st/infer.py @@ -0,0 +1,356 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import subprocess +from typing import List +from typing import Optional +from typing import Union + +import kaldi_io +import numpy as np +import paddle +import soundfile +from kaldiio import WriteHelper +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.utility import UpdateConfig +from yacs.config import CfgNode + +from ..executor import BaseExecutor +from ..utils import cli_register +from ..utils import download_and_decompress +from ..utils import logger +from ..utils import MODEL_HOME + +__all__ = ["STExecutor"] + +pretrained_models = { + "fat_st_ted_en-zh": { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz", + "md5": + "fa0a7425b91b4f8d259c70b2aca5ae67", + "cfg_path": + "conf/transformer_mtl_noam.yaml", + "ckpt_path": + "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", + } +} + +model_alias = {"fat_st_ted": "paddlespeech.s2t.models.u2_st:U2STModel"} + +kaldi_bins = { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", + "md5": + "c0682303b3f3393dbf6ed4c4e35a53eb", +} + + +@cli_register( + name="paddlespeech.st", description="Speech translation infer command.") +class STExecutor(BaseExecutor): + def __init__(self): + super(STExecutor, self).__init__() + + self.parser = argparse.ArgumentParser( + prog="paddlespeech.st", add_help=True) + self.parser.add_argument( + "--input", type=str, required=True, help="Audio file to translate.") + self.parser.add_argument( + "--model_type", + type=str, + default="fat_st_ted", + help="Choose model type of st task.") + self.parser.add_argument( + "--src_lang", + type=str, + default="en", + help="Choose model source language.") + self.parser.add_argument( + "--tgt_lang", + type=str, + default="zh", + help="Choose model target language.") + self.parser.add_argument( + "--sample_rate", + type=int, + default=16000, + choices=[16000], + help='Choose the audio sample rate of the model. 8000 or 16000') + self.parser.add_argument( + "--cfg_path", + type=str, + default=None, + help="Config of st task. Use deault config when it is None.") + self.parser.add_argument( + "--ckpt_path", + type=str, + default=None, + help="Checkpoint file of model.") + self.parser.add_argument( + "--device", + type=str, + default=paddle.get_device(), + help="Choose device to execute model inference.") + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + assert tag in pretrained_models, "Can not find pretrained resources of {}.".format( + tag) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + "Use pretrained model stored in: {}".format(decompressed_path)) + + return decompressed_path + + def _set_kaldi_bins(self) -> os.PathLike: + """ + Download and returns kaldi_bins resources path of current task. + """ + decompressed_path = download_and_decompress(kaldi_bins, MODEL_HOME) + decompressed_path = os.path.abspath(decompressed_path) + logger.info("Kaldi_bins stored in: {}".format(decompressed_path)) + if "LD_LIBRARY_PATH" in os.environ: + os.environ["LD_LIBRARY_PATH"] += f":{decompressed_path}" + else: + os.environ["LD_LIBRARY_PATH"] = f"{decompressed_path}" + os.environ["PATH"] += f":{decompressed_path}" + return decompressed_path + + def _init_from_path(self, + model_type: str="fat_st_ted", + src_lang: str="en", + tgt_lang: str="zh", + cfg_path: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None): + """ + Init model and other resources from a specific path. + """ + if hasattr(self, 'model'): + logger.info('Model had been initialized.') + return + + if cfg_path is None or ckpt_path is None: + tag = model_type + "_" + src_lang + "-" + tgt_lang + res_path = self._get_pretrained_path(tag) + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]["cfg_path"]) + self.ckpt_path = os.path.join(res_path, + pretrained_models[tag]["ckpt_path"]) + logger.info(res_path) + logger.info(self.cfg_path) + logger.info(self.ckpt_path) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path) + res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + self.config.decoding.decoding_method = "fullsentence" + + with UpdateConfig(self.config): + self.config.collator.vocab_filepath = os.path.join( + res_path, self.config.collator.vocab_filepath) + self.config.collator.cmvn_path = os.path.join( + res_path, self.config.collator.cmvn_path) + self.config.collator.spm_model_prefix = os.path.join( + res_path, self.config.collator.spm_model_prefix) + self.text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) + self.config.model.input_dim = self.config.collator.feat_dim + self.config.model.output_dim = self.text_feature.vocab_size + + model_conf = self.config.model + logger.info(model_conf) + model_class = dynamic_import(model_type, model_alias) + self.model = model_class.from_config(model_conf) + self.model.eval() + + # load model + params_path = self.ckpt_path + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + + # set kaldi bins + self._set_kaldi_bins() + + def _check(self, audio_file: str, sample_rate: int): + _, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + if audio_sample_rate != sample_rate: + raise Exception("invalid sample rate") + sys.exit(-1) + + def preprocess(self, wav_file: Union[str, os.PathLike], model_type: str): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a file(wav). + """ + audio_file = os.path.abspath(wav_file) + logger.info("Preprocess audio_file:" + audio_file) + + if model_type == "fat_st_ted": + cmvn = self.config.collator.cmvn_path + utt_name = "_tmp" + + # Get the object for feature extraction + fbank_extract_command = [ + "compute-fbank-feats", "--num-mel-bins=80", "--verbose=2", + "--sample-frequency=16000", "scp:-", "ark:-" + ] + fbank_extract_process = subprocess.Popen( + fbank_extract_command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + fbank_extract_process.stdin.write( + f"{utt_name} {wav_file}".encode("utf8")) + fbank_extract_process.stdin.close() + fbank_feat = dict( + kaldi_io.read_mat_ark(fbank_extract_process.stdout))[utt_name] + + extract_command = ["compute-kaldi-pitch-feats", "scp:-", "ark:-"] + pitch_extract_process = subprocess.Popen( + extract_command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + pitch_extract_process.stdin.write( + f"{utt_name} {wav_file}".encode("utf8")) + process_command = ["process-kaldi-pitch-feats", "ark:", "ark:-"] + pitch_process = subprocess.Popen( + process_command, + stdin=pitch_extract_process.stdout, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + pitch_extract_process.stdin.close() + pitch_feat = dict( + kaldi_io.read_mat_ark(pitch_process.stdout))[utt_name] + concated_feat = np.concatenate((fbank_feat, pitch_feat), axis=1) + raw_feat = f"{utt_name}.raw" + with WriteHelper( + f"ark,scp:{raw_feat}.ark,{raw_feat}.scp") as writer: + writer(utt_name, concated_feat) + cmvn_command = [ + "apply-cmvn", "--norm-vars=true", cmvn, f"scp:{raw_feat}.scp", + "ark:-" + ] + cmvn_process = subprocess.Popen( + cmvn_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + process_command = [ + "copy-feats", "--compress=true", "ark:-", "ark:-" + ] + process = subprocess.Popen( + process_command, + stdin=cmvn_process.stdout, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + norm_feat = dict(kaldi_io.read_mat_ark(process.stdout))[utt_name] + self._inputs["audio"] = paddle.to_tensor(norm_feat).unsqueeze(0) + self._inputs["audio_len"] = paddle.to_tensor( + self._inputs["audio"].shape[1], dtype="int64") + else: + raise ValueError("Wrong model type.") + + @paddle.no_grad() + def infer(self, model_type: str): + """ + Model inference and result stored in self.output. + """ + cfg = self.config.decoding + audio = self._inputs["audio"] + audio_len = self._inputs["audio_len"] + if model_type == "fat_st_ted": + hyps = self.model.decode( + audio, + audio_len, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + lang_model_path=None, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch, + ctc_weight=cfg.ctc_weight, + word_reward=cfg.word_reward, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + self._outputs["result"] = hyps + else: + raise ValueError("Wrong model type.") + + def postprocess(self, model_type: str) -> Union[str, os.PathLike]: + """ + Output postprocess and return human-readable results such as texts and audio files. + """ + if model_type == "fat_st_ted": + return self._outputs["result"] + else: + raise ValueError("Wrong model type.") + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + + model_type = parser_args.model_type + src_lang = parser_args.src_lang + tgt_lang = parser_args.tgt_lang + sample_rate = parser_args.sample_rate + cfg_path = parser_args.cfg_path + ckpt_path = parser_args.ckpt_path + audio_file = parser_args.input + device = parser_args.device + + try: + res = self(model_type, src_lang, tgt_lang, sample_rate, cfg_path, + ckpt_path, audio_file, device) + logger.info("ST Result: {}".format(res)) + return True + except Exception as e: + print(e) + return False + + def __call__(self, model_type, src_lang, tgt_lang, sample_rate, cfg_path, + ckpt_path, audio_file, device): + """ + Python API to call an executor. + """ + audio_file = os.path.abspath(audio_file) + self._check(audio_file, sample_rate) + paddle.set_device(device) + self._init_from_path(model_type, src_lang, tgt_lang, cfg_path, + ckpt_path) + self.preprocess(audio_file, model_type) + self.infer(model_type) + res = self.postprocess(model_type) + + return res diff --git a/setup.py b/setup.py index 7720ba3fd9d4e64ef4f85f1b37919f896bb546d2..c1ae00ce8fda97e0add24213359bdfd0fc9ad87e 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ requirements = { "jieba", "jsonlines", "kaldiio", + "kaldi_io", "librosa", "loguru", "matplotlib",