Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
97d31f9a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97d31f9a
编写于
4月 17, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update the attention_rescoring method, test=doc
上级
0c5dbbee
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
287 addition
and
120 deletion
+287
-120
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+5
-11
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+92
-85
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+119
-0
paddlespeech/server/tests/__init__.py
paddlespeech/server/tests/__init__.py
+13
-0
paddlespeech/server/tests/asr/__init__.py
paddlespeech/server/tests/asr/__init__.py
+13
-0
paddlespeech/server/tests/asr/offline/__init__.py
paddlespeech/server/tests/asr/offline/__init__.py
+13
-0
paddlespeech/server/tests/asr/online/__init__.py
paddlespeech/server/tests/asr/online/__init__.py
+13
-0
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+19
-24
未找到文件。
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
97d31f9a
...
@@ -71,15 +71,9 @@ asr_online:
...
@@ -71,15 +71,9 @@ asr_online:
summary
:
True
# False -> do not show predictor config
summary
:
True
# False -> do not show predictor config
chunk_buffer_conf
:
chunk_buffer_conf
:
frame_duration_ms
:
85
window_n
:
7
# frame
shift_ms
:
40
shift_n
:
4
# frame
window_ms
:
25
# ms
shift_ms
:
10
# ms
sample_rate
:
16000
sample_rate
:
16000
sample_width
:
2
sample_width
:
2
\ No newline at end of file
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9
\ No newline at end of file
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
97d31f9a
...
@@ -12,9 +12,8 @@
...
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
os
from
paddlespeech.s2t.utils.utility
import
log_add
from
typing
import
Optional
from
typing
import
Optional
from
collections
import
defaultdict
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
numpy
import
float32
from
numpy
import
float32
...
@@ -22,19 +21,18 @@ from yacs.config import CfgNode
...
@@ -22,19 +21,18 @@ from yacs.config import CfgNode
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
model_alias
from
paddlespeech.cli.asr.infer
import
model_alias
from
paddlespeech.cli.asr.infer
import
pretrained_models
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.utils
import
download_and_decompress
from
paddlespeech.cli.utils
import
download_and_decompress
from
paddlespeech.cli.utils
import
MODEL_HOME
from
paddlespeech.cli.utils
import
MODEL_HOME
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.speech
import
SpeechSegment
from
paddlespeech.s2t.frontend.speech
import
SpeechSegment
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.mask
import
mask_finished_preds
from
paddlespeech.s2t.modules.mask
import
mask_finished_scores
from
paddlespeech.s2t.modules.mask
import
subsequent_mask
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.s2t.utils.tensor_utils
import
pad_sequence
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.asr.online.ctc_search
import
CTCPrefixBeamSearch
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
...
@@ -62,9 +60,9 @@ pretrained_models = {
...
@@ -62,9 +60,9 @@ pretrained_models = {
},
},
"conformer2online_aishell-zh-16k"
:
{
"conformer2online_aishell-zh-16k"
:
{
'url'
:
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.
0
.model.tar.gz'
,
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.
1
.model.tar.gz'
,
'md5'
:
'md5'
:
'
7989b3248c898070904cf042fd656003
'
,
'
b450d5dfaea0ac227c595ce58d18b637
'
,
'cfg_path'
:
'cfg_path'
:
'model.yaml'
,
'model.yaml'
,
'ckpt_path'
:
'ckpt_path'
:
...
@@ -123,9 +121,9 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -123,9 +121,9 @@ class ASRServerExecutor(ASRExecutor):
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
self
.
res_path
=
res_path
self
.
res_path
=
res_path
self
.
cfg_path
=
"/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
#
self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
#
self.cfg_path = os.path.join(res_path,
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
#
pretrained_models[tag]['cfg_path'])
pretrained_models
[
tag
][
'cfg_path'
])
self
.
am_model
=
os
.
path
.
join
(
res_path
,
self
.
am_model
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'model'
])
pretrained_models
[
tag
][
'model'
])
...
@@ -177,6 +175,18 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -177,6 +175,18 @@ class ASRServerExecutor(ASRExecutor):
# update the decoding method
# update the decoding method
if
decode_method
:
if
decode_method
:
self
.
config
.
decode
.
decoding_method
=
decode_method
self
.
config
.
decode
.
decoding_method
=
decode_method
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if
self
.
config
.
decode
.
decoding_method
not
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
]:
logger
.
info
(
"we set the decoding_method to attention_rescoring"
)
self
.
config
.
decode
.
decoding
=
"attention_rescoring"
assert
self
.
config
.
decode
.
decoding_method
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
else
:
else
:
raise
Exception
(
"wrong type"
)
raise
Exception
(
"wrong type"
)
if
"deepspeech2online"
in
model_type
or
"deepspeech2offline"
in
model_type
:
if
"deepspeech2online"
in
model_type
or
"deepspeech2offline"
in
model_type
:
...
@@ -232,7 +242,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -232,7 +242,7 @@ class ASRServerExecutor(ASRExecutor):
logger
.
info
(
"create the transformer like model success"
)
logger
.
info
(
"create the transformer like model success"
)
# update the ctc decoding
# update the ctc decoding
self
.
searcher
=
None
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
config
.
decode
)
self
.
transformer_decode_reset
()
self
.
transformer_decode_reset
()
def
reset_decoder_and_chunk
(
self
):
def
reset_decoder_and_chunk
(
self
):
...
@@ -320,7 +330,16 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -320,7 +330,16 @@ class ASRServerExecutor(ASRExecutor):
def
advanced_decoding
(
self
,
xs
:
paddle
.
Tensor
,
x_chunk_lens
):
def
advanced_decoding
(
self
,
xs
:
paddle
.
Tensor
,
x_chunk_lens
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
logger
.
info
(
"start to decode with advanced_decoding method"
)
encoder_out
,
encoder_mask
=
self
.
decode_forward
(
xs
)
encoder_out
,
encoder_mask
=
self
.
decode_forward
(
xs
)
self
.
ctc_prefix_beam_search
(
xs
,
encoder_out
,
encoder_mask
)
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
self
.
searcher
.
search
(
xs
,
ctc_probs
,
xs
.
place
)
# update the one best result
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if
"attention_rescoring"
in
self
.
config
.
decode
.
decoding_method
:
self
.
rescoring
(
encoder_out
,
xs
.
place
)
def
decode_forward
(
self
,
xs
):
def
decode_forward
(
self
,
xs
):
logger
.
info
(
"get the model out from the feat"
)
logger
.
info
(
"get the model out from the feat"
)
...
@@ -338,7 +357,6 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -338,7 +357,6 @@ class ASRServerExecutor(ASRExecutor):
num_frames
=
xs
.
shape
[
1
]
num_frames
=
xs
.
shape
[
1
]
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
logger
.
info
(
"start to do model forward"
)
logger
.
info
(
"start to do model forward"
)
outputs
=
[]
outputs
=
[]
...
@@ -359,85 +377,74 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -359,85 +377,74 @@ class ASRServerExecutor(ASRExecutor):
masks
=
masks
.
unsqueeze
(
1
)
masks
=
masks
.
unsqueeze
(
1
)
return
ys
,
masks
return
ys
,
masks
def
rescoring
(
self
,
encoder_out
,
device
):
logger
.
info
(
"start to rescoring the hyps"
)
beam_size
=
self
.
config
.
decode
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
assert
len
(
hyps
)
==
beam_size
hyp_list
=
[]
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
model
.
sos
,
self
.
model
.
eos
,
self
.
model
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
model
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
config
.
decode
.
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
# update the one best result
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
return
hyps
[
best_index
][
0
]
def
transformer_decode_reset
(
self
):
def
transformer_decode_reset
(
self
):
self
.
subsampling_cache
=
None
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
hyps
=
None
self
.
offset
=
0
self
.
offset
=
0
self
.
cur_hyps
=
None
# decoding reset
self
.
hyps
=
None
self
.
searcher
.
reset
()
def
ctc_prefix_beam_search
(
self
,
xs
,
encoder_out
,
encoder_mask
,
blank_id
=
0
):
# decode
logger
.
info
(
"start to ctc prefix search"
)
device
=
xs
.
place
cfg
=
self
.
config
.
decode
batch_size
=
xs
.
shape
[
0
]
beam_size
=
cfg
.
beam_size
maxlen
=
encoder_out
.
shape
[
1
]
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if
self
.
cur_hyps
is
None
:
self
.
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
logp
=
ctc_probs
[
t
]
# (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
)))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
for
prefix
,
(
pb
,
pnb
)
in
self
.
cur_hyps
:
last
=
prefix
[
-
1
]
if
len
(
prefix
)
>
0
else
None
if
s
==
blank_id
:
# blank
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pb
=
log_add
([
n_pb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
elif
s
==
last
:
# Update *ss -> *s;
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pnb
=
log_add
([
n_pnb
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
# Update *s-s -> *ss, - is for blank
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
else
:
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
# 2.2 Second beam prune
next_hyps
=
sorted
(
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])),
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
self
.
cur_hyps
]
self
.
hyps
=
[
hyps
[
0
][
0
]]
logger
.
info
(
"ctc prefix search success"
)
return
hyps
,
encoder_out
def
update_result
(
self
):
def
update_result
(
self
):
logger
.
info
(
"update the final result"
)
logger
.
info
(
"update the final result"
)
hyps
=
self
.
hyps
self
.
result_transcripts
=
[
self
.
result_transcripts
=
[
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
self
.
hyps
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
self
.
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
extract_feat
(
self
,
samples
,
sample_rate
):
def
extract_feat
(
self
,
samples
,
sample_rate
):
"""extract feat
"""extract feat
...
@@ -483,9 +490,9 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -483,9 +490,9 @@ class ASRServerExecutor(ASRExecutor):
elif
"conformer2online"
in
self
.
model_type
:
elif
"conformer2online"
in
self
.
model_type
:
if
sample_rate
!=
self
.
sample_rate
:
if
sample_rate
!=
self
.
sample_rate
:
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
\
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
"the model sample_rate is {self.sample_rate}"
)
"the model sample_rate is {self.sample_rate}"
)
logger
.
info
(
f
"ASR Engine use the
{
self
.
model_type
}
to process"
)
logger
.
info
(
"ASR Engine use the {self.model_type} to process"
)
logger
.
info
(
"Create the preprocess instance"
)
logger
.
info
(
"Create the preprocess instance"
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
preprocess_args
=
{
"train"
:
False
}
...
...
paddlespeech/server/engine/asr/online/ctc_search.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
paddlespeech.cli.log
import
logger
from
paddlespeech.s2t.utils.utility
import
log_add
__all__
=
[
'CTCPrefixBeamSearch'
]
class
CTCPrefixBeamSearch
:
def
__init__
(
self
,
config
):
"""Implement the ctc prefix beam search
Args:
config (_type_): _description_
"""
self
.
config
=
config
self
.
reset
()
def
search
(
self
,
xs
,
ctc_probs
,
device
,
blank_id
=
0
):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
encoder_out (paddle.Tensor): _description_
encoder_mask (_type_): _description_
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger
.
info
(
"start to ctc prefix search"
)
# device = xs.place
batch_size
=
xs
.
shape
[
0
]
beam_size
=
self
.
config
.
beam_size
maxlen
=
ctc_probs
.
shape
[
0
]
assert
len
(
ctc_probs
.
shape
)
==
2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if
self
.
cur_hyps
is
None
:
self
.
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
logp
=
ctc_probs
[
t
]
# (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
)))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
for
prefix
,
(
pb
,
pnb
)
in
self
.
cur_hyps
:
last
=
prefix
[
-
1
]
if
len
(
prefix
)
>
0
else
None
if
s
==
blank_id
:
# blank
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pb
=
log_add
([
n_pb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
elif
s
==
last
:
# Update *ss -> *s;
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pnb
=
log_add
([
n_pnb
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
# Update *s-s -> *ss, - is for blank
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
else
:
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
# 2.2 Second beam prune
next_hyps
=
sorted
(
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])),
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
self
.
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
self
.
cur_hyps
]
logger
.
info
(
"ctc prefix search success"
)
return
self
.
hyps
def
get_one_best_hyps
(
self
):
"""Return the one best result
Returns:
list: the one best result
"""
return
[
self
.
hyps
[
0
][
0
]]
def
get_hyps
(
self
):
return
self
.
hyps
def
reset
(
self
):
"""Rest the search cache value
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
paddlespeech/server/tests/__init__.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/__init__.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/offline/__init__.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/online/__init__.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/ws/asr_socket.py
浏览文件 @
97d31f9a
...
@@ -34,17 +34,17 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -34,17 +34,17 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool
=
get_engine_pool
()
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_engine
=
engine_pool
[
'asr'
]
# init buffer
# init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf
=
asr_engine
.
config
.
chunk_buffer_conf
chunk_buffer_conf
=
asr_engine
.
config
.
chunk_buffer_conf
chunk_buffer
=
ChunkBuffer
(
chunk_buffer
=
ChunkBuffer
(
window_n
=
7
,
window_n
=
chunk_buffer_conf
.
window_n
,
shift_n
=
4
,
shift_n
=
chunk_buffer_conf
.
shift_n
,
window_ms
=
20
,
window_ms
=
chunk_buffer_conf
.
window_ms
,
shift_ms
=
10
,
shift_ms
=
chunk_buffer_conf
.
shift_ms
,
sample_rate
=
chunk_buffer_conf
[
'sample_rate'
],
sample_rate
=
chunk_buffer_conf
.
sample_rate
,
sample_width
=
chunk_buffer_conf
[
'sample_width'
])
sample_width
=
chunk_buffer_conf
.
sample_width
)
# init vad
# init vad
# print(asr_engine.config)
# print(type(asr_engine.config))
vad_conf
=
asr_engine
.
config
.
get
(
'vad_conf'
,
None
)
vad_conf
=
asr_engine
.
config
.
get
(
'vad_conf'
,
None
)
if
vad_conf
:
if
vad_conf
:
vad
=
VADAudio
(
vad
=
VADAudio
(
...
@@ -72,7 +72,7 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -72,7 +72,7 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool
=
get_engine_pool
()
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_engine
=
engine_pool
[
'asr'
]
# reset single engine for an new connection
# reset single engine for an new connection
#
asr_engine.reset()
asr_engine
.
reset
()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
}
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
}
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
break
break
...
@@ -85,21 +85,16 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -85,21 +85,16 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool
=
get_engine_pool
()
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_engine
=
engine_pool
[
'asr'
]
asr_results
=
""
asr_results
=
""
# frames = chunk_buffer.frame_generator(message)
frames
=
chunk_buffer
.
frame_generator
(
message
)
# for frame in frames:
for
frame
in
frames
:
# # get the pcm data from the bytes
# get the pcm data from the bytes
# samples = np.frombuffer(frame.bytes, dtype=np.int16)
samples
=
np
.
frombuffer
(
frame
.
bytes
,
dtype
=
np
.
int16
)
# sample_rate = asr_engine.config.sample_rate
sample_rate
=
asr_engine
.
config
.
sample_rate
# x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
x_chunk
,
x_chunk_lens
=
asr_engine
.
preprocess
(
samples
,
# sample_rate)
sample_rate
)
# asr_engine.run(x_chunk, x_chunk_lens)
asr_engine
.
run
(
x_chunk
,
x_chunk_lens
)
# asr_results = asr_engine.postprocess()
asr_results
=
asr_engine
.
postprocess
()
samples
=
np
.
frombuffer
(
message
,
dtype
=
np
.
int16
)
sample_rate
=
asr_engine
.
config
.
sample_rate
x_chunk
,
x_chunk_lens
=
asr_engine
.
preprocess
(
samples
,
sample_rate
)
asr_engine
.
run
(
x_chunk
,
x_chunk_lens
)
# asr_results = asr_engine.postprocess()
asr_results
=
asr_engine
.
postprocess
()
asr_results
=
asr_engine
.
postprocess
()
resp
=
{
'asr_results'
:
asr_results
}
resp
=
{
'asr_results'
:
asr_results
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录