diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index ea1828b6b0a7123e80d8330e9fd261aa38c478d0..c9ec058cd8741cb93c39ec5fe764bcc28d8065d3 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -78,7 +78,7 @@ class ASRExecutor(BaseExecutor): default='zh', help='Choose model language. zh or en') self.parser.add_argument( - "--model_sample_rate", + "--sr", type=int, default=16000, help='Choose the audio sample rate of the model. 8000 or 16000') @@ -117,7 +117,7 @@ class ASRExecutor(BaseExecutor): def _init_from_path(self, model_type: str='wenetspeech', lang: str='zh', - model_sample_rate: int=16000, + sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None, device: str='cpu'): @@ -125,8 +125,8 @@ class ASRExecutor(BaseExecutor): Init model and other resources from a specific path. """ if cfg_path is None or ckpt_path is None: - model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k' - tag = model_type + '_' + lang + '_' + model_sample_rate_str + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + tag = model_type + '_' + lang + '_' + sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) @@ -197,8 +197,6 @@ class ASRExecutor(BaseExecutor): audio_file = input logger.info("Preprocess audio_file:" + audio_file) - config_target_sample_rate = self.config.collator.target_sample_rate - # Get the object for feature extraction if model_type == "ds2_online" or model_type == "ds2_offline": audio, _ = self.collate_fn_test.process_utterance( @@ -222,7 +220,7 @@ class ASRExecutor(BaseExecutor): preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) logger.info("read the audio file") - audio, sample_rate = soundfile.read( + audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) if self.change_format: @@ -231,17 +229,13 @@ class ASRExecutor(BaseExecutor): else: audio = audio[:, 0] audio = audio.astype("float32") - audio = librosa.resample(audio, sample_rate, - self.target_sample_rate) - sample_rate = self.target_sample_rate + audio = librosa.resample(audio, audio_sample_rate, + self.sample_rate) + audio_sample_rate = self.sample_rate audio = audio.astype("int16") else: audio = audio[:, 0] - if sample_rate != config_target_sample_rate: - logger.error( - f"sample rate error: {sample_rate}, need {self.sr} ") - sys.exit(-1) logger.info(f"audio shape: {audio.shape}") # fbank audio = preprocessing(audio, **preprocess_args) @@ -313,11 +307,11 @@ class ASRExecutor(BaseExecutor): """ return self._outputs["result"] - def _check(self, audio_file: str, model_sample_rate: int): - self.target_sample_rate = model_sample_rate - if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: + def _check(self, audio_file: str, sample_rate: int): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: logger.error( - "please input --model_sample_rate 8000 or --model_sample_rate 16000" + "please input --sr 8000 or --sr 16000" ) raise Exception("invalid sample rate") sys.exit(-1) @@ -328,7 +322,7 @@ class ASRExecutor(BaseExecutor): logger.info("checking the audio file format......") try: - sig, sample_rate = soundfile.read( + audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) except Exception as e: logger.error(str(e)) @@ -342,15 +336,15 @@ class ASRExecutor(BaseExecutor): sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ ") sys.exit(-1) - logger.info("The sample rate is %d" % sample_rate) - if sample_rate != self.target_sample_rate: + logger.info("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: logger.warning( "The sample rate of the input file is not {}.\n \ The program will resample the wav file to {}.\n \ If the result does not meet your expectations,\n \ Please input the 16k 16bit 1 channel wav file. \ " - .format(self.target_sample_rate, self.target_sample_rate)) + .format(self.sample_rate, self.sample_rate)) while (True): logger.info( "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." @@ -381,14 +375,14 @@ class ASRExecutor(BaseExecutor): model = parser_args.model lang = parser_args.lang - model_sample_rate = parser_args.model_sample_rate + sample_rate = parser_args.sr config = parser_args.config ckpt_path = parser_args.ckpt_path audio_file = parser_args.input device = parser_args.device try: - res = self(model, lang, model_sample_rate, config, ckpt_path, + res = self(model, lang, sample_rate, config, ckpt_path, audio_file, device) logger.info('ASR Result: {}'.format(res)) return True @@ -396,14 +390,14 @@ class ASRExecutor(BaseExecutor): print(e) return False - def __call__(self, model, lang, model_sample_rate, config, ckpt_path, + def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file, device): """ Python API to call an executor. """ audio_file = os.path.abspath(audio_file) - self._check(audio_file, model_sample_rate) - self._init_from_path(model, lang, model_sample_rate, config, ckpt_path, + self._check(audio_file, sample_rate) + self._init_from_path(model, lang, sample_rate, config, ckpt_path, device) self.preprocess(model, audio_file) self.infer(model)