diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 00216356c7d5f2124a08abb01517467285ca08de..e9d8c0b11b3adb1a348966e0dda22276e49230da 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -22,6 +22,7 @@ import librosa import paddle import soundfile from yacs.config import CfgNode +import numpy as np from ..executor import BaseExecutor from ..utils import cli_register @@ -78,9 +79,10 @@ class ASRExecutor(BaseExecutor): default='zh', help='Choose model language. zh or en') self.parser.add_argument( - "--model_sample_rate", + "--sr", type=int, default=16000, + choices=[8000, 16000], help='Choose the audio sample rate of the model. 8000 or 16000') self.parser.add_argument( '--config', @@ -117,26 +119,27 @@ class ASRExecutor(BaseExecutor): def _init_from_path(self, model_type: str='wenetspeech', lang: str='zh', - model_sample_rate: int=16000, + sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, - ckpt_path: Optional[os.PathLike]=None): + ckpt_path: Optional[os.PathLike]=None + ): """ Init model and other resources from a specific path. """ if cfg_path is None or ckpt_path is None: - model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k' - tag = model_type + '_' + lang + '_' + model_sample_rate_str + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + tag = model_type + '_' + lang + '_' + sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) self.ckpt_path = os.path.join(res_path, - pretrained_models[tag]['ckpt_path']) + pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) logger.info(self.cfg_path) logger.info(self.ckpt_path) else: self.cfg_path = os.path.abspath(cfg_path) - self.ckpt_path = os.path.abspath(ckpt_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) @@ -182,8 +185,7 @@ class ASRExecutor(BaseExecutor): self.model.eval() # load model - params_path = self.ckpt_path + ".pdparams" - model_dict = paddle.load(params_path) + model_dict = paddle.load(self.ckpt_path) self.model.set_state_dict(model_dict) def preprocess(self, model_type: str, input: Union[str, os.PathLike]): @@ -195,8 +197,6 @@ class ASRExecutor(BaseExecutor): audio_file = input logger.info("Preprocess audio_file:" + audio_file) - config_target_sample_rate = self.config.collator.target_sample_rate - # Get the object for feature extraction if model_type == "ds2_online" or model_type == "ds2_offline": audio, _ = self.collate_fn_test.process_utterance( @@ -220,7 +220,7 @@ class ASRExecutor(BaseExecutor): preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) logger.info("read the audio file") - audio, sample_rate = soundfile.read( + audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) if self.change_format: @@ -229,17 +229,13 @@ class ASRExecutor(BaseExecutor): else: audio = audio[:, 0] audio = audio.astype("float32") - audio = librosa.resample(audio, sample_rate, - self.target_sample_rate) - sample_rate = self.target_sample_rate - audio = audio.astype("int16") + audio = librosa.resample(audio, audio_sample_rate, + self.sample_rate) + audio_sample_rate = self.sample_rate + audio = np.round(audio).astype("int16") else: audio = audio[:, 0] - if sample_rate != config_target_sample_rate: - logger.error( - f"sample rate error: {sample_rate}, need {self.sr} ") - sys.exit(-1) logger.info(f"audio shape: {audio.shape}") # fbank audio = preprocessing(audio, **preprocess_args) @@ -311,11 +307,11 @@ class ASRExecutor(BaseExecutor): """ return self._outputs["result"] - def _check(self, audio_file: str, model_sample_rate: int): - self.target_sample_rate = model_sample_rate - if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: + def _check(self, audio_file: str, sample_rate: int): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: logger.error( - "please input --model_sample_rate 8000 or --model_sample_rate 16000" + "please input --sr 8000 or --sr 16000" ) raise Exception("invalid sample rate") sys.exit(-1) @@ -326,7 +322,7 @@ class ASRExecutor(BaseExecutor): logger.info("checking the audio file format......") try: - sig, sample_rate = soundfile.read( + audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) except Exception as e: logger.error(str(e)) @@ -340,15 +336,15 @@ class ASRExecutor(BaseExecutor): sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ ") sys.exit(-1) - logger.info("The sample rate is %d" % sample_rate) - if sample_rate != self.target_sample_rate: + logger.info("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: logger.warning( "The sample rate of the input file is not {}.\n \ The program will resample the wav file to {}.\n \ If the result does not meet your expectations,\n \ Please input the 16k 16bit 1 channel wav file. \ " - .format(self.target_sample_rate, self.target_sample_rate)) + .format(self.sample_rate, self.sample_rate)) while (True): logger.info( "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." @@ -379,14 +375,14 @@ class ASRExecutor(BaseExecutor): model = parser_args.model lang = parser_args.lang - model_sample_rate = parser_args.model_sample_rate + sample_rate = parser_args.sr config = parser_args.config ckpt_path = parser_args.ckpt_path audio_file = parser_args.input device = parser_args.device try: - res = self(model, lang, model_sample_rate, config, ckpt_path, + res = self(model, lang, sample_rate, config, ckpt_path, audio_file, device) logger.info('ASR Result: {}'.format(res)) return True @@ -394,16 +390,15 @@ class ASRExecutor(BaseExecutor): print(e) return False - def __call__(self, model, lang, model_sample_rate, config, ckpt_path, + def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file, device): """ Python API to call an executor. """ audio_file = os.path.abspath(audio_file) - self._check(audio_file, model_sample_rate) - + self._check(audio_file, sample_rate) paddle.set_device(device) - self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) + self._init_from_path(model, lang, sample_rate, config, ckpt_path) self.preprocess(model, audio_file) self.infer(model) res = self.postprocess() # Retrieve result of asr.