提交 b0356ae4 编写于 作者: H huangyuxin

revise

上级 957f2e3a
...@@ -22,6 +22,7 @@ import librosa ...@@ -22,6 +22,7 @@ import librosa
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
...@@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor): ...@@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor):
"--sr", "--sr",
type=int, type=int,
default=16000, default=16000,
choices=[8000, 16000],
help='Choose the audio sample rate of the model. 8000 or 16000') help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
...@@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor): ...@@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor):
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path, self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path']) pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
res_path = os.path.dirname( res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
...@@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor): ...@@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self.model.eval() self.model.eval()
# load model # load model
params_path = self.ckpt_path + ".pdparams" model_dict = paddle.load(self.ckpt_path)
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
...@@ -227,11 +228,16 @@ class ASRExecutor(BaseExecutor): ...@@ -227,11 +228,16 @@ class ASRExecutor(BaseExecutor):
audio = audio.mean(axis=1) audio = audio.mean(axis=1)
else: else:
audio = audio[:, 0] audio = audio[:, 0]
# pcm16 -> pcm 32
audio = audio.astype("float32") audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1))
audio = librosa.resample(audio, audio_sample_rate, audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate) self.sample_rate)
audio_sample_rate = self.sample_rate audio_sample_rate = self.sample_rate
audio = audio.astype("int16") # pcm16 -> pcm 32
audio = audio * (2**(bits - 1))
audio = np.round(audio).astype("int16")
else: else:
audio = audio[:, 0] audio = audio[:, 0]
...@@ -341,7 +347,7 @@ class ASRExecutor(BaseExecutor): ...@@ -341,7 +347,7 @@ class ASRExecutor(BaseExecutor):
"The sample rate of the input file is not {}.\n \ "The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
If the result does not meet your expectations,\n \ If the result does not meet your expectations,\n \
Please input the 16k 16bit 1 channel wav file. \ Please input the 16k 16 bit 1 channel wav file. \
" "
.format(self.sample_rate, self.sample_rate)) .format(self.sample_rate, self.sample_rate))
while (True): while (True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册