You need to sign in or sign up before continuing.
提交 317ffea5 编写于 作者: H huangyuxin

simplify the code

上级 1f050a4d
......@@ -410,30 +410,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
if self.args.model_type == "online":
output_probs_branch, output_lens_branch = self.static_forward_online(
audio, audio_len)
output_probs, output_lens = self.static_forward_online(audio,
audio_len)
elif self.args.model_type == "offline":
output_probs_branch, output_lens_branch = self.static_forward_offline(
audio, audio_len)
output_probs, output_lens = self.static_forward_offline(audio,
audio_len)
else:
raise Exception("wrong model type")
self.predictor.clear_intermediate_tensor()
self.predictor.try_shrink_memory()
self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path,
vocab_list, cfg.decoding_method)
result_transcripts = self.model.decoder.decode_probs(
output_probs_branch.numpy(), output_lens_branch, vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
output_probs, output_lens, vocab_list, cfg.decoding_method,
cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size,
cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch)
return result_transcripts
def static_forward_online(self, audio, audio_len):
def static_forward_online(self, audio, audio_len,
decoder_chunk_size: int=1):
"""
Parameters
----------
audio (Tensor): shape[B, T, D]
audio_len (Tensor): shape[B]
decoder_chunk_size(int)
Returns
-------
output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B]
"""
output_probs_list = []
output_lens_list = []
decoder_chunk_size = 1
subsampling_rate = self.model.encoder.conv.subsampling_rate
receptive_field_length = self.model.encoder.conv.receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
......@@ -441,41 +453,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
) * subsampling_rate + receptive_field_length
x_batch = audio.numpy()
batch_size = x_batch.shape[0]
batch_size, Tmax, x_dim = x_batch.shape
x_len_batch = audio_len.numpy().astype(np.int64)
max_len_batch = x_batch.shape[1]
batch_padding_len = chunk_stride - (
max_len_batch - chunk_size
padding_len_batch = chunk_stride - (
Tmax - chunk_size
) % chunk_stride # The length of padding for the batch
x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0)
x_len_list = np.split(x_len_batch, batch_size, axis=0)
for x, x_len in zip(x_list, x_len_list):
self.autolog.times.start()
self.autolog.times.stamp()
assert (chunk_size <= x_len[0])
x_len = x_len[0]
assert (chunk_size <= x_len)
eouts_chunk_list = []
eouts_chunk_lens_list = []
if (x_len - chunk_size) % chunk_stride != 0:
padding_len_x = chunk_stride - (x_len - chunk_size
) % chunk_stride
else:
padding_len_x = 0
padding_len_x = chunk_stride - (x_len[0] - chunk_size
) % chunk_stride
padding = np.zeros(
(x.shape[0], padding_len_x, x.shape[2]), dtype=np.float32)
(x.shape[0], padding_len_x, x.shape[2]), dtype=x.dtype)
padded_x = np.concatenate([x, padding], axis=1)
num_chunk = (x_len[0] + padding_len_x - chunk_size
) / chunk_stride + 1
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
chunk_state_h_box = np.zeros(
(self.config.model.num_rnn_layers, 1,
self.config.model.rnn_layer_size),
dtype=np.float32)
dtype=x.dtype)
chunk_state_c_box = np.zeros(
(self.config.model.num_rnn_layers, 1,
self.config.model.rnn_layer_size),
dtype=np.float32)
dtype=x.dtype)
input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0])
......@@ -489,16 +502,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
x_len_left = np.where(x_len - i * chunk_stride < 0,
np.zeros_like(x_len, dtype=np.int64),
x_len - i * chunk_stride)
x_chunk_len_tmp = np.ones_like(
x_len, dtype=np.int64) * chunk_size
x_chunk_lens = np.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp)
if (x_chunk_lens[0] <
if x_len < i * chunk_stride:
x_chunk_lens = 0
else:
x_chunk_lens = min(x_len - i * chunk_stride, chunk_size)
if (x_chunk_lens <
receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob
break
x_chunk_lens = np.array([x_chunk_lens])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
......@@ -530,11 +542,13 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
probs_chunk_lens_list.append(output_chunk_lens)
output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0)
output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[
vocab_size = output_probs.shape[2]
output_probs_padding_len = Tmax + padding_len_batch - output_probs.shape[
1]
output_probs_padding = np.zeros(
(1, output_probs_padding_len, output_probs.shape[2]),
dtype=np.float32) # The prob padding for a piece of utterance
(1, output_probs_padding_len, vocab_size),
dtype=output_probs.
dtype) # The prob padding for a piece of utterance
output_probs = np.concatenate(
[output_probs, output_probs_padding], axis=1)
output_probs_list.append(output_probs)
......@@ -542,13 +556,22 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
output_probs_branch = np.concatenate(output_probs_list, axis=0)
output_lens_branch = np.concatenate(output_lens_list, axis=0)
output_probs_branch = paddle.to_tensor(output_probs_branch)
output_lens_branch = paddle.to_tensor(output_lens_branch)
return output_probs_branch, output_lens_branch
output_probs = np.concatenate(output_probs_list, axis=0)
output_lens = np.concatenate(output_lens_list, axis=0)
return output_probs, output_lens
def static_forward_offline(self, audio, audio_len):
"""
Parameters
----------
audio (Tensor): shape[B, T, D]
audio_len (Tensor): shape[B]
Returns
-------
output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B]
"""
x = audio.numpy()
x_len = audio_len.numpy().astype(np.int64)
......@@ -574,9 +597,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_probs = output_handle.copy_to_cpu()
output_lens = output_lens_handle.copy_to_cpu()
output_probs_branch = paddle.to_tensor(output_probs)
output_lens_branch = paddle.to_tensor(output_lens)
return output_probs_branch, output_lens_branch
return output_probs, output_lens
def run_test(self):
try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册