diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index e5c64e9ab38d7dc9b4e421b67043c1be1b33fbc4..a0ae535070b86d9d4635a34f4d41fdbf999b14e3 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -18,17 +18,17 @@ from typing import List from typing import Optional from typing import Union +import librosa import paddle import soundfile +from yacs.config import CfgNode from ..executor import BaseExecutor from ..utils import cli_register from ..utils import download_and_decompress from ..utils import logger from ..utils import MODEL_HOME -from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig @@ -36,7 +36,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] pretrained_models = { - "wenetspeech_zh": { + "wenetspeech_zh_16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', 'md5': @@ -73,7 +73,15 @@ class ASRExecutor(BaseExecutor): default='wenetspeech', help='Choose model type of asr task.') self.parser.add_argument( - '--lang', type=str, default='zh', help='Choose model language.') + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + self.parser.add_argument( + "--model_sample_rate", + type=int, + default=16000, + help='Choose the audio sample rate of the model. 8000 or 16000') self.parser.add_argument( '--config', type=str, @@ -109,13 +117,15 @@ class ASRExecutor(BaseExecutor): def _init_from_path(self, model_type: str='wenetspeech', lang: str='zh', + model_sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. """ if cfg_path is None or ckpt_path is None: - tag = model_type + '_' + lang + model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k' + tag = model_type + '_' + lang + '_' + model_sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) @@ -136,23 +146,24 @@ class ASRExecutor(BaseExecutor): #Init body. parser_args = self.parser_args paddle.set_device(parser_args.device) - self.config = get_cfg_defaults() + self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" - #self.config.freeze() model_conf = self.config.model logger.info(model_conf) with UpdateConfig(model_conf): if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + from paddlespeech.s2t.io.collator import SpeechCollator self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.cmvn_path) self.collate_fn_test = SpeechCollator.from_config(self.config) - model_conf.feat_size = self.collate_fn_test.feature_size - model_conf.dict_size = self.text_feature.vocab_size + model_conf.input_dim = self.collate_fn_test.feature_size + model_conf.output_dim = self.text_feature.vocab_size elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": + self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) self.text_feature = TextFeaturizer( @@ -163,6 +174,7 @@ class ASRExecutor(BaseExecutor): model_conf.output_dim = self.text_feature.vocab_size else: raise Exception("wrong type") + self.config.freeze() model_class = dynamic_import(parser_args.model, model_alias) model = model_class.from_config(model_conf) self.model = model @@ -182,13 +194,13 @@ class ASRExecutor(BaseExecutor): parser_args = self.parser_args config = self.config audio_file = input - logger.info("audio_file" + audio_file) + logger.info("Preprocess audio_file:" + audio_file) self.sr = config.collator.target_sample_rate # Get the object for feature extraction if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": - audio, _ = collate_fn_test.process_utterance( + audio, _ = self.collate_fn_test.process_utterance( audio_file=audio_file, transcript=" ") audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') @@ -203,18 +215,30 @@ class ASRExecutor(BaseExecutor): os.path.dirname(os.path.abspath(self.cfg_path)), "preprocess.yaml") - cmvn_path: data / mean_std.json - logger.info(preprocess_conf) preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) + logger.info("read the audio file") audio, sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) + + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1) + else: + audio = audio[:, 0] + audio = audio.astype("float32") + audio = librosa.resample(audio, sample_rate, + self.target_sample_rate) + sample_rate = self.target_sample_rate + audio = audio.astype("int16") + else: + audio = audio[:, 0] + if sample_rate != self.sr: logger.error( f"sample rate error: {sample_rate}, need {self.sr} ") sys.exit(-1) - audio = audio[:, 0] logger.info(f"audio shape: {audio.shape}") # fbank audio = preprocessing(audio, **preprocess_args) @@ -282,6 +306,63 @@ class ASRExecutor(BaseExecutor): """ return self.result_transcripts + def _check(self, audio_file: str, model_sample_rate: int): + self.target_sample_rate = model_sample_rate + if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: + logger.error( + "please input --model_sample_rate 8000 or --model_sample_rate 16000") + raise Exception("invalid sample rate") + sys.exit(-1) + + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + sys.exit(-1) + + logger.info("checking the audio file format......") + try: + sig, sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + except Exception as e: + logger.error(str(e)) + logger.error( + "can not open the audio file, please check the audio file format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + sys.exit(-1) + logger.info("The sample rate is %d" % sample_rate) + if sample_rate != self.target_sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16bit 1 channel wav file. \ + ".format(self.target_sample_rate, self.target_sample_rate)) + while (True): + logger.info( + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip() == "Yes": + logger.info( + "change the sampele rate, channel to 16k and 1 channel") + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip() == "No": + logger.info("Exit the program") + exit(1) + else: + logger.warning("Not regular input, please input again") + + self.change_format = True + else: + logger.info("The audio file format is right") + self.change_format = False + def execute(self, argv: List[str]) -> bool: """ Command line entry. @@ -290,24 +371,28 @@ class ASRExecutor(BaseExecutor): model = self.parser_args.model lang = self.parser_args.lang + model_sample_rate = self.parser_args.model_sample_rate config = self.parser_args.config ckpt_path = self.parser_args.ckpt_path audio_file = os.path.abspath(self.parser_args.input) device = self.parser_args.device try: - res = self(model, lang, config, ckpt_path, audio_file, device) + res = self(model, lang, model_sample_rate, config, ckpt_path, audio_file, + device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: print(e) return False - def __call__(self, model, lang, config, ckpt_path, audio_file, device): + def __call__(self, model, lang, model_sample_rate, config, ckpt_path, audio_file, + device): """ Python API to call an executor. """ - self._init_from_path(model, lang, config, ckpt_path) + self._check(audio_file, model_sample_rate) + self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) self.preprocess(audio_file) self.infer() res = self.postprocess() # Retrieve result of asr.