提交 cbd8383d 编写于 作者: X xiongxinlei

streaming asr server add time stamp, test=doc

上级 774ec8b0
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import copy import copy
import os import os
import time
from typing import Optional from typing import Optional
import numpy as np import numpy as np
...@@ -298,6 +297,7 @@ class PaddleASRConnectionHanddler: ...@@ -298,6 +297,7 @@ class PaddleASRConnectionHanddler:
self.global_frame_offset = 0 self.global_frame_offset = 0
self.result_transcripts = [''] self.result_transcripts = ['']
self.first_char_occur_elapsed = None self.first_char_occur_elapsed = None
self.word_time_stamp = None
def decode(self, is_finished=False): def decode(self, is_finished=False):
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
...@@ -513,6 +513,12 @@ class PaddleASRConnectionHanddler: ...@@ -513,6 +513,12 @@ class PaddleASRConnectionHanddler:
else: else:
return '' return ''
def get_word_time_stamp(self):
if self.word_time_stamp is None:
return []
else:
return self.word_time_stamp
@paddle.no_grad() @paddle.no_grad()
def rescoring(self): def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
...@@ -577,8 +583,35 @@ class PaddleASRConnectionHanddler: ...@@ -577,8 +583,35 @@ class PaddleASRConnectionHanddler:
# update the one best result # update the one best result
logger.info(f"best index: {best_index}") logger.info(f"best index: {best_index}")
self.hyps = [hyps[best_index][0]] self.hyps = [hyps[best_index][0]]
# update the hyps time stamp
self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[
best_index][3] else hyps[best_index][6]
logger.info(f"time stamp: {self.time_stamp}")
self.update_result() self.update_result()
# update each word start and end time stamp
frame_shift_in_ms = self.model.encoder.embed.subsampling_rate * self.n_shift / self.sample_rate
logger.info(f"frame shift ms: {frame_shift_in_ms}")
word_time_stamp = []
for idx, _ in enumerate(self.time_stamp):
start = (self.time_stamp[idx - 1] + self.time_stamp[idx]
) / 2.0 if idx > 0 else 0
start = start * frame_shift_in_ms
end = (self.time_stamp[idx] + self.time_stamp[idx + 1]
) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset
end = end * frame_shift_in_ms
word_time_stamp.append({
"w": self.result_transcripts[0][idx],
"bg": start,
"ed": end
})
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}")
self.word_time_stamp = word_time_stamp
logger.info(f"word time stamp: {self.word_time_stamp}")
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
def __init__(self): def __init__(self):
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from collections import defaultdict from collections import defaultdict
import paddle import paddle
...@@ -54,14 +55,24 @@ class CTCPrefixBeamSearch: ...@@ -54,14 +55,24 @@ class CTCPrefixBeamSearch:
assert len(ctc_probs.shape) == 2 assert len(ctc_probs.shape) == 2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain # 0. blank_ending_score,
# 1. none_blank_ending_score,
# 2. viterbi_blank ending,
# 3. viterbi_non_blank,
# 4. current_token_prob,
# 5. times_viterbi_blank,
# 6. times_titerbi_non_blank
if self.cur_hyps is None: if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))] self.cur_hyps = [(tuple(), (0.0, -float('inf'), 0.0, 0.0,
-float('inf'), [], []))]
# self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step # 2. CTC beam search step by step
for t in range(0, maxlen): for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,) logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf) # key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) # next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
next_hyps = defaultdict(
lambda: (-float('inf'), -float('inf'), -float('inf'), -float('inf'), -float('inf'), [], []))
# 2.1 First beam prune: select topk best # 2.1 First beam prune: select topk best
# do token passing process # do token passing process
...@@ -69,36 +80,83 @@ class CTCPrefixBeamSearch: ...@@ -69,36 +80,83 @@ class CTCPrefixBeamSearch:
for s in top_k_index: for s in top_k_index:
s = s.item() s = s.item()
ps = logp[s].item() ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps: for prefix, (pb, pnb, v_s, v_ns, cur_token_prob, times_s,
times_ns) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix] n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps]) n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
pre_times = times_s if v_s > v_ns else times_ns
n_times_s = copy.deepcopy(pre_times)
viterbi_score = v_s if v_s > v_ns else v_ns
n_v_s = viterbi_score + ps
next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns,
n_cur_token_prob, n_times_s,
n_times_ns)
elif s == last: elif s == last:
# Update *ss -> *s; # Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix] # case1: *a + a => *a
n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
prefix]
n_pnb = log_add([n_pnb, pnb + ps]) n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb) if n_v_ns < v_ns + ps:
n_v_ns = v_ns + ps
if n_cur_token_prob < ps:
n_cur_token_prob = ps
n_times_ns = copy.deepcopy(times_ns)
n_times_ns[
-1] = self.abs_time_step # 注意,这里要重新使用绝对时间
next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns,
n_cur_token_prob, n_times_s,
n_times_ns)
# Update *s-s -> *ss, - is for blank # Update *s-s -> *ss, - is for blank
# Case 2: *aε + a => *aa
n_prefix = prefix + (s, ) n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix] n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
n_prefix]
if n_v_ns < v_s + ps:
n_v_ns = v_s + ps
n_cur_token_prob = ps
n_times_ns = copy.deepcopy(times_s)
n_times_ns.append(self.abs_time_step)
n_pnb = log_add([n_pnb, pb + ps]) n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb) next_hyps[n_prefix] = (n_pb, n_pnb, n_v_s, n_v_ns,
n_cur_token_prob, n_times_s,
n_times_ns)
else: else:
# Case 3: *a + b => *ab, *aε + b => *ab
n_prefix = prefix + (s, ) n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix] n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_n = next_hyps[
n_prefix]
viterbi_score = v_s if v_s > v_ns else v_ns
pre_times = times_s if v_s > v_ns else times_ns
if n_v_ns < viterbi_score + ps:
n_v_ns = viterbi_score + ps
n_cur_token_prob = ps
n_times_ns = copy.deepcopy(pre_times)
n_times_ns.append(self.abs_time_step)
n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb) next_hyps[n_prefix] = (n_pb, n_pnb, n_v_s, n_v_ns,
n_cur_token_prob, n_times_s,
n_times_ns)
# 2.2 Second beam prune # 2.2 Second beam prune
next_hyps = sorted( next_hyps = sorted(
next_hyps.items(), next_hyps.items(),
key=lambda x: log_add(list(x[1])), key=lambda x: log_add([x[1][0], x[1][1]]),
reverse=True) reverse=True)
self.cur_hyps = next_hyps[:beam_size] self.cur_hyps = next_hyps[:beam_size]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] # 2.3 update the absolute time step
self.abs_time_step += 1
# self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]]), y[1][2], y[1][3],
y[1][4], y[1][5], y[1][6]) for y in self.cur_hyps]
logger.info("ctc prefix search success") logger.info("ctc prefix search success")
return self.hyps return self.hyps
...@@ -123,6 +181,7 @@ class CTCPrefixBeamSearch: ...@@ -123,6 +181,7 @@ class CTCPrefixBeamSearch:
""" """
self.cur_hyps = None self.cur_hyps = None
self.hyps = None self.hyps = None
self.abs_time_step = 0
def finalize_search(self): def finalize_search(self):
"""do nothing in ctc_prefix_beam_search """do nothing in ctc_prefix_beam_search
......
...@@ -78,12 +78,14 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -78,12 +78,14 @@ async def websocket_endpoint(websocket: WebSocket):
connection_handler.decode(is_finished=True) connection_handler.decode(is_finished=True)
connection_handler.rescoring() connection_handler.rescoring()
asr_results = connection_handler.get_result() asr_results = connection_handler.get_result()
word_time_stamp = connection_handler.get_word_time_stamp()
connection_handler.reset() connection_handler.reset()
resp = { resp = {
"status": "ok", "status": "ok",
"signal": "finished", "signal": "finished",
'result': asr_results 'result': asr_results,
'times': word_time_stamp
} }
await websocket.send_json(resp) await websocket.send_json(resp)
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册