Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
cbd8383d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cbd8383d
编写于
5月 04, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
streaming asr server add time stamp, test=doc
上级
774ec8b0
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
110 addition
and
16 deletion
+110
-16
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+34
-1
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+73
-14
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+3
-1
未找到文件。
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
cbd8383d
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
copy
import
os
import
time
from
typing
import
Optional
import
numpy
as
np
...
...
@@ -298,6 +297,7 @@ class PaddleASRConnectionHanddler:
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
self
.
first_char_occur_elapsed
=
None
self
.
word_time_stamp
=
None
def
decode
(
self
,
is_finished
=
False
):
if
"deepspeech2online"
in
self
.
model_type
:
...
...
@@ -513,6 +513,12 @@ class PaddleASRConnectionHanddler:
else
:
return
''
def
get_word_time_stamp
(
self
):
if
self
.
word_time_stamp
is
None
:
return
[]
else
:
return
self
.
word_time_stamp
@
paddle
.
no_grad
()
def
rescoring
(
self
):
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
...
...
@@ -577,8 +583,35 @@ class PaddleASRConnectionHanddler:
# update the one best result
logger
.
info
(
f
"best index:
{
best_index
}
"
)
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
# update the hyps time stamp
self
.
time_stamp
=
hyps
[
best_index
][
5
]
if
hyps
[
best_index
][
2
]
>
hyps
[
best_index
][
3
]
else
hyps
[
best_index
][
6
]
logger
.
info
(
f
"time stamp:
{
self
.
time_stamp
}
"
)
self
.
update_result
()
# update each word start and end time stamp
frame_shift_in_ms
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
*
self
.
n_shift
/
self
.
sample_rate
logger
.
info
(
f
"frame shift ms:
{
frame_shift_in_ms
}
"
)
word_time_stamp
=
[]
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
)
/
2.0
if
idx
>
0
else
0
start
=
start
*
frame_shift_in_ms
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
end
=
end
*
frame_shift_in_ms
word_time_stamp
.
append
({
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"bg"
:
start
,
"ed"
:
end
})
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}")
self
.
word_time_stamp
=
word_time_stamp
logger
.
info
(
f
"word time stamp:
{
self
.
word_time_stamp
}
"
)
class
ASRServerExecutor
(
ASRExecutor
):
def
__init__
(
self
):
...
...
paddlespeech/server/engine/asr/online/ctc_search.py
浏览文件 @
cbd8383d
...
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
from
collections
import
defaultdict
import
paddle
...
...
@@ -54,14 +55,24 @@ class CTCPrefixBeamSearch:
assert
len
(
ctc_probs
.
shape
)
==
2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
# 0. blank_ending_score,
# 1. none_blank_ending_score,
# 2. viterbi_blank ending,
# 3. viterbi_non_blank,
# 4. current_token_prob,
# 5. times_viterbi_blank,
# 6. times_titerbi_non_blank
if
self
.
cur_hyps
is
None
:
self
.
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
self
.
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
),
0.0
,
0.0
,
-
float
(
'inf'
),
[],
[]))]
# self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
logp
=
ctc_probs
[
t
]
# (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
)))
# next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
),
-
float
(
'inf'
),
-
float
(
'inf'
),
-
float
(
'inf'
),
[],
[]))
# 2.1 First beam prune: select topk best
# do token passing process
...
...
@@ -69,36 +80,83 @@ class CTCPrefixBeamSearch:
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
for
prefix
,
(
pb
,
pnb
)
in
self
.
cur_hyps
:
for
prefix
,
(
pb
,
pnb
,
v_s
,
v_ns
,
cur_token_prob
,
times_s
,
times_ns
)
in
self
.
cur_hyps
:
last
=
prefix
[
-
1
]
if
len
(
prefix
)
>
0
else
None
if
s
==
blank_id
:
# blank
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
=
next_hyps
[
prefix
]
n_pb
=
log_add
([
n_pb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
pre_times
=
times_s
if
v_s
>
v_ns
else
times_ns
n_times_s
=
copy
.
deepcopy
(
pre_times
)
viterbi_score
=
v_s
if
v_s
>
v_ns
else
v_ns
n_v_s
=
viterbi_score
+
ps
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
)
elif
s
==
last
:
# Update *ss -> *s;
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
# case1: *a + a => *a
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
=
next_hyps
[
prefix
]
n_pnb
=
log_add
([
n_pnb
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
if
n_v_ns
<
v_ns
+
ps
:
n_v_ns
=
v_ns
+
ps
if
n_cur_token_prob
<
ps
:
n_cur_token_prob
=
ps
n_times_ns
=
copy
.
deepcopy
(
times_ns
)
n_times_ns
[
-
1
]
=
self
.
abs_time_step
# 注意,这里要重新使用绝对时间
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
)
# Update *s-s -> *ss, - is for blank
# Case 2: *aε + a => *aa
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
=
next_hyps
[
n_prefix
]
if
n_v_ns
<
v_s
+
ps
:
n_v_ns
=
v_s
+
ps
n_cur_token_prob
=
ps
n_times_ns
=
copy
.
deepcopy
(
times_s
)
n_times_ns
.
append
(
self
.
abs_time_step
)
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
)
else
:
# Case 3: *a + b => *ab, *aε + b => *ab
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_n
=
next_hyps
[
n_prefix
]
viterbi_score
=
v_s
if
v_s
>
v_ns
else
v_ns
pre_times
=
times_s
if
v_s
>
v_ns
else
times_ns
if
n_v_ns
<
viterbi_score
+
ps
:
n_v_ns
=
viterbi_score
+
ps
n_cur_token_prob
=
ps
n_times_ns
=
copy
.
deepcopy
(
pre_times
)
n_times_ns
.
append
(
self
.
abs_time_step
)
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
)
# 2.2 Second beam prune
next_hyps
=
sorted
(
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])
),
key
=
lambda
x
:
log_add
(
[
x
[
1
][
0
],
x
[
1
][
1
]]
),
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
self
.
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
self
.
cur_hyps
]
# 2.3 update the absolute time step
self
.
abs_time_step
+=
1
# self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
self
.
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]),
y
[
1
][
2
],
y
[
1
][
3
],
y
[
1
][
4
],
y
[
1
][
5
],
y
[
1
][
6
])
for
y
in
self
.
cur_hyps
]
logger
.
info
(
"ctc prefix search success"
)
return
self
.
hyps
...
...
@@ -123,6 +181,7 @@ class CTCPrefixBeamSearch:
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
def
finalize_search
(
self
):
"""do nothing in ctc_prefix_beam_search
...
...
paddlespeech/server/ws/asr_socket.py
浏览文件 @
cbd8383d
...
...
@@ -78,12 +78,14 @@ async def websocket_endpoint(websocket: WebSocket):
connection_handler
.
decode
(
is_finished
=
True
)
connection_handler
.
rescoring
()
asr_results
=
connection_handler
.
get_result
()
word_time_stamp
=
connection_handler
.
get_word_time_stamp
()
connection_handler
.
reset
()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'result'
:
asr_results
'result'
:
asr_results
,
'times'
:
word_time_stamp
}
await
websocket
.
send_json
(
resp
)
break
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录