提交 ab92e2c9 编写于 作者: T tianhao zhang

fix deepspeech2 decode_wav

上级 ed16f96a
......@@ -20,8 +20,8 @@ import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils import mp_tools
......@@ -38,24 +38,24 @@ class DeepSpeech2Tester_hub():
self.args = args
self.config = config
self.audio_file = args.audio_file
self.collate_fn_test = SpeechCollator.from_config(config)
self._text_featurizer = TextFeaturizer(
unit_type=config.unit_type, vocab=None)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath,
spm_model_prefix=config.spm_model_prefix)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
decode_batch_size = cfg.decode_batch_size
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, cfg.decoding_method,
cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size,
cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch)
result_transcripts = self.model.decode(audio, audio_len)
return result_transcripts
@mp_tools.rank_zero_only
......@@ -64,16 +64,23 @@ class DeepSpeech2Tester_hub():
self.model.eval()
cfg = self.config
audio_file = self.audio_file
collate_fn_test = self.collate_fn_test
audio, _ = collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
vocab_list = collate_fn_test.vocab_list
audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
# fbank
feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}")
audio_len = paddle.to_tensor(feat.shape[0])
audio = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
result_transcripts = self.compute_result_transcripts(
audio, audio_len, vocab_list, cfg.decode)
audio, audio_len, self.text_feature.vocab_list, cfg.decode)
logger.info("result_transcripts: " + result_transcripts[0])
def run_test(self):
......@@ -109,11 +116,9 @@ class DeepSpeech2Tester_hub():
def setup_model(self):
config = self.config.clone()
with UpdateConfig(config):
config.input_dim = self.collate_fn_test.feature_size
config.output_dim = self.collate_fn_test.vocab_size
config.input_dim = config.feat_dim
config.output_dim = self.text_feature.vocab_size
model = DeepSpeech2Model.from_config(config)
self.model = model
def setup_checkpointer(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册