未验证 提交 8828210f 编写于 作者: K KP 提交者: GitHub

Merge pull request #3 from Jackwaterveg/cli_infer

revise the sample rate
...@@ -22,6 +22,7 @@ import librosa ...@@ -22,6 +22,7 @@ import librosa
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
...@@ -78,9 +79,10 @@ class ASRExecutor(BaseExecutor): ...@@ -78,9 +79,10 @@ class ASRExecutor(BaseExecutor):
default='zh', default='zh',
help='Choose model language. zh or en') help='Choose model language. zh or en')
self.parser.add_argument( self.parser.add_argument(
"--model_sample_rate", "--sr",
type=int, type=int,
default=16000, default=16000,
choices=[8000, 16000],
help='Choose the audio sample rate of the model. 8000 or 16000') help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
...@@ -117,26 +119,27 @@ class ASRExecutor(BaseExecutor): ...@@ -117,26 +119,27 @@ class ASRExecutor(BaseExecutor):
def _init_from_path(self, def _init_from_path(self,
model_type: str='wenetspeech', model_type: str='wenetspeech',
lang: str='zh', lang: str='zh',
model_sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None): ckpt_path: Optional[os.PathLike]=None
):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + model_sample_rate_str tag = model_type + '_' + lang + '_' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path, self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path']) pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
res_path = os.path.dirname( res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
...@@ -182,8 +185,7 @@ class ASRExecutor(BaseExecutor): ...@@ -182,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self.model.eval() self.model.eval()
# load model # load model
params_path = self.ckpt_path + ".pdparams" model_dict = paddle.load(self.ckpt_path)
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
...@@ -195,8 +197,6 @@ class ASRExecutor(BaseExecutor): ...@@ -195,8 +197,6 @@ class ASRExecutor(BaseExecutor):
audio_file = input audio_file = input
logger.info("Preprocess audio_file:" + audio_file) logger.info("Preprocess audio_file:" + audio_file)
config_target_sample_rate = self.config.collator.target_sample_rate
# Get the object for feature extraction # Get the object for feature extraction
if model_type == "ds2_online" or model_type == "ds2_offline": if model_type == "ds2_online" or model_type == "ds2_offline":
audio, _ = self.collate_fn_test.process_utterance( audio, _ = self.collate_fn_test.process_utterance(
...@@ -220,7 +220,7 @@ class ASRExecutor(BaseExecutor): ...@@ -220,7 +220,7 @@ class ASRExecutor(BaseExecutor):
preprocess_args = {"train": False} preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf) preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file") logger.info("read the audio file")
audio, sample_rate = soundfile.read( audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True) audio_file, dtype="int16", always_2d=True)
if self.change_format: if self.change_format:
...@@ -229,17 +229,13 @@ class ASRExecutor(BaseExecutor): ...@@ -229,17 +229,13 @@ class ASRExecutor(BaseExecutor):
else: else:
audio = audio[:, 0] audio = audio[:, 0]
audio = audio.astype("float32") audio = audio.astype("float32")
audio = librosa.resample(audio, sample_rate, audio = librosa.resample(audio, audio_sample_rate,
self.target_sample_rate) self.sample_rate)
sample_rate = self.target_sample_rate audio_sample_rate = self.sample_rate
audio = audio.astype("int16") audio = np.round(audio).astype("int16")
else: else:
audio = audio[:, 0] audio = audio[:, 0]
if sample_rate != config_target_sample_rate:
logger.error(
f"sample rate error: {sample_rate}, need {self.sr} ")
sys.exit(-1)
logger.info(f"audio shape: {audio.shape}") logger.info(f"audio shape: {audio.shape}")
# fbank # fbank
audio = preprocessing(audio, **preprocess_args) audio = preprocessing(audio, **preprocess_args)
...@@ -311,11 +307,11 @@ class ASRExecutor(BaseExecutor): ...@@ -311,11 +307,11 @@ class ASRExecutor(BaseExecutor):
""" """
return self._outputs["result"] return self._outputs["result"]
def _check(self, audio_file: str, model_sample_rate: int): def _check(self, audio_file: str, sample_rate: int):
self.target_sample_rate = model_sample_rate self.sample_rate = sample_rate
if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error( logger.error(
"please input --model_sample_rate 8000 or --model_sample_rate 16000" "please input --sr 8000 or --sr 16000"
) )
raise Exception("invalid sample rate") raise Exception("invalid sample rate")
sys.exit(-1) sys.exit(-1)
...@@ -326,7 +322,7 @@ class ASRExecutor(BaseExecutor): ...@@ -326,7 +322,7 @@ class ASRExecutor(BaseExecutor):
logger.info("checking the audio file format......") logger.info("checking the audio file format......")
try: try:
sig, sample_rate = soundfile.read( audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True) audio_file, dtype="int16", always_2d=True)
except Exception as e: except Exception as e:
logger.error(str(e)) logger.error(str(e))
...@@ -340,15 +336,15 @@ class ASRExecutor(BaseExecutor): ...@@ -340,15 +336,15 @@ class ASRExecutor(BaseExecutor):
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
") ")
sys.exit(-1) sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate) logger.info("The sample rate is %d" % audio_sample_rate)
if sample_rate != self.target_sample_rate: if audio_sample_rate != self.sample_rate:
logger.warning( logger.warning(
"The sample rate of the input file is not {}.\n \ "The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
If the result does not meet your expectations,\n \ If the result does not meet your expectations,\n \
Please input the 16k 16bit 1 channel wav file. \ Please input the 16k 16bit 1 channel wav file. \
" "
.format(self.target_sample_rate, self.target_sample_rate)) .format(self.sample_rate, self.sample_rate))
while (True): while (True):
logger.info( logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
...@@ -379,14 +375,14 @@ class ASRExecutor(BaseExecutor): ...@@ -379,14 +375,14 @@ class ASRExecutor(BaseExecutor):
model = parser_args.model model = parser_args.model
lang = parser_args.lang lang = parser_args.lang
model_sample_rate = parser_args.model_sample_rate sample_rate = parser_args.sr
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input audio_file = parser_args.input
device = parser_args.device device = parser_args.device
try: try:
res = self(model, lang, model_sample_rate, config, ckpt_path, res = self(model, lang, sample_rate, config, ckpt_path,
audio_file, device) audio_file, device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
...@@ -394,16 +390,15 @@ class ASRExecutor(BaseExecutor): ...@@ -394,16 +390,15 @@ class ASRExecutor(BaseExecutor):
print(e) print(e)
return False return False
def __call__(self, model, lang, model_sample_rate, config, ckpt_path, def __call__(self, model, lang, sample_rate, config, ckpt_path,
audio_file, device): audio_file, device):
""" """
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
self._check(audio_file, model_sample_rate) self._check(audio_file, sample_rate)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) self._init_from_path(model, lang, sample_rate, config, ckpt_path)
self.preprocess(model, audio_file) self.preprocess(model, audio_file)
self.infer(model) self.infer(model)
res = self.postprocess() # Retrieve result of asr. res = self.postprocess() # Retrieve result of asr.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册