提交 1b57d05d 编写于 作者: H huangyuxin

rm the os.chdir in cli asr

上级 021311c7
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import io
import os import os
import sys import sys
from typing import List from typing import List
...@@ -19,10 +20,11 @@ from typing import Optional ...@@ -19,10 +20,11 @@ from typing import Optional
from typing import Union from typing import Union
import librosa import librosa
import numpy as np
import paddle import paddle
import soundfile import soundfile
import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
...@@ -131,8 +133,7 @@ class ASRExecutor(BaseExecutor): ...@@ -131,8 +133,7 @@ class ASRExecutor(BaseExecutor):
lang: str='zh', lang: str='zh',
sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None ckpt_path: Optional[os.PathLike]=None):
):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
...@@ -140,10 +141,11 @@ class ASRExecutor(BaseExecutor): ...@@ -140,10 +141,11 @@ class ASRExecutor(BaseExecutor):
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str tag = model_type + '_' + lang + '_' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path, self.ckpt_path = os.path.join(
pretrained_models[tag]['ckpt_path'] + ".pdparams") res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
...@@ -157,10 +159,8 @@ class ASRExecutor(BaseExecutor): ...@@ -157,10 +159,8 @@ class ASRExecutor(BaseExecutor):
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
self.config.decoding.decoding_method = "attention_rescoring" self.config.decoding.decoding_method = "attention_rescoring"
model_conf = self.config.model
logger.info(model_conf)
with UpdateConfig(model_conf): with UpdateConfig(self.config):
if model_type == "ds2_online" or model_type == "ds2_offline": if model_type == "ds2_online" or model_type == "ds2_offline":
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
...@@ -172,24 +172,29 @@ class ASRExecutor(BaseExecutor): ...@@ -172,24 +172,29 @@ class ASRExecutor(BaseExecutor):
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.collate_fn_test.feature_size self.config.model.input_dim = self.collate_fn_test.feature_size
model_conf.output_dim = text_feature.vocab_size self.config.model.output_dim = text_feature.vocab_size
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.config.collator.augmentation_config = os.path.join(
res_path, self.config.collator.augmentation_config)
self.config.collator.spm_model_prefix = os.path.join(
res_path, self.config.collator.spm_model_prefix)
text_feature = TextFeaturizer( text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.config.collator.feat_dim self.config.model.input_dim = self.config.collator.feat_dim
model_conf.output_dim = text_feature.vocab_size self.config.model.output_dim = text_feature.vocab_size
else: else:
raise Exception("wrong type") raise Exception("wrong type")
self.config.freeze()
# Enter the path of model root # Enter the path of model root
os.chdir(res_path)
model_class = dynamic_import(model_type, model_alias) model_class = dynamic_import(model_type, model_alias)
model_conf = self.config.model
logger.info(model_conf)
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
self.model = model self.model = model
self.model.eval() self.model.eval()
...@@ -222,10 +227,17 @@ class ASRExecutor(BaseExecutor): ...@@ -222,10 +227,17 @@ class ASRExecutor(BaseExecutor):
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
logger.info("get the preprocess conf") logger.info("get the preprocess conf")
preprocess_conf = os.path.join( preprocess_conf_file = self.config.collator.augmentation_config
os.path.dirname(os.path.abspath(self.cfg_path)), # redirect the cmvn path
"preprocess.yaml") with io.open(preprocess_conf_file, encoding="utf-8") as f:
preprocess_conf = yaml.safe_load(f)
for idx, process in enumerate(preprocess_conf["process"]):
if process['type'] == "cmvn_json":
preprocess_conf["process"][idx][
"cmvn_path"] = os.path.join(
self.res_path,
preprocess_conf["process"][idx]["cmvn_path"])
break
logger.info(preprocess_conf) logger.info(preprocess_conf)
preprocess_args = {"train": False} preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf) preprocessing = Transformation(preprocess_conf)
...@@ -320,14 +332,14 @@ class ASRExecutor(BaseExecutor): ...@@ -320,14 +332,14 @@ class ASRExecutor(BaseExecutor):
return self._outputs["result"] return self._outputs["result"]
def _pcm16to32(self, audio): def _pcm16to32(self, audio):
assert(audio.dtype == np.int16) assert (audio.dtype == np.int16)
audio = audio.astype("float32") audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1)) audio = audio / (2**(bits - 1))
return audio return audio
def _pcm32to16(self, audio): def _pcm32to16(self, audio):
assert(audio.dtype == np.float32) assert (audio.dtype == np.float32)
bits = np.iinfo(np.int16).bits bits = np.iinfo(np.int16).bits
audio = audio * (2**(bits - 1)) audio = audio * (2**(bits - 1))
audio = np.round(audio).astype("int16") audio = np.round(audio).astype("int16")
...@@ -336,9 +348,7 @@ class ASRExecutor(BaseExecutor): ...@@ -336,9 +348,7 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int): def _check(self, audio_file: str, sample_rate: int):
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error( logger.error("please input --sr 8000 or --sr 16000")
"please input --sr 8000 or --sr 16000"
)
raise Exception("invalid sample rate") raise Exception("invalid sample rate")
sys.exit(-1) sys.exit(-1)
...@@ -364,13 +374,11 @@ class ASRExecutor(BaseExecutor): ...@@ -364,13 +374,11 @@ class ASRExecutor(BaseExecutor):
sys.exit(-1) sys.exit(-1)
logger.info("The sample rate is %d" % audio_sample_rate) logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate: if audio_sample_rate != self.sample_rate:
logger.warning( logger.warning("The sample rate of the input file is not {}.\n \
"The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
If the result does not meet your expectations,\n \ If the result does not meet your expectations,\n \
Please input the 16k 16 bit 1 channel wav file. \ Please input the 16k 16 bit 1 channel wav file. \
" ".format(self.sample_rate, self.sample_rate))
.format(self.sample_rate, self.sample_rate))
while (True): while (True):
logger.info( logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
...@@ -408,16 +416,16 @@ class ASRExecutor(BaseExecutor): ...@@ -408,16 +416,16 @@ class ASRExecutor(BaseExecutor):
device = parser_args.device device = parser_args.device
try: try:
res = self(model, lang, sample_rate, config, ckpt_path, res = self(model, lang, sample_rate, config, ckpt_path, audio_file,
audio_file, device) device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
except Exception as e: except Exception as e:
print(e) print(e)
return False return False
def __call__(self, model, lang, sample_rate, config, ckpt_path, def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file,
audio_file, device): device):
""" """
Python API to call an executor. Python API to call an executor.
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册