提交 b0356ae4 编写于 作者: H huangyuxin

revise

上级 957f2e3a
......@@ -22,6 +22,7 @@ import librosa
import paddle
import soundfile
from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor
from ..utils import cli_register
......@@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor):
"--sr",
type=int,
default=16000,
choices=[8000, 16000],
help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument(
'--config',
......@@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor):
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path'])
pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
......@@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self.model.eval()
# load model
params_path = self.ckpt_path + ".pdparams"
model_dict = paddle.load(params_path)
model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
......@@ -227,11 +228,16 @@ class ASRExecutor(BaseExecutor):
audio = audio.mean(axis=1)
else:
audio = audio[:, 0]
# pcm16 -> pcm 32
audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1))
audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate)
audio_sample_rate = self.sample_rate
audio = audio.astype("int16")
# pcm16 -> pcm 32
audio = audio * (2**(bits - 1))
audio = np.round(audio).astype("int16")
else:
audio = audio[:, 0]
......@@ -341,7 +347,7 @@ class ASRExecutor(BaseExecutor):
"The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations,\n \
Please input the 16k 16bit 1 channel wav file. \
Please input the 16k 16 bit 1 channel wav file. \
"
.format(self.sample_rate, self.sample_rate))
while (True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册