Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
97d31f9a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
97d31f9a
编写于
4月 17, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update the attention_rescoring method, test=doc
上级
0c5dbbee
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
287 addition
and
120 deletion
+287
-120
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+5
-11
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+92
-85
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+119
-0
paddlespeech/server/tests/__init__.py
paddlespeech/server/tests/__init__.py
+13
-0
paddlespeech/server/tests/asr/__init__.py
paddlespeech/server/tests/asr/__init__.py
+13
-0
paddlespeech/server/tests/asr/offline/__init__.py
paddlespeech/server/tests/asr/offline/__init__.py
+13
-0
paddlespeech/server/tests/asr/online/__init__.py
paddlespeech/server/tests/asr/online/__init__.py
+13
-0
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+19
-24
未找到文件。
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
97d31f9a
...
...
@@ -71,15 +71,9 @@ asr_online:
summary
:
True
# False -> do not show predictor config
chunk_buffer_conf
:
frame_duration_ms
:
85
shift_ms
:
40
window_n
:
7
# frame
shift_n
:
4
# frame
window_ms
:
25
# ms
shift_ms
:
10
# ms
sample_rate
:
16000
sample_width
:
2
\ No newline at end of file
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9
\ No newline at end of file
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
97d31f9a
...
...
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
paddlespeech.s2t.utils.utility
import
log_add
from
typing
import
Optional
from
collections
import
defaultdict
import
numpy
as
np
import
paddle
from
numpy
import
float32
...
...
@@ -22,19 +21,18 @@ from yacs.config import CfgNode
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
model_alias
from
paddlespeech.cli.asr.infer
import
pretrained_models
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.utils
import
download_and_decompress
from
paddlespeech.cli.utils
import
MODEL_HOME
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.speech
import
SpeechSegment
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.mask
import
mask_finished_preds
from
paddlespeech.s2t.modules.mask
import
mask_finished_scores
from
paddlespeech.s2t.modules.mask
import
subsequent_mask
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.s2t.utils.tensor_utils
import
pad_sequence
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.asr.online.ctc_search
import
CTCPrefixBeamSearch
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
...
...
@@ -62,9 +60,9 @@ pretrained_models = {
},
"conformer2online_aishell-zh-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.
0
.model.tar.gz'
,
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.
1
.model.tar.gz'
,
'md5'
:
'
7989b3248c898070904cf042fd656003
'
,
'
b450d5dfaea0ac227c595ce58d18b637
'
,
'cfg_path'
:
'model.yaml'
,
'ckpt_path'
:
...
...
@@ -123,9 +121,9 @@ class ASRServerExecutor(ASRExecutor):
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
self
.
res_path
=
res_path
self
.
cfg_path
=
"/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
#
self.cfg_path = os.path.join(res_path,
#
pretrained_models[tag]['cfg_path'])
#
self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'cfg_path'
])
self
.
am_model
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'model'
])
...
...
@@ -177,6 +175,18 @@ class ASRServerExecutor(ASRExecutor):
# update the decoding method
if
decode_method
:
self
.
config
.
decode
.
decoding_method
=
decode_method
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if
self
.
config
.
decode
.
decoding_method
not
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
]:
logger
.
info
(
"we set the decoding_method to attention_rescoring"
)
self
.
config
.
decode
.
decoding
=
"attention_rescoring"
assert
self
.
config
.
decode
.
decoding_method
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
else
:
raise
Exception
(
"wrong type"
)
if
"deepspeech2online"
in
model_type
or
"deepspeech2offline"
in
model_type
:
...
...
@@ -232,7 +242,7 @@ class ASRServerExecutor(ASRExecutor):
logger
.
info
(
"create the transformer like model success"
)
# update the ctc decoding
self
.
searcher
=
None
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
config
.
decode
)
self
.
transformer_decode_reset
()
def
reset_decoder_and_chunk
(
self
):
...
...
@@ -320,7 +330,16 @@ class ASRServerExecutor(ASRExecutor):
def
advanced_decoding
(
self
,
xs
:
paddle
.
Tensor
,
x_chunk_lens
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
encoder_out
,
encoder_mask
=
self
.
decode_forward
(
xs
)
self
.
ctc_prefix_beam_search
(
xs
,
encoder_out
,
encoder_mask
)
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
self
.
searcher
.
search
(
xs
,
ctc_probs
,
xs
.
place
)
# update the one best result
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if
"attention_rescoring"
in
self
.
config
.
decode
.
decoding_method
:
self
.
rescoring
(
encoder_out
,
xs
.
place
)
def
decode_forward
(
self
,
xs
):
logger
.
info
(
"get the model out from the feat"
)
...
...
@@ -338,7 +357,6 @@ class ASRServerExecutor(ASRExecutor):
num_frames
=
xs
.
shape
[
1
]
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
logger
.
info
(
"start to do model forward"
)
outputs
=
[]
...
...
@@ -359,85 +377,74 @@ class ASRServerExecutor(ASRExecutor):
masks
=
masks
.
unsqueeze
(
1
)
return
ys
,
masks
def
rescoring
(
self
,
encoder_out
,
device
):
logger
.
info
(
"start to rescoring the hyps"
)
beam_size
=
self
.
config
.
decode
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
assert
len
(
hyps
)
==
beam_size
hyp_list
=
[]
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
model
.
sos
,
self
.
model
.
eos
,
self
.
model
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
model
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
config
.
decode
.
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
# update the one best result
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
return
hyps
[
best_index
][
0
]
def
transformer_decode_reset
(
self
):
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
hyps
=
None
self
.
offset
=
0
self
.
cur_hyps
=
None
self
.
hyps
=
None
def
ctc_prefix_beam_search
(
self
,
xs
,
encoder_out
,
encoder_mask
,
blank_id
=
0
):
# decode
logger
.
info
(
"start to ctc prefix search"
)
device
=
xs
.
place
cfg
=
self
.
config
.
decode
batch_size
=
xs
.
shape
[
0
]
beam_size
=
cfg
.
beam_size
maxlen
=
encoder_out
.
shape
[
1
]
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if
self
.
cur_hyps
is
None
:
self
.
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
logp
=
ctc_probs
[
t
]
# (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
)))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
for
prefix
,
(
pb
,
pnb
)
in
self
.
cur_hyps
:
last
=
prefix
[
-
1
]
if
len
(
prefix
)
>
0
else
None
if
s
==
blank_id
:
# blank
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pb
=
log_add
([
n_pb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
elif
s
==
last
:
# Update *ss -> *s;
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pnb
=
log_add
([
n_pnb
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
# Update *s-s -> *ss, - is for blank
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
else
:
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
# 2.2 Second beam prune
next_hyps
=
sorted
(
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])),
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
self
.
cur_hyps
]
self
.
hyps
=
[
hyps
[
0
][
0
]]
logger
.
info
(
"ctc prefix search success"
)
return
hyps
,
encoder_out
# decoding reset
self
.
searcher
.
reset
()
def
update_result
(
self
):
logger
.
info
(
"update the final result"
)
hyps
=
self
.
hyps
self
.
result_transcripts
=
[
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
self
.
hyps
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
self
.
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
extract_feat
(
self
,
samples
,
sample_rate
):
"""extract feat
...
...
@@ -483,9 +490,9 @@ class ASRServerExecutor(ASRExecutor):
elif
"conformer2online"
in
self
.
model_type
:
if
sample_rate
!=
self
.
sample_rate
:
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
\
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
"the model sample_rate is {self.sample_rate}"
)
logger
.
info
(
f
"ASR Engine use the
{
self
.
model_type
}
to process"
)
logger
.
info
(
"ASR Engine use the {self.model_type} to process"
)
logger
.
info
(
"Create the preprocess instance"
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
...
...
paddlespeech/server/engine/asr/online/ctc_search.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
paddlespeech.cli.log
import
logger
from
paddlespeech.s2t.utils.utility
import
log_add
__all__
=
[
'CTCPrefixBeamSearch'
]
class
CTCPrefixBeamSearch
:
def
__init__
(
self
,
config
):
"""Implement the ctc prefix beam search
Args:
config (_type_): _description_
"""
self
.
config
=
config
self
.
reset
()
def
search
(
self
,
xs
,
ctc_probs
,
device
,
blank_id
=
0
):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
encoder_out (paddle.Tensor): _description_
encoder_mask (_type_): _description_
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger
.
info
(
"start to ctc prefix search"
)
# device = xs.place
batch_size
=
xs
.
shape
[
0
]
beam_size
=
self
.
config
.
beam_size
maxlen
=
ctc_probs
.
shape
[
0
]
assert
len
(
ctc_probs
.
shape
)
==
2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if
self
.
cur_hyps
is
None
:
self
.
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
logp
=
ctc_probs
[
t
]
# (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
)))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
for
prefix
,
(
pb
,
pnb
)
in
self
.
cur_hyps
:
last
=
prefix
[
-
1
]
if
len
(
prefix
)
>
0
else
None
if
s
==
blank_id
:
# blank
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pb
=
log_add
([
n_pb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
elif
s
==
last
:
# Update *ss -> *s;
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pnb
=
log_add
([
n_pnb
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
# Update *s-s -> *ss, - is for blank
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
else
:
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
# 2.2 Second beam prune
next_hyps
=
sorted
(
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])),
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
self
.
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
self
.
cur_hyps
]
logger
.
info
(
"ctc prefix search success"
)
return
self
.
hyps
def
get_one_best_hyps
(
self
):
"""Return the one best result
Returns:
list: the one best result
"""
return
[
self
.
hyps
[
0
][
0
]]
def
get_hyps
(
self
):
return
self
.
hyps
def
reset
(
self
):
"""Rest the search cache value
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
paddlespeech/server/tests/__init__.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/__init__.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/offline/__init__.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/online/__init__.py
0 → 100644
浏览文件 @
97d31f9a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/ws/asr_socket.py
浏览文件 @
97d31f9a
...
...
@@ -34,17 +34,17 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
# init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf
=
asr_engine
.
config
.
chunk_buffer_conf
chunk_buffer
=
ChunkBuffer
(
window_n
=
7
,
shift_n
=
4
,
window_ms
=
20
,
shift_ms
=
10
,
sample_rate
=
chunk_buffer_conf
[
'sample_rate'
],
sample_width
=
chunk_buffer_conf
[
'sample_width'
])
window_n
=
chunk_buffer_conf
.
window_n
,
shift_n
=
chunk_buffer_conf
.
shift_n
,
window_ms
=
chunk_buffer_conf
.
window_ms
,
shift_ms
=
chunk_buffer_conf
.
shift_ms
,
sample_rate
=
chunk_buffer_conf
.
sample_rate
,
sample_width
=
chunk_buffer_conf
.
sample_width
)
# init vad
# print(asr_engine.config)
# print(type(asr_engine.config))
vad_conf
=
asr_engine
.
config
.
get
(
'vad_conf'
,
None
)
if
vad_conf
:
vad
=
VADAudio
(
...
...
@@ -72,7 +72,7 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
# reset single engine for an new connection
#
asr_engine.reset()
asr_engine
.
reset
()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
}
await
websocket
.
send_json
(
resp
)
break
...
...
@@ -85,21 +85,16 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_results
=
""
# frames = chunk_buffer.frame_generator(message)
# for frame in frames:
# # get the pcm data from the bytes
# samples = np.frombuffer(frame.bytes, dtype=np.int16)
# sample_rate = asr_engine.config.sample_rate
# x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
# sample_rate)
# asr_engine.run(x_chunk, x_chunk_lens)
# asr_results = asr_engine.postprocess()
samples
=
np
.
frombuffer
(
message
,
dtype
=
np
.
int16
)
frames
=
chunk_buffer
.
frame_generator
(
message
)
for
frame
in
frames
:
# get the pcm data from the bytes
samples
=
np
.
frombuffer
(
frame
.
bytes
,
dtype
=
np
.
int16
)
sample_rate
=
asr_engine
.
config
.
sample_rate
x_chunk
,
x_chunk_lens
=
asr_engine
.
preprocess
(
samples
,
sample_rate
)
asr_engine
.
run
(
x_chunk
,
x_chunk_lens
)
# asr_results = asr_engine.postprocess()
asr_results
=
asr_engine
.
postprocess
()
asr_results
=
asr_engine
.
postprocess
()
resp
=
{
'asr_results'
:
asr_results
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录