提交 90d648a6 编写于 作者: H huangyuxin

support using by __call__

上级 aecb5f56
...@@ -119,7 +119,8 @@ class ASRExecutor(BaseExecutor): ...@@ -119,7 +119,8 @@ class ASRExecutor(BaseExecutor):
lang: str='zh', lang: str='zh',
model_sample_rate: int=16000, model_sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None): ckpt_path: Optional[os.PathLike]=None,
device: str='cpu'):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
...@@ -140,12 +141,8 @@ class ASRExecutor(BaseExecutor): ...@@ -140,12 +141,8 @@ class ASRExecutor(BaseExecutor):
res_path = os.path.dirname( res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
# Enter the path of model root
os.chdir(res_path)
#Init body. #Init body.
parser_args = self.parser_args paddle.set_device(device)
paddle.set_device(parser_args.device)
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
self.config.decoding.decoding_method = "attention_rescoring" self.config.decoding.decoding_method = "attention_rescoring"
...@@ -153,29 +150,35 @@ class ASRExecutor(BaseExecutor): ...@@ -153,29 +150,35 @@ class ASRExecutor(BaseExecutor):
logger.info(model_conf) logger.info(model_conf)
with UpdateConfig(model_conf): with UpdateConfig(model_conf):
if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": if model_type == "ds2_online" or model_type == "ds2_offline":
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.config.collator.vocab_filepath = os.path.join( self.config.collator.mean_std_filepath = os.path.join(
res_path, self.config.collator.cmvn_path) res_path, self.config.collator.cmvn_path)
self.collate_fn_test = SpeechCollator.from_config(self.config) self.collate_fn_test = SpeechCollator.from_config(self.config)
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.collate_fn_test.feature_size model_conf.input_dim = self.collate_fn_test.feature_size
model_conf.output_dim = self.text_feature.vocab_size model_conf.output_dim = text_feature.vocab_size
elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.text_feature = TextFeaturizer( text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.config.collator.feat_dim model_conf.input_dim = self.config.collator.feat_dim
model_conf.output_dim = self.text_feature.vocab_size model_conf.output_dim = text_feature.vocab_size
else: else:
raise Exception("wrong type") raise Exception("wrong type")
self.config.freeze() self.config.freeze()
model_class = dynamic_import(parser_args.model, model_alias) # Enter the path of model root
os.chdir(res_path)
model_class = dynamic_import(model_type, model_alias)
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
self.model = model self.model = model
self.model.eval() self.model.eval()
...@@ -185,31 +188,31 @@ class ASRExecutor(BaseExecutor): ...@@ -185,31 +188,31 @@ class ASRExecutor(BaseExecutor):
model_dict = paddle.load(params_path) model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def preprocess(self, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
""" """
Input preprocess and return paddle.Tensor stored in self.input. Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
""" """
parser_args = self.parser_args
config = self.config
audio_file = input audio_file = input
logger.info("Preprocess audio_file:" + audio_file) logger.info("Preprocess audio_file:" + audio_file)
self.sr = config.collator.target_sample_rate config_target_sample_rate = self.config.collator.target_sample_rate
# Get the object for feature extraction # Get the object for feature extraction
if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": if model_type == "ds2_online" or model_type == "ds2_offline":
audio, _ = self.collate_fn_test.process_utterance( audio, _ = self.collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ") audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0] audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32') audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len) audio_len = paddle.to_tensor(audio_len)
self.audio = paddle.unsqueeze(audio, axis=0) audio = paddle.unsqueeze(audio, axis=0)
self.vocab_list = collate_fn_test.vocab_list vocab_list = collate_fn_test.vocab_list
logger.info(f"audio feat shape: {self.audio.shape}") self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": logger.info(f"audio feat shape: {audio.shape}")
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
logger.info("get the preprocess conf") logger.info("get the preprocess conf")
preprocess_conf = os.path.join( preprocess_conf = os.path.join(
os.path.dirname(os.path.abspath(self.cfg_path)), os.path.dirname(os.path.abspath(self.cfg_path)),
...@@ -235,7 +238,7 @@ class ASRExecutor(BaseExecutor): ...@@ -235,7 +238,7 @@ class ASRExecutor(BaseExecutor):
else: else:
audio = audio[:, 0] audio = audio[:, 0]
if sample_rate != self.sr: if sample_rate != config_target_sample_rate:
logger.error( logger.error(
f"sample rate error: {sample_rate}, need {self.sr} ") f"sample rate error: {sample_rate}, need {self.sr} ")
sys.exit(-1) sys.exit(-1)
...@@ -243,29 +246,36 @@ class ASRExecutor(BaseExecutor): ...@@ -243,29 +246,36 @@ class ASRExecutor(BaseExecutor):
# fbank # fbank
audio = preprocessing(audio, **preprocess_args) audio = preprocessing(audio, **preprocess_args)
self.audio_len = paddle.to_tensor(audio.shape[0]) audio_len = paddle.to_tensor(audio.shape[0])
self.audio = paddle.to_tensor( audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
audio, dtype='float32').unsqueeze(axis=0) text_feature = TextFeaturizer(
logger.info(f"audio feat shape: {self.audio.shape}") unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
else: else:
raise Exception("wrong type") raise Exception("wrong type")
@paddle.no_grad() @paddle.no_grad()
def infer(self): def infer(self, model_type: str):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
cfg = self.config.decoding cfg = self.config.decoding
parser_args = self.parser_args audio = self._inputs["audio"]
audio = self.audio audio_len = self._inputs["audio_len"]
audio_len = self.audio_len if model_type == "ds2_online" or model_type == "ds2_offline":
if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline":
vocab_list = self.vocab_list
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
vocab_list, text_feature.vocab_list,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path, lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha, beam_alpha=cfg.alpha,
...@@ -274,14 +284,13 @@ class ASRExecutor(BaseExecutor): ...@@ -274,14 +284,13 @@ class ASRExecutor(BaseExecutor):
cutoff_prob=cfg.cutoff_prob, cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n, cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch) num_processes=cfg.num_proc_bsearch)
self.result_transcripts = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
text_feature = self.text_feature
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
text_feature=self.text_feature, text_feature=text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path, lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha, beam_alpha=cfg.alpha,
...@@ -294,23 +303,22 @@ class ASRExecutor(BaseExecutor): ...@@ -294,23 +303,22 @@ class ASRExecutor(BaseExecutor):
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
self.result_transcripts = result_transcripts[0][0] self._outputs["result"] = result_transcripts[0][0]
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")
pass
def postprocess(self) -> Union[str, os.PathLike]: def postprocess(self) -> Union[str, os.PathLike]:
""" """
Output postprocess and return human-readable results such as texts and audio files. Output postprocess and return human-readable results such as texts and audio files.
""" """
return self.result_transcripts return self._outputs["result"]
def _check(self, audio_file: str, model_sample_rate: int): def _check(self, audio_file: str, model_sample_rate: int):
self.target_sample_rate = model_sample_rate self.target_sample_rate = model_sample_rate
if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: if self.target_sample_rate != 16000 and self.target_sample_rate != 8000:
logger.error( logger.error(
"please input --model_sample_rate 8000 or --model_sample_rate 16000") "please input --model_sample_rate 8000 or --model_sample_rate 16000"
)
raise Exception("invalid sample rate") raise Exception("invalid sample rate")
sys.exit(-1) sys.exit(-1)
...@@ -336,11 +344,13 @@ class ASRExecutor(BaseExecutor): ...@@ -336,11 +344,13 @@ class ASRExecutor(BaseExecutor):
sys.exit(-1) sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate) logger.info("The sample rate is %d" % sample_rate)
if sample_rate != self.target_sample_rate: if sample_rate != self.target_sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \ logger.warning(
"The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
If the result does not meet your expectations,\n \ If the result does not meet your expectations,\n \
Please input the 16k 16bit 1 channel wav file. \ Please input the 16k 16bit 1 channel wav file. \
".format(self.target_sample_rate, self.target_sample_rate)) "
.format(self.target_sample_rate, self.target_sample_rate))
while (True): while (True):
logger.info( logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
...@@ -367,34 +377,36 @@ class ASRExecutor(BaseExecutor): ...@@ -367,34 +377,36 @@ class ASRExecutor(BaseExecutor):
""" """
Command line entry. Command line entry.
""" """
self.parser_args = self.parser.parse_args(argv) parser_args = self.parser.parse_args(argv)
model = self.parser_args.model model = parser_args.model
lang = self.parser_args.lang lang = parser_args.lang
model_sample_rate = self.parser_args.model_sample_rate model_sample_rate = parser_args.model_sample_rate
config = self.parser_args.config config = parser_args.config
ckpt_path = self.parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = os.path.abspath(self.parser_args.input) audio_file = parser_args.input
device = self.parser_args.device device = parser_args.device
try: try:
res = self(model, lang, model_sample_rate, config, ckpt_path, audio_file, res = self(model, lang, model_sample_rate, config, ckpt_path,
device) audio_file, device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
except Exception as e: except Exception as e:
print(e) print(e)
return False return False
def __call__(self, model, lang, model_sample_rate, config, ckpt_path, audio_file, def __call__(self, model, lang, model_sample_rate, config, ckpt_path,
device): audio_file, device):
""" """
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file)
self._check(audio_file, model_sample_rate) self._check(audio_file, model_sample_rate)
self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) self._init_from_path(model, lang, model_sample_rate, config, ckpt_path,
self.preprocess(audio_file) device)
self.infer() self.preprocess(model, audio_file)
self.infer(model)
res = self.postprocess() # Retrieve result of asr. res = self.postprocess() # Retrieve result of asr.
return res return res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册