提交 b7a77eeb 编写于 作者: X xiongxinlei

update the time stamp type, test=doc

上级 43582f50
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import copy import copy
import os import os
import time
from typing import Optional from typing import Optional
import numpy as np import numpy as np
...@@ -297,7 +296,8 @@ class PaddleASRConnectionHanddler: ...@@ -297,7 +296,8 @@ class PaddleASRConnectionHanddler:
self.chunk_num = 0 self.chunk_num = 0
self.global_frame_offset = 0 self.global_frame_offset = 0
self.result_transcripts = [''] self.result_transcripts = ['']
self.word_time_stamp = None self.word_time_stamp = []
self.time_stamp = []
self.first_char_occur_elapsed = None self.first_char_occur_elapsed = None
def decode(self, is_finished=False): def decode(self, is_finished=False):
...@@ -515,9 +515,6 @@ class PaddleASRConnectionHanddler: ...@@ -515,9 +515,6 @@ class PaddleASRConnectionHanddler:
return '' return ''
def get_word_time_stamp(self): def get_word_time_stamp(self):
if self.word_time_stamp is None:
return []
else:
return self.word_time_stamp return self.word_time_stamp
@paddle.no_grad() @paddle.no_grad()
...@@ -582,7 +579,18 @@ class PaddleASRConnectionHanddler: ...@@ -582,7 +579,18 @@ class PaddleASRConnectionHanddler:
best_index = i best_index = i
# update the one best result # update the one best result
# hyps stored the beam results and each fields is:
logger.info(f"best index: {best_index}") logger.info(f"best index: {best_index}")
# logger.info(f'best result: {hyps[best_index]}')
# the field of the hyps is:
# hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][1]: the sentence decoding probability with all paths
# hyps[0][2]: viterbi_blank ending probability
# hyps[0][3]: viterbi_non_blank probability
# hyps[0][4]: current_token_prob,
# hyps[0][5]: times_viterbi_blank,
# hyps[0][6]: times_titerbi_non_blank
self.hyps = [hyps[best_index][0]] self.hyps = [hyps[best_index][0]]
# update the hyps time stamp # update the hyps time stamp
......
...@@ -27,7 +27,7 @@ class CTCPrefixBeamSearch: ...@@ -27,7 +27,7 @@ class CTCPrefixBeamSearch:
"""Implement the ctc prefix beam search """Implement the ctc prefix beam search
Args: Args:
config (yacs.config.CfgNode): _description_ config (yacs.config.CfgNode): the ctc prefix beam search configuration
""" """
self.config = config self.config = config
self.reset() self.reset()
...@@ -69,7 +69,6 @@ class CTCPrefixBeamSearch: ...@@ -69,7 +69,6 @@ class CTCPrefixBeamSearch:
# 2. CTC beam search step by step # 2. CTC beam search step by step
for t in range(0, maxlen): for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,) logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
# next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) # next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
next_hyps = defaultdict( next_hyps = defaultdict(
lambda: (-float('inf'), -float('inf'), -float('inf'), -float('inf'), -float('inf'), [], [])) lambda: (-float('inf'), -float('inf'), -float('inf'), -float('inf'), -float('inf'), [], []))
...@@ -80,7 +79,7 @@ class CTCPrefixBeamSearch: ...@@ -80,7 +79,7 @@ class CTCPrefixBeamSearch:
for s in top_k_index: for s in top_k_index:
s = s.item() s = s.item()
ps = logp[s].item() ps = logp[s].item()
for prefix, (pb, pnb, v_s, v_ns, cur_token_prob, times_s, for prefix, (pb, pnb, v_b_s, v_nb_s, cur_token_prob, times_s,
times_ns) in self.cur_hyps: times_ns) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank if s == blank_id: # blank
...@@ -88,9 +87,9 @@ class CTCPrefixBeamSearch: ...@@ -88,9 +87,9 @@ class CTCPrefixBeamSearch:
prefix] prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps]) n_pb = log_add([n_pb, pb + ps, pnb + ps])
pre_times = times_s if v_s > v_ns else times_ns pre_times = times_s if v_b_s > v_nb_s else times_ns
n_times_s = copy.deepcopy(pre_times) n_times_s = copy.deepcopy(pre_times)
viterbi_score = v_s if v_s > v_ns else v_ns viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s
n_v_s = viterbi_score + ps n_v_s = viterbi_score + ps
next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns,
n_cur_token_prob, n_times_s, n_cur_token_prob, n_times_s,
...@@ -101,8 +100,8 @@ class CTCPrefixBeamSearch: ...@@ -101,8 +100,8 @@ class CTCPrefixBeamSearch:
n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
prefix] prefix]
n_pnb = log_add([n_pnb, pnb + ps]) n_pnb = log_add([n_pnb, pnb + ps])
if n_v_ns < v_ns + ps: if n_v_ns < v_nb_s + ps:
n_v_ns = v_ns + ps n_v_ns = v_nb_s + ps
if n_cur_token_prob < ps: if n_cur_token_prob < ps:
n_cur_token_prob = ps n_cur_token_prob = ps
n_times_ns = copy.deepcopy(times_ns) n_times_ns = copy.deepcopy(times_ns)
...@@ -117,8 +116,8 @@ class CTCPrefixBeamSearch: ...@@ -117,8 +116,8 @@ class CTCPrefixBeamSearch:
n_prefix = prefix + (s, ) n_prefix = prefix + (s, )
n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
n_prefix] n_prefix]
if n_v_ns < v_s + ps: if n_v_ns < v_b_s + ps:
n_v_ns = v_s + ps n_v_ns = v_b_s + ps
n_cur_token_prob = ps n_cur_token_prob = ps
n_times_ns = copy.deepcopy(times_s) n_times_ns = copy.deepcopy(times_s)
n_times_ns.append(self.abs_time_step) n_times_ns.append(self.abs_time_step)
...@@ -129,10 +128,10 @@ class CTCPrefixBeamSearch: ...@@ -129,10 +128,10 @@ class CTCPrefixBeamSearch:
else: else:
# Case 3: *a + b => *ab, *aε + b => *ab # Case 3: *a + b => *ab, *aε + b => *ab
n_prefix = prefix + (s, ) n_prefix = prefix + (s, )
n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_n = next_hyps[ n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
n_prefix] n_prefix]
viterbi_score = v_s if v_s > v_ns else v_ns viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s
pre_times = times_s if v_s > v_ns else times_ns pre_times = times_s if v_b_s > v_nb_s else times_ns
if n_v_ns < viterbi_score + ps: if n_v_ns < viterbi_score + ps:
n_v_ns = viterbi_score + ps n_v_ns = viterbi_score + ps
n_cur_token_prob = ps n_cur_token_prob = ps
...@@ -153,7 +152,7 @@ class CTCPrefixBeamSearch: ...@@ -153,7 +152,7 @@ class CTCPrefixBeamSearch:
# 2.3 update the absolute time step # 2.3 update the absolute time step
self.abs_time_step += 1 self.abs_time_step += 1
# self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]]), y[1][2], y[1][3], self.hyps = [(y[0], log_add([y[1][0], y[1][1]]), y[1][2], y[1][3],
y[1][4], y[1][5], y[1][6]) for y in self.cur_hyps] y[1][4], y[1][5], y[1][6]) for y in self.cur_hyps]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册