未验证 提交 8828210f 编写于 作者: K KP 提交者: GitHub

Merge pull request #3 from Jackwaterveg/cli_infer

revise the sample rate
......@@ -22,6 +22,7 @@ import librosa
import paddle
import soundfile
from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor
from ..utils import cli_register
......@@ -78,9 +79,10 @@ class ASRExecutor(BaseExecutor):
default='zh',
help='Choose model language. zh or en')
self.parser.add_argument(
"--model_sample_rate",
"--sr",
type=int,
default=16000,
choices=[8000, 16000],
help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument(
'--config',
......@@ -117,26 +119,27 @@ class ASRExecutor(BaseExecutor):
def _init_from_path(self,
model_type: str='wenetspeech',
lang: str='zh',
model_sample_rate: int=16000,
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None):
ckpt_path: Optional[os.PathLike]=None
):
"""
Init model and other resources from a specific path.
"""
if cfg_path is None or ckpt_path is None:
model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + model_sample_rate_str
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path'])
pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
......@@ -182,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self.model.eval()
# load model
params_path = self.ckpt_path + ".pdparams"
model_dict = paddle.load(params_path)
model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
......@@ -195,8 +197,6 @@ class ASRExecutor(BaseExecutor):
audio_file = input
logger.info("Preprocess audio_file:" + audio_file)
config_target_sample_rate = self.config.collator.target_sample_rate
# Get the object for feature extraction
if model_type == "ds2_online" or model_type == "ds2_offline":
audio, _ = self.collate_fn_test.process_utterance(
......@@ -220,7 +220,7 @@ class ASRExecutor(BaseExecutor):
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file")
audio, sample_rate = soundfile.read(
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
if self.change_format:
......@@ -229,17 +229,13 @@ class ASRExecutor(BaseExecutor):
else:
audio = audio[:, 0]
audio = audio.astype("float32")
audio = librosa.resample(audio, sample_rate,
self.target_sample_rate)
sample_rate = self.target_sample_rate
audio = audio.astype("int16")
audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate)
audio_sample_rate = self.sample_rate
audio = np.round(audio).astype("int16")
else:
audio = audio[:, 0]
if sample_rate != config_target_sample_rate:
logger.error(
f"sample rate error: {sample_rate}, need {self.sr} ")
sys.exit(-1)
logger.info(f"audio shape: {audio.shape}")
# fbank
audio = preprocessing(audio, **preprocess_args)
......@@ -311,11 +307,11 @@ class ASRExecutor(BaseExecutor):
"""
return self._outputs["result"]
def _check(self, audio_file: str, model_sample_rate: int):
self.target_sample_rate = model_sample_rate
if self.target_sample_rate != 16000 and self.target_sample_rate != 8000:
def _check(self, audio_file: str, sample_rate: int):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(
"please input --model_sample_rate 8000 or --model_sample_rate 16000"
"please input --sr 8000 or --sr 16000"
)
raise Exception("invalid sample rate")
sys.exit(-1)
......@@ -326,7 +322,7 @@ class ASRExecutor(BaseExecutor):
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
except Exception as e:
logger.error(str(e))
......@@ -340,15 +336,15 @@ class ASRExecutor(BaseExecutor):
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
if sample_rate != self.target_sample_rate:
logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate:
logger.warning(
"The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations,\n \
Please input the 16k 16bit 1 channel wav file. \
"
.format(self.target_sample_rate, self.target_sample_rate))
.format(self.sample_rate, self.sample_rate))
while (True):
logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
......@@ -379,14 +375,14 @@ class ASRExecutor(BaseExecutor):
model = parser_args.model
lang = parser_args.lang
model_sample_rate = parser_args.model_sample_rate
sample_rate = parser_args.sr
config = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
device = parser_args.device
try:
res = self(model, lang, model_sample_rate, config, ckpt_path,
res = self(model, lang, sample_rate, config, ckpt_path,
audio_file, device)
logger.info('ASR Result: {}'.format(res))
return True
......@@ -394,16 +390,15 @@ class ASRExecutor(BaseExecutor):
print(e)
return False
def __call__(self, model, lang, model_sample_rate, config, ckpt_path,
def __call__(self, model, lang, sample_rate, config, ckpt_path,
audio_file, device):
"""
Python API to call an executor.
"""
audio_file = os.path.abspath(audio_file)
self._check(audio_file, model_sample_rate)
self._check(audio_file, sample_rate)
paddle.set_device(device)
self._init_from_path(model, lang, model_sample_rate, config, ckpt_path)
self._init_from_path(model, lang, sample_rate, config, ckpt_path)
self.preprocess(model, audio_file)
self.infer(model)
res = self.postprocess() # Retrieve result of asr.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册