提交 a9d206c1 编写于 作者: H huangyuxin

revise

上级 957f2e3a
...@@ -22,6 +22,7 @@ import librosa ...@@ -22,6 +22,7 @@ import librosa
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
...@@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor): ...@@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor):
"--sr", "--sr",
type=int, type=int,
default=16000, default=16000,
choices=[8000, 16000],
help='Choose the audio sample rate of the model. 8000 or 16000') help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
...@@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor): ...@@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor):
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path, self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path']) pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
res_path = os.path.dirname( res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
...@@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor): ...@@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self.model.eval() self.model.eval()
# load model # load model
params_path = self.ckpt_path + ".pdparams" model_dict = paddle.load(self.ckpt_path)
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
...@@ -231,7 +232,7 @@ class ASRExecutor(BaseExecutor): ...@@ -231,7 +232,7 @@ class ASRExecutor(BaseExecutor):
audio = librosa.resample(audio, audio_sample_rate, audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate) self.sample_rate)
audio_sample_rate = self.sample_rate audio_sample_rate = self.sample_rate
audio = audio.astype("int16") audio = np.round(audio).astype("int16")
else: else:
audio = audio[:, 0] audio = audio[:, 0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册