Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8641608f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8641608f
编写于
6月 08, 2022
作者:
Y
YangZhou
提交者:
GitHub
6月 08, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2015 from zh794390558/endpoint
[server][asr] support endpoint for conformer streaming model
上级
f3132ce2
dfdf450b
变更
52
隐藏空白更改
内联
并排
Showing
52 changed file
with
547 addition
and
369 deletion
+547
-369
.pre-commit-config.yaml
.pre-commit-config.yaml
+6
-6
demos/streaming_asr_server/conf/application.yaml
demos/streaming_asr_server/conf/application.yaml
+2
-0
demos/streaming_asr_server/conf/ws_conformer_application.yaml
...s/streaming_asr_server/conf/ws_conformer_application.yaml
+3
-0
demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml
...asr_server/conf/ws_conformer_wenetspeech_application.yaml
+2
-0
demos/streaming_asr_server/server.sh
demos/streaming_asr_server/server.sh
+2
-1
demos/streaming_asr_server/test.sh
demos/streaming_asr_server/test.sh
+2
-1
examples/wenetspeech/asr1/local/extract_meta.py
examples/wenetspeech/asr1/local/extract_meta.py
+0
-1
paddlespeech/cli/base_commands.py
paddlespeech/cli/base_commands.py
+0
-1
paddlespeech/cli/cls/infer.py
paddlespeech/cli/cls/infer.py
+2
-2
paddlespeech/cli/vector/infer.py
paddlespeech/cli/vector/infer.py
+2
-2
paddlespeech/resource/model_alias.py
paddlespeech/resource/model_alias.py
+1
-2
paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py
paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py
+2
-1
paddlespeech/s2t/exps/deepspeech2/model.py
paddlespeech/s2t/exps/deepspeech2/model.py
+4
-6
paddlespeech/s2t/models/ds2/__init__.py
paddlespeech/s2t/models/ds2/__init__.py
+2
-1
paddlespeech/s2t/models/ds2/deepspeech2.py
paddlespeech/s2t/models/ds2/deepspeech2.py
+8
-4
paddlespeech/s2t/models/lm/transformer.py
paddlespeech/s2t/models/lm/transformer.py
+1
-1
paddlespeech/s2t/models/u2/updater.py
paddlespeech/s2t/models/u2/updater.py
+0
-1
paddlespeech/s2t/modules/ctc.py
paddlespeech/s2t/modules/ctc.py
+1
-1
paddlespeech/s2t/utils/tensor_utils.py
paddlespeech/s2t/utils/tensor_utils.py
+2
-1
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+1
-0
paddlespeech/server/conf/ws_conformer_application.yaml
paddlespeech/server/conf/ws_conformer_application.yaml
+2
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+233
-211
paddlespeech/server/engine/asr/online/ctc_endpoint.py
paddlespeech/server/engine/asr/online/ctc_endpoint.py
+118
-0
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+35
-15
paddlespeech/server/engine/tts/online/python/tts_engine.py
paddlespeech/server/engine/tts/online/python/tts_engine.py
+0
-1
paddlespeech/server/ws/asr_api.py
paddlespeech/server/ws/asr_api.py
+28
-5
paddlespeech/t2s/exps/synthesize.py
paddlespeech/t2s/exps/synthesize.py
+2
-8
paddlespeech/t2s/exps/synthesize_e2e.py
paddlespeech/t2s/exps/synthesize_e2e.py
+2
-8
paddlespeech/t2s/exps/voice_cloning.py
paddlespeech/t2s/exps/voice_cloning.py
+2
-8
paddlespeech/t2s/models/vits/__init__.py
paddlespeech/t2s/models/vits/__init__.py
+1
-1
paddlespeech/t2s/models/vits/vits_updater.py
paddlespeech/t2s/models/vits/vits_updater.py
+4
-2
paddlespeech/t2s/modules/losses.py
paddlespeech/t2s/modules/losses.py
+9
-9
speechx/examples/README.md
speechx/examples/README.md
+0
-1
speechx/examples/ds2_ol/README.md
speechx/examples/ds2_ol/README.md
+1
-1
speechx/speechx/codelab/README.md
speechx/speechx/codelab/README.md
+0
-1
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
...hx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
+2
-2
speechx/speechx/decoder/ctc_tlg_decoder.cc
speechx/speechx/decoder/ctc_tlg_decoder.cc
+1
-1
speechx/speechx/decoder/param.h
speechx/speechx/decoder/param.h
+1
-1
speechx/speechx/decoder/tlg_decoder_main.cc
speechx/speechx/decoder/tlg_decoder_main.cc
+2
-2
speechx/speechx/frontend/audio/assembler.cc
speechx/speechx/frontend/audio/assembler.cc
+15
-14
speechx/speechx/frontend/audio/assembler.h
speechx/speechx/frontend/audio/assembler.h
+3
-7
speechx/speechx/frontend/audio/audio_cache.h
speechx/speechx/frontend/audio/audio_cache.h
+1
-1
speechx/speechx/frontend/audio/fbank.cc
speechx/speechx/frontend/audio/fbank.cc
+4
-4
speechx/speechx/frontend/audio/feature_cache.cc
speechx/speechx/frontend/audio/feature_cache.cc
+2
-2
speechx/speechx/frontend/audio/feature_cache.h
speechx/speechx/frontend/audio/feature_cache.h
+1
-3
speechx/speechx/frontend/audio/feature_common.h
speechx/speechx/frontend/audio/feature_common.h
+4
-3
speechx/speechx/frontend/audio/feature_common_inl.h
speechx/speechx/frontend/audio/feature_common_inl.h
+12
-10
speechx/speechx/frontend/audio/feature_pipeline.h
speechx/speechx/frontend/audio/feature_pipeline.h
+1
-1
speechx/speechx/frontend/audio/linear_spectrogram.cc
speechx/speechx/frontend/audio/linear_spectrogram.cc
+4
-5
speechx/speechx/nnet/nnet_forward_main.cc
speechx/speechx/nnet/nnet_forward_main.cc
+8
-5
speechx/speechx/protocol/websocket/websocket_client.h
speechx/speechx/protocol/websocket/websocket_client.h
+2
-2
speechx/speechx/protocol/websocket/websocket_server.cc
speechx/speechx/protocol/websocket/websocket_server.cc
+4
-3
未找到文件。
.pre-commit-config.yaml
浏览文件 @
8641608f
...
...
@@ -51,12 +51,12 @@ repos:
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude
:
(?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
-
id
:
copyright_checker
name
:
copyright_checker
entry
:
python .pre-commit-hooks/copyright-check.hook
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude
:
(?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
#
- id: copyright_checker
#
name: copyright_checker
#
entry: python .pre-commit-hooks/copyright-check.hook
#
language: system
#
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
#
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
-
repo
:
https://github.com/asottile/reorder_python_imports
rev
:
v2.4.0
hooks
:
...
...
demos/streaming_asr_server/conf/application.yaml
浏览文件 @
8641608f
...
...
@@ -31,6 +31,8 @@ asr_online:
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
decode_method
:
"
attention_rescoring"
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
...
...
demos/streaming_asr_server/conf/ws_conformer_application.yaml
浏览文件 @
8641608f
...
...
@@ -30,6 +30,9 @@ asr_online:
decode_method
:
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
decode_method
:
"
attention_rescoring"
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
...
...
demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml
浏览文件 @
8641608f
...
...
@@ -31,6 +31,8 @@ asr_online:
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
decode_method
:
"
attention_rescoring"
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
...
...
demos/streaming_asr_server/server.sh
浏览文件 @
8641608f
...
...
@@ -5,4 +5,5 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3
paddlespeech_server start
--config_file
conf/punc_application.yaml &> punc.log &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
paddlespeech_server start
--config_file
conf/ws_conformer_application.yaml &> streaming_asr.log &
\ No newline at end of file
paddlespeech_server start
--config_file
conf/ws_conformer_application.yaml &> streaming_asr.log &
demos/streaming_asr_server/test.sh
浏览文件 @
8641608f
...
...
@@ -9,4 +9,5 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa
# read the wav and call streaming and punc service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
paddlespeech_client asr_online
--server_ip
127.0.0.1
--port
8290
--punc
.server_ip 127.0.0.1
--punc
.port 8190
--input
./zh.wav
\ No newline at end of file
paddlespeech_client asr_online
--server_ip
127.0.0.1
--port
8290
--punc
.server_ip 127.0.0.1
--punc
.port 8190
--input
./zh.wav
examples/wenetspeech/asr1/local/extract_meta.py
浏览文件 @
8641608f
...
...
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
json
import
os
...
...
paddlespeech/cli/base_commands.py
浏览文件 @
8641608f
...
...
@@ -145,4 +145,3 @@ for com, info in _commands.items():
name
=
'paddlespeech.{}'
.
format
(
com
),
description
=
info
[
0
],
cls
=
'paddlespeech.cli.{}.{}'
.
format
(
com
,
info
[
1
]))
\ No newline at end of file
paddlespeech/cli/cls/infer.py
浏览文件 @
8641608f
...
...
@@ -21,12 +21,12 @@ from typing import Union
import
numpy
as
np
import
paddle
import
yaml
from
paddleaudio
import
load
from
paddleaudio.features
import
LogMelSpectrogram
from
..executor
import
BaseExecutor
from
..log
import
logger
from
..utils
import
stats_wrapper
from
paddleaudio
import
load
from
paddleaudio.features
import
LogMelSpectrogram
__all__
=
[
'CLSExecutor'
]
...
...
paddlespeech/cli/vector/infer.py
浏览文件 @
8641608f
...
...
@@ -22,13 +22,13 @@ from typing import Union
import
paddle
import
soundfile
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.compliance.librosa
import
melspectrogram
from
yacs.config
import
CfgNode
from
..executor
import
BaseExecutor
from
..log
import
logger
from
..utils
import
stats_wrapper
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.compliance.librosa
import
melspectrogram
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
...
...
paddlespeech/resource/model_alias.py
浏览文件 @
8641608f
...
...
@@ -22,8 +22,7 @@ model_alias = {
# -------------- ASR --------------
# ---------------------------------
"deepspeech2offline"
:
[
"paddlespeech.s2t.models.ds2:DeepSpeech2Model"
],
"deepspeech2online"
:
[
"paddlespeech.s2t.models.ds2:DeepSpeech2Model"
],
"deepspeech2online"
:
[
"paddlespeech.s2t.models.ds2:DeepSpeech2Model"
],
"conformer"
:
[
"paddlespeech.s2t.models.u2:U2Model"
],
"conformer_online"
:
[
"paddlespeech.s2t.models.u2:U2Model"
],
"transformer"
:
[
"paddlespeech.s2t.models.u2:U2Model"
],
...
...
paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py
浏览文件 @
8641608f
...
...
@@ -76,7 +76,8 @@ class CTCPrefixScorePD():
last_ids
=
[
yi
[
-
1
]
for
yi
in
y
]
# last output label ids
n_bh
=
len
(
last_ids
)
# batch * hyps
n_hyps
=
n_bh
//
self
.
batch
# assuming each utterance has the same # of hyps
self
.
scoring_num
=
paddle
.
shape
(
scoring_ids
)[
-
1
]
if
scoring_ids
is
not
None
else
0
self
.
scoring_num
=
paddle
.
shape
(
scoring_ids
)[
-
1
]
if
scoring_ids
is
not
None
else
0
# prepare state info
if
state
is
None
:
r_prev
=
paddle
.
full
(
...
...
paddlespeech/s2t/exps/deepspeech2/model.py
浏览文件 @
8641608f
...
...
@@ -22,11 +22,9 @@ import numpy as np
import
paddle
from
paddle
import
distributed
as
dist
from
paddle
import
inference
from
paddle.io
import
DataLoader
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.data
set
import
ManifestDataset
from
paddlespeech.s2t.io.data
loader
import
BatchDataLoader
from
paddlespeech.s2t.models.ds2
import
DeepSpeech2InferModel
from
paddlespeech.s2t.models.ds2
import
DeepSpeech2Model
from
paddlespeech.s2t.training.gradclip
import
ClipGradByGlobalNormWithLog
...
...
@@ -238,8 +236,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
self
.
_text_featurizer
=
TextFeaturizer
(
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
)
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
)
self
.
vocab_list
=
self
.
_text_featurizer
.
vocab_list
def
ordid2token
(
self
,
texts
,
texts_len
):
...
...
@@ -248,7 +245,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
for
text
,
n
in
zip
(
texts
,
texts_len
):
n
=
n
.
numpy
().
item
()
ids
=
text
[:
n
]
trans
.
append
(
self
.
_text_featurizer
.
defeaturize
(
ids
.
numpy
().
tolist
()))
trans
.
append
(
self
.
_text_featurizer
.
defeaturize
(
ids
.
numpy
().
tolist
()))
return
trans
def
compute_metrics
(
self
,
...
...
paddlespeech/s2t/models/ds2/__init__.py
浏览文件 @
8641608f
...
...
@@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
from
.deepspeech2
import
DeepSpeech2InferModel
from
.deepspeech2
import
DeepSpeech2Model
from
paddlespeech.s2t.utils
import
dynamic_pip_install
import
sys
try
:
import
paddlespeech_ctcdecoders
...
...
paddlespeech/s2t/models/ds2/deepspeech2.py
浏览文件 @
8641608f
...
...
@@ -372,11 +372,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
=
None
,
def
forward
(
self
,
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
=
None
,
chunk_state_c_box
=
None
):
if
self
.
encoder
.
rnn_direction
==
"forward"
:
eouts_chunk
,
eouts_chunk_lens
,
final_state_h_box
,
final_state_c_box
=
self
.
encoder
(
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
)
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
)
probs_chunk
=
self
.
decoder
.
softmax
(
eouts_chunk
)
return
probs_chunk
,
eouts_chunk_lens
,
final_state_h_box
,
final_state_c_box
elif
self
.
encoder
.
rnn_direction
==
"bidirect"
:
...
...
@@ -392,8 +396,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
self
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
self
.
encoder
.
feat_size
],
#[B, chunk_size, feat_dim]
shape
=
[
None
,
None
,
self
.
encoder
.
feat_size
],
#[B, chunk_size, feat_dim]
dtype
=
'float32'
),
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
...
...
paddlespeech/s2t/models/lm/transformer.py
浏览文件 @
8641608f
...
...
@@ -90,7 +90,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def
_target_mask
(
self
,
ys_in_pad
):
ys_mask
=
ys_in_pad
!=
0
m
=
subsequent_mask
(
paddle
.
shape
(
ys_mask
)[
-
1
])
)
.
unsqueeze
(
0
)
m
=
subsequent_mask
(
paddle
.
shape
(
ys_mask
)[
-
1
]).
unsqueeze
(
0
)
return
ys_mask
.
unsqueeze
(
-
2
)
&
m
def
forward
(
self
,
x
:
paddle
.
Tensor
,
t
:
paddle
.
Tensor
...
...
paddlespeech/s2t/models/u2/updater.py
浏览文件 @
8641608f
...
...
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
contextlib
import
nullcontext
import
paddle
...
...
paddlespeech/s2t/modules/ctc.py
浏览文件 @
8641608f
...
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
from
typing
import
Union
import
paddle
...
...
@@ -22,7 +23,6 @@ from paddlespeech.s2t.modules.align import Linear
from
paddlespeech.s2t.modules.loss
import
CTCLoss
from
paddlespeech.s2t.utils
import
ctc_utils
from
paddlespeech.s2t.utils.log
import
Log
import
sys
logger
=
Log
(
__name__
).
getlog
()
...
...
paddlespeech/s2t/utils/tensor_utils.py
浏览文件 @
8641608f
...
...
@@ -82,7 +82,8 @@ def pad_sequence(sequences: List[paddle.Tensor],
max_size
=
paddle
.
shape
(
sequences
[
0
])
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims
=
tuple
(
max_size
[
1
:].
numpy
().
tolist
())
if
sequences
[
0
].
ndim
>=
2
else
()
trailing_dims
=
tuple
(
max_size
[
1
:].
numpy
().
tolist
())
if
sequences
[
0
].
ndim
>=
2
else
()
max_len
=
max
([
s
.
shape
[
0
]
for
s
in
sequences
])
if
batch_first
:
out_dims
=
(
len
(
sequences
),
max_len
)
+
trailing_dims
...
...
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
8641608f
...
...
@@ -29,6 +29,7 @@ asr_online:
cfg_path
:
decode_method
:
force_yes
:
True
device
:
# cpu or gpu:id
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
...
...
paddlespeech/server/conf/ws_conformer_application.yaml
浏览文件 @
8641608f
...
...
@@ -30,6 +30,8 @@ asr_online:
decode_method
:
force_yes
:
True
device
:
# cpu or gpu:id
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
...
...
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
8641608f
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
os
import
sys
from
typing
import
ByteString
from
typing
import
Optional
import
numpy
as
np
...
...
@@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation
from
paddlespeech.s2t.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.s2t.utils.tensor_utils
import
pad_sequence
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.asr.online.ctc_endpoint
import
OnlineCTCEndpoingOpt
from
paddlespeech.server.engine.asr.online.ctc_endpoint
import
OnlineCTCEndpoint
from
paddlespeech.server.engine.asr.online.ctc_search
import
CTCPrefixBeamSearch
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
__all__
=
[
'PaddleASRConnectionHanddler'
,
'ASRServerExecutor'
,
'ASREngine'
]
...
...
@@ -54,24 +56,35 @@ class PaddleASRConnectionHanddler:
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
asr_engine
=
asr_engine
self
.
init
()
self
.
reset
()
def
init
(
self
):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
# extract feat, new only fbank in conformer model
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
assert
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
==
self
.
sample_rate
,
(
self
.
sample_rate
,
self
.
preprocess_conf
.
process
[
0
][
'fs'
])
self
.
frame_shift_in_ms
=
int
(
self
.
n_shift
/
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
*
1000
)
self
.
continuous_decoding
=
self
.
config
.
get
(
"continuous_decoding"
,
False
)
self
.
init_decoder
()
self
.
reset
()
def
init_decoder
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
assert
self
.
continuous_decoding
is
False
,
"ds2 model not support endpoint"
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
# extract feat, new only fbank in conformer model
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
model_config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
model_config
.
rnn_layer_size
*
2
,
...
...
@@ -90,142 +103,65 @@ class PaddleASRConnectionHanddler:
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
# acoustic model
self
.
model
=
self
.
asr_engine
.
executor
.
model
self
.
continuous_decoding
=
self
.
config
.
continuous_decoding
logger
.
info
(
f
"continue decoding:
{
self
.
continuous_decoding
}
"
)
# ctc decoding config
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
ctc_decode_config
)
# extract feat, new only fbank in conformer model
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
# ctc endpoint
self
.
endpoint_opt
=
OnlineCTCEndpoingOpt
(
frame_shift_in_ms
=
self
.
frame_shift_in_ms
,
blank
=
0
)
self
.
endpointer
=
OnlineCTCEndpoint
(
self
.
endpoint_opt
)
else
:
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
def
extract_feat
(
self
,
samples
):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
if
"deepspeech2online"
in
self
.
model_type
:
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
# fbank
feat
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
feat
=
paddle
.
to_tensor
(
feat
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
feat
else
:
assert
(
len
(
feat
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
feat
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
# cur frame step
num_frames
=
feat
.
shape
[
1
]
self
.
num_frames
+=
num_frames
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
elif
"conformer_online"
in
self
.
model_type
:
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
# samples not enough for feature window
return
0
def
model_reset
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
return
# fbank
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
# feature cache
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
else
:
assert
(
len
(
x_chunk
.
shape
)
==
3
)
# (B,T,D)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
# (B,T,D)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
## conformer
# cache for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
# conformer decoding state
self
.
offset
=
0
# global offset in decoding frame unit
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
## just for record info
self
.
chunk_num
=
0
# global decoding chunk num, not used
# cur frame step
num_frames
=
x_chunk
.
shape
[
1
]
def
output_reset
(
self
):
## outputs
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
# token timestamp result
self
.
word_time_stamp
=
[]
# global frame step
self
.
num_frames
+=
num_frames
## just for record
self
.
hyps
=
[]
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:
]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[
]
logger
.
info
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
else
:
raise
ValueError
(
f
"not supported:
{
self
.
model_type
}
"
)
def
reset_continuous_decoding
(
self
):
"""
when in continous decoding, reset for next utterance.
"""
self
.
global_frame_offset
=
self
.
num_frames
self
.
model_reset
()
self
.
searcher
.
reset
()
self
.
endpointer
.
reset
()
self
.
output_reset
()
def
reset
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
...
...
@@ -241,38 +177,87 @@ class PaddleASRConnectionHanddler:
dtype
=
float32
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
if
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
self
.
searcher
.
reset
()
self
.
endpointer
.
reset
()
self
.
device
=
None
## common
# global sample and frame step
self
.
num_samples
=
0
self
.
global_frame_offset
=
0
# frame step of cur utterance
self
.
num_frames
=
0
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
## endpoint
self
.
endpoint_state
=
False
# True for detect endpoint
## conformer
self
.
model_reset
()
# cache for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
# conformer decoding state
self
.
chunk_num
=
0
# globa decoding chunk num
self
.
offset
=
0
# global offset in decoding frame unit
self
.
hyps
=
[]
## outputs
self
.
output_reset
()
# token timestamp result
self
.
word_time_stamp
=
[]
def
extract_feat
(
self
,
samples
:
ByteString
):
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
# samples not enough for feature window
return
0
# fbank
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
# feature cache
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
else
:
assert
(
len
(
x_chunk
.
shape
)
==
3
)
# (B,T,D)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
# (B,T,D)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
# cur frame step
num_frames
=
x_chunk
.
shape
[
1
]
# global frame step
self
.
num_frames
+=
num_frames
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
...
...
@@ -280,14 +265,12 @@ class PaddleASRConnectionHanddler:
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None:
nothing
None:
"""
if
"deepspeech2
online
"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
decoding_chunk_size
=
1
# decoding chunk size = 1. int decoding frame unit
context
=
7
# context=7, in audio frame unit
subsampling
=
4
# subsampling=4, in audio frame unit
...
...
@@ -332,9 +315,11 @@ class PaddleASRConnectionHanddler:
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# extract the audio
x_chunk
=
self
.
cached_feat
[:,
cur
:
end
,
:].
numpy
()
x_chunk_lens
=
np
.
array
([
x_chunk
.
shape
[
1
]])
trans_best
=
self
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
)
self
.
result_transcripts
=
[
trans_best
]
...
...
@@ -409,31 +394,41 @@ class PaddleASRConnectionHanddler:
@
paddle
.
no_grad
()
def
advance_decoding
(
self
,
is_finished
=
False
):
if
"deepspeech"
in
self
.
model_type
:
return
# reset endpiont state
self
.
endpoint_state
=
False
logger
.
info
(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg
=
self
.
ctc_decode_config
# cur chunk size, in decoding frame unit
# cur chunk size, in decoding frame unit
, e.g. 16
decoding_chunk_size
=
cfg
.
decoding_chunk_size
# using num of history chunks
# using num of history chunks
, e.g -1
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
# e.g. 4
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
# e.g. 7
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
# processed chunk feature cached for next chunk
# processed chunk feature cached for next chunk
, e.g. 3
cached_feature_num
=
context
-
subsampling
# decoding stride, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
# decoding window, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
# decoding stride, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
# (B=1,T,D)
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
...
...
@@ -454,9 +449,6 @@ class PaddleASRConnectionHanddler:
return
None
,
None
logger
.
info
(
"start to do model forward"
)
# hist of chunks, in deocding frame unit
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
outputs
=
[]
# num_frames - context + 1 ensure that current frame can get context window
if
is_finished
:
...
...
@@ -466,7 +458,11 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
left_frames
=
decoding_window
# hist of chunks, in deocding frame unit
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
# record the end for removing the processed feat
outputs
=
[]
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
...
...
@@ -491,30 +487,40 @@ class PaddleASRConnectionHanddler:
self
.
encoder_out
=
ys
else
:
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
logger
.
info
(
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
)
# get the ctc probs
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
## decoding
# advance decoding
self
.
searcher
.
search
(
ctc_probs
,
self
.
cached_feat
.
place
)
# get one best hyps
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
end
>=
cached_feature_num
# endpoint
if
not
is_finished
:
def
contain_nonsilence
():
return
len
(
self
.
hyps
)
>
0
and
len
(
self
.
hyps
[
0
])
>
0
decoding_something
=
contain_nonsilence
()
if
self
.
endpointer
.
endpoint_detected
(
ctc_probs
.
numpy
(),
decoding_something
):
self
.
endpoint_state
=
True
logger
.
info
(
f
"Endpoint is detected at
{
self
.
num_frames
}
frame."
)
# advance cache of feat
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
cached_feature_num
:,
:].
unsqueeze
(
0
)
assert
self
.
cached_feat
.
shape
[
0
]
==
1
#(B=1,T,D)
assert
end
>=
cached_feature_num
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
assert
len
(
self
.
cached_feat
.
shape
)
==
3
,
f
"current cache feat shape is:
{
self
.
cached_feat
.
shape
}
"
logger
.
info
(
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
)
def
update_result
(
self
):
"""Conformer/Transformer hyps to result.
"""
...
...
@@ -654,24 +660,28 @@ class PaddleASRConnectionHanddler:
# update each word start and end time stamp
# decoding frame to audio frame
frame_shift
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
frame_shift_in_sec
=
frame_shift
*
(
self
.
n_shift
/
self
.
sample_rate
)
logger
.
info
(
f
"frame shift sec:
{
frame_shift_in_sec
}
"
)
decode_frame_shift
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
decode_frame_shift_in_sec
=
decode_frame_shift
*
(
self
.
n_shift
/
self
.
sample_rate
)
logger
.
info
(
f
"decode frame shift in sec:
{
decode_frame_shift_in_sec
}
"
)
global_offset_in_sec
=
self
.
global_frame_offset
*
self
.
frame_shift_in_ms
/
1000.0
logger
.
info
(
f
"global offset:
{
global_offset_in_sec
}
sec."
)
word_time_stamp
=
[]
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
)
/
2.0
if
idx
>
0
else
0
start
=
start
*
frame_shift_in_sec
start
=
start
*
decode_
frame_shift_in_sec
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
end
=
end
*
frame_shift_in_sec
end
=
end
*
decode_
frame_shift_in_sec
word_time_stamp
.
append
({
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"bg"
:
start
,
"ed"
:
end
"bg"
:
global_offset_in_sec
+
start
,
"ed"
:
global_offset_in_sec
+
end
})
# logger.info(f"{word_time_stamp[-1]}")
...
...
@@ -705,13 +715,14 @@ class ASRServerExecutor(ASRExecutor):
self
.
model_type
=
model_type
self
.
sample_rate
=
sample_rate
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
self
.
task_resource
.
set_task_model
(
model_tag
=
tag
)
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
self
.
res_path
=
self
.
task_resource
.
res_dir
self
.
cfg_path
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'cfg_path'
])
...
...
@@ -719,7 +730,6 @@ class ASRServerExecutor(ASRExecutor):
self
.
task_resource
.
res_dict
[
'model'
])
self
.
am_params
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'params'
])
logger
.
info
(
self
.
res_path
)
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
...
...
@@ -727,9 +737,12 @@ class ASRServerExecutor(ASRExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
logger
.
info
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
info
(
f
" am_params path:
{
self
.
am_params
}
"
)
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
...
@@ -738,25 +751,39 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
vocab
=
self
.
config
.
vocab_filepath
with
UpdateConfig
(
self
.
config
):
if
"deepspeech2"
in
model_type
:
if
"deepspeech2"
in
model_type
:
with
UpdateConfig
(
self
.
config
):
# download lm
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
MODEL_HOME
,
'language_model'
,
self
.
config
.
decode
.
lang_model_path
)
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
with
UpdateConfig
(
self
.
config
):
logger
.
info
(
"start to create the stream conformer asr engine"
)
# update the decoding method
if
decode_method
:
...
...
@@ -770,37 +797,24 @@ class ASRServerExecutor(ASRExecutor):
logger
.
info
(
"we set the decoding_method to attention_rescoring"
)
self
.
config
.
decode
.
decoding_method
=
"attention_rescoring"
assert
self
.
config
.
decode
.
decoding_method
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
else
:
raise
Exception
(
"wrong type"
)
if
"deepspeech2"
in
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
# load model
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
logger
.
info
(
f
"model name:
{
model_name
}
"
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model_conf
=
self
.
config
model
=
model_class
.
from_config
(
model_conf
)
model
=
model_class
.
from_config
(
self
.
config
)
self
.
model
=
model
self
.
model
.
set_state_dict
(
paddle
.
load
(
self
.
am_model
))
self
.
model
.
eval
()
# load model
model_dict
=
paddle
.
load
(
self
.
am_model
)
self
.
model
.
set_state_dict
(
model_dict
)
logger
.
info
(
"create the transformer like model success"
)
else
:
raise
ValueError
(
f
"N
ot support:
{
model_type
}
"
)
raise
Exception
(
f
"n
ot support:
{
model_type
}
"
)
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
return
True
...
...
@@ -857,6 +871,14 @@ class ASREngine(BaseEngine):
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
def
new_handler
(
self
):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return
PaddleASRConnectionHanddler
(
self
)
def
preprocess
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"Online not using this."
)
...
...
paddlespeech/server/engine/asr/online/ctc_endpoint.py
0 → 100644
浏览文件 @
8641608f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
import
numpy
as
np
from
paddlespeech.cli.log
import
logger
@
dataclass
class
OnlineCTCEndpointRule
:
must_contain_nonsilence
:
bool
=
True
min_trailing_silence
:
int
=
1000
min_utterance_length
:
int
=
0
@
dataclass
class
OnlineCTCEndpoingOpt
:
frame_shift_in_ms
:
int
=
10
blank
:
int
=
0
# blank id, that we consider as silence for purposes of endpointing.
blank_threshold
:
float
=
0.8
# above blank threshold is silence
# We support three rules. We terminate decoding if ANY of these rules
# evaluates to "true". If you want to add more rules, do it by changing this
# code. If you want to disable a rule, you can set the silence-timeout for
# that rule to a very large number.
# rule1 times out after 5 seconds of silence, even if we decoded nothing.
rule1
:
OnlineCTCEndpointRule
=
OnlineCTCEndpointRule
(
False
,
5000
,
0
)
# rule4 times out after 1.0 seconds of silence after decoding something,
# even if we did not reach a final-state at all.
rule2
:
OnlineCTCEndpointRule
=
OnlineCTCEndpointRule
(
True
,
1000
,
0
)
# rule5 times out after the utterance is 20 seconds long, regardless of
# anything else.
rule3
:
OnlineCTCEndpointRule
=
OnlineCTCEndpointRule
(
False
,
0
,
20000
)
class
OnlineCTCEndpoint
:
"""
[END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf)
"""
def
__init__
(
self
,
opts
:
OnlineCTCEndpoingOpt
):
self
.
opts
=
opts
logger
.
info
(
f
"Endpont Opts:
{
opts
}
"
)
self
.
frame_shift_in_ms
=
opts
.
frame_shift_in_ms
self
.
num_frames_decoded
=
0
self
.
trailing_silence_frames
=
0
self
.
reset
()
def
reset
(
self
):
self
.
num_frames_decoded
=
0
self
.
trailing_silence_frames
=
0
def
rule_activated
(
self
,
rule
:
OnlineCTCEndpointRule
,
rule_name
:
str
,
decoding_something
:
bool
,
trailine_silence
:
int
,
utterance_length
:
int
)
->
bool
:
ans
=
(
decoding_something
or
(
not
rule
.
must_contain_nonsilence
)
)
and
trailine_silence
>=
rule
.
min_trailing_silence
and
utterance_length
>=
rule
.
min_utterance_length
if
(
ans
):
logger
.
info
(
f
"Endpoint Rule:
{
rule_name
}
activated:
{
rule
}
"
)
return
ans
def
endpoint_detected
(
self
,
ctc_log_probs
:
np
.
ndarray
,
decoding_something
:
bool
)
->
bool
:
"""detect endpoint.
Args:
ctc_log_probs (np.ndarray): (T, D)
decoding_something (bool): contain nonsilince.
Returns:
bool: whether endpoint detected.
"""
for
logprob
in
ctc_log_probs
:
blank_prob
=
np
.
exp
(
logprob
[
self
.
opts
.
blank
])
self
.
num_frames_decoded
+=
1
if
blank_prob
>
self
.
opts
.
blank_threshold
:
self
.
trailing_silence_frames
+=
1
else
:
self
.
trailing_silence_frames
=
0
assert
self
.
num_frames_decoded
>=
self
.
trailing_silence_frames
assert
self
.
frame_shift_in_ms
>
0
utterance_length
=
self
.
num_frames_decoded
*
self
.
frame_shift_in_ms
trailing_silence
=
self
.
trailing_silence_frames
*
self
.
frame_shift_in_ms
if
self
.
rule_activated
(
self
.
opts
.
rule1
,
'rule1'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
if
self
.
rule_activated
(
self
.
opts
.
rule2
,
'rule2'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
if
self
.
rule_activated
(
self
.
opts
.
rule3
,
'rule3'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
return
False
paddlespeech/server/engine/asr/online/ctc_search.py
浏览文件 @
8641608f
...
...
@@ -30,8 +30,29 @@ class CTCPrefixBeamSearch:
config (yacs.config.CfgNode): the ctc prefix beam search configuration
"""
self
.
config
=
config
# beam size
self
.
first_beam_size
=
self
.
config
.
beam_size
# TODO(support second beam size)
self
.
second_beam_size
=
int
(
self
.
first_beam_size
*
1.0
)
logger
.
info
(
f
"first and second beam size:
{
self
.
first_beam_size
}
,
{
self
.
second_beam_size
}
"
)
# state
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
self
.
reset
()
def
reset
(
self
):
"""Rest the search cache value
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
@
paddle
.
no_grad
()
def
search
(
self
,
ctc_probs
,
device
,
blank_id
=
0
):
"""ctc prefix beam search method decode a chunk feature
...
...
@@ -47,12 +68,17 @@ class CTCPrefixBeamSearch:
"""
# decode
logger
.
info
(
"start to ctc prefix search"
)
assert
len
(
ctc_probs
.
shape
)
==
2
batch_size
=
1
beam_size
=
self
.
config
.
beam_size
maxlen
=
ctc_probs
.
shape
[
0
]
assert
len
(
ctc_probs
.
shape
)
==
2
vocab_size
=
ctc_probs
.
shape
[
1
]
first_beam_size
=
min
(
self
.
first_beam_size
,
vocab_size
)
second_beam_size
=
min
(
self
.
second_beam_size
,
vocab_size
)
logger
.
info
(
f
"effect first and second beam size:
{
self
.
first_beam_size
}
,
{
self
.
second_beam_size
}
"
)
maxlen
=
ctc_probs
.
shape
[
0
]
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# 0. blank_ending_score,
...
...
@@ -75,7 +101,8 @@ class CTCPrefixBeamSearch:
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
top_k_logp
,
top_k_index
=
logp
.
topk
(
first_beam_size
)
# (first_beam_size,)
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
...
...
@@ -148,7 +175,7 @@ class CTCPrefixBeamSearch:
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
([
x
[
1
][
0
],
x
[
1
][
1
]]),
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
self
.
cur_hyps
=
next_hyps
[:
second_
beam_size
]
# 2.3 update the absolute time step
self
.
abs_time_step
+=
1
...
...
@@ -163,7 +190,7 @@ class CTCPrefixBeamSearch:
"""Return the one best result
Returns:
list: the one best result
list: the one best result
, List[str]
"""
return
[
self
.
hyps
[
0
][
0
]]
...
...
@@ -171,17 +198,10 @@ class CTCPrefixBeamSearch:
"""Return the search hyps
Returns:
list: return the search hyps
list: return the search hyps
, List[Tuple[str, float, ...]]
"""
return
self
.
hyps
def
reset
(
self
):
"""Rest the search cache value
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
def
finalize_search
(
self
):
"""do nothing in ctc_prefix_beam_search
"""
...
...
paddlespeech/server/engine/tts/online/python/tts_engine.py
浏览文件 @
8641608f
...
...
@@ -42,7 +42,6 @@ class TTSServerExecutor(TTSExecutor):
self
.
task_resource
=
CommonTaskResource
(
task
=
'tts'
,
model_format
=
'dynamic'
,
inference_mode
=
'online'
)
def
get_model_info
(
self
,
field
:
str
,
model_name
:
str
,
...
...
paddlespeech/server/ws/asr_api.py
浏览文件 @
8641608f
...
...
@@ -19,7 +19,6 @@ from fastapi import WebSocketDisconnect
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
paddlespeech.cli.log
import
logger
from
paddlespeech.server.engine.asr.online.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.engine.engine_pool
import
get_engine_pool
router
=
APIRouter
()
...
...
@@ -38,7 +37,7 @@ async def websocket_endpoint(websocket: WebSocket):
#2. if we accept the websocket headers, we will get the online asr engine instance
engine_pool
=
get_engine_pool
()
asr_
engine
=
engine_pool
[
'asr'
]
asr_
model
=
engine_pool
[
'asr'
]
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
# and each connection has its own connection instance to process the request
...
...
@@ -70,7 +69,8 @@ async def websocket_endpoint(websocket: WebSocket):
resp
=
{
"status"
:
"ok"
,
"signal"
:
"server_ready"
}
# do something at begining here
# create the instance to process the audio
connection_handler
=
PaddleASRConnectionHanddler
(
asr_engine
)
#connection_handler = PaddleASRConnectionHanddler(asr_model)
connection_handler
=
asr_model
.
new_handler
()
await
websocket
.
send_json
(
resp
)
elif
message
[
'signal'
]
==
'end'
:
# reset single engine for an new connection
...
...
@@ -100,11 +100,34 @@ async def websocket_endpoint(websocket: WebSocket):
# and decode for the result in this package data
connection_handler
.
extract_feat
(
message
)
connection_handler
.
decode
(
is_finished
=
False
)
if
connection_handler
.
endpoint_state
:
logger
.
info
(
"endpoint: detected and rescoring."
)
connection_handler
.
rescoring
()
word_time_stamp
=
connection_handler
.
get_word_time_stamp
()
asr_results
=
connection_handler
.
get_result
()
# return the current period result
# if the engine create the vad instance, this connection will have many period results
if
connection_handler
.
endpoint_state
:
if
connection_handler
.
continuous_decoding
:
logger
.
info
(
"endpoint: continue decoding"
)
connection_handler
.
reset_continuous_decoding
()
else
:
logger
.
info
(
"endpoint: exit decoding"
)
# ending by endpoint
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'result'
:
asr_results
,
'times'
:
word_time_stamp
}
await
websocket
.
send_json
(
resp
)
break
# return the current partial result
# if the engine create the vad instance, this connection will have many partial results
resp
=
{
'result'
:
asr_results
}
await
websocket
.
send_json
(
resp
)
except
WebSocketDisconnect
as
e
:
logger
.
error
(
e
)
paddlespeech/t2s/exps/synthesize.py
浏览文件 @
8641608f
...
...
@@ -140,10 +140,7 @@ def parse_args():
],
help
=
'Choose acoustic model type of tts task.'
)
parser
.
add_argument
(
'--am_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of acoustic model.'
)
'--am_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of acoustic model.'
)
parser
.
add_argument
(
'--am_ckpt'
,
type
=
str
,
...
...
@@ -179,10 +176,7 @@ def parse_args():
],
help
=
'Choose vocoder type of tts task.'
)
parser
.
add_argument
(
'--voc_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of voc.'
)
'--voc_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of voc.'
)
parser
.
add_argument
(
'--voc_ckpt'
,
type
=
str
,
default
=
None
,
help
=
'Checkpoint file of voc.'
)
parser
.
add_argument
(
...
...
paddlespeech/t2s/exps/synthesize_e2e.py
浏览文件 @
8641608f
...
...
@@ -174,10 +174,7 @@ def parse_args():
],
help
=
'Choose acoustic model type of tts task.'
)
parser
.
add_argument
(
'--am_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of acoustic model.'
)
'--am_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of acoustic model.'
)
parser
.
add_argument
(
'--am_ckpt'
,
type
=
str
,
...
...
@@ -220,10 +217,7 @@ def parse_args():
],
help
=
'Choose vocoder type of tts task.'
)
parser
.
add_argument
(
'--voc_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of voc.'
)
'--voc_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of voc.'
)
parser
.
add_argument
(
'--voc_ckpt'
,
type
=
str
,
default
=
None
,
help
=
'Checkpoint file of voc.'
)
parser
.
add_argument
(
...
...
paddlespeech/t2s/exps/voice_cloning.py
浏览文件 @
8641608f
...
...
@@ -131,10 +131,7 @@ def parse_args():
choices
=
[
'fastspeech2_aishell3'
,
'tacotron2_aishell3'
],
help
=
'Choose acoustic model type of tts task.'
)
parser
.
add_argument
(
'--am_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of acoustic model.'
)
'--am_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of acoustic model.'
)
parser
.
add_argument
(
'--am_ckpt'
,
type
=
str
,
...
...
@@ -160,10 +157,7 @@ def parse_args():
help
=
'Choose vocoder type of tts task.'
)
parser
.
add_argument
(
'--voc_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of voc.'
)
'--voc_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of voc.'
)
parser
.
add_argument
(
'--voc_ckpt'
,
type
=
str
,
default
=
None
,
help
=
'Checkpoint file of voc.'
)
parser
.
add_argument
(
...
...
paddlespeech/t2s/models/vits/__init__.py
浏览文件 @
8641608f
...
...
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.vits
import
*
from
.vits_updater
import
*
\ No newline at end of file
from
.vits_updater
import
*
paddlespeech/t2s/models/vits/vits_updater.py
浏览文件 @
8641608f
...
...
@@ -56,7 +56,8 @@ class VITSUpdater(StandardUpdater):
self
.
models
:
Dict
[
str
,
Layer
]
=
models
# self.model = model
self
.
model
=
model
.
_layers
if
isinstance
(
model
,
paddle
.
DataParallel
)
else
model
self
.
model
=
model
.
_layers
if
isinstance
(
model
,
paddle
.
DataParallel
)
else
model
self
.
optimizers
=
optimizers
self
.
optimizer_g
:
Optimizer
=
optimizers
[
'generator'
]
...
...
@@ -225,7 +226,8 @@ class VITSEvaluator(StandardEvaluator):
models
=
{
"main"
:
model
}
self
.
models
:
Dict
[
str
,
Layer
]
=
models
# self.model = model
self
.
model
=
model
.
_layers
if
isinstance
(
model
,
paddle
.
DataParallel
)
else
model
self
.
model
=
model
.
_layers
if
isinstance
(
model
,
paddle
.
DataParallel
)
else
model
self
.
criterions
=
criterions
self
.
criterion_mel
=
criterions
[
'mel'
]
...
...
paddlespeech/t2s/modules/losses.py
浏览文件 @
8641608f
...
...
@@ -971,18 +971,18 @@ class FeatureMatchLoss(nn.Layer):
return
feat_match_loss
# loss for VITS
class
KLDivergenceLoss
(
nn
.
Layer
):
"""KL divergence loss."""
def
forward
(
self
,
z_p
:
paddle
.
Tensor
,
logs_q
:
paddle
.
Tensor
,
m_p
:
paddle
.
Tensor
,
logs_p
:
paddle
.
Tensor
,
z_mask
:
paddle
.
Tensor
,
)
->
paddle
.
Tensor
:
self
,
z_p
:
paddle
.
Tensor
,
logs_q
:
paddle
.
Tensor
,
m_p
:
paddle
.
Tensor
,
logs_p
:
paddle
.
Tensor
,
z_mask
:
paddle
.
Tensor
,
)
->
paddle
.
Tensor
:
"""Calculate KL divergence loss.
Args:
...
...
@@ -1002,8 +1002,8 @@ class KLDivergenceLoss(nn.Layer):
logs_p
=
paddle
.
cast
(
logs_p
,
'float32'
)
z_mask
=
paddle
.
cast
(
z_mask
,
'float32'
)
kl
=
logs_p
-
logs_q
-
0.5
kl
+=
0.5
*
((
z_p
-
m_p
)
**
2
)
*
paddle
.
exp
(
-
2.0
*
logs_p
)
kl
+=
0.5
*
((
z_p
-
m_p
)
**
2
)
*
paddle
.
exp
(
-
2.0
*
logs_p
)
kl
=
paddle
.
sum
(
kl
*
z_mask
)
loss
=
kl
/
paddle
.
sum
(
z_mask
)
return
loss
\ No newline at end of file
return
loss
speechx/examples/README.md
浏览文件 @
8641608f
...
...
@@ -25,4 +25,3 @@ netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host
> Reminder: Only for developer, make sure you know what's it.
*
codelab - for speechx developer, using for test.
speechx/examples/ds2_ol/README.md
浏览文件 @
8641608f
...
...
@@ -3,4 +3,4 @@
## Examples
*
`websocket`
- Streaming ASR with websocket for deepspeech2_aishell.
*
`aishell`
- Streaming Decoding under aishell dataset, for local WER test.
\ No newline at end of file
*
`aishell`
- Streaming Decoding under aishell dataset, for local WER test.
speechx/speechx/codelab/README.md
浏览文件 @
8641608f
...
...
@@ -4,4 +4,3 @@
> Reminder: Only for developer.
*
codelab - for speechx developer, using for test.
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
浏览文件 @
8641608f
...
...
@@ -91,8 +91,8 @@ int main(int argc, char* argv[]) {
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
,
raw_data
));
int32
chunk_size
=
FLAGS_receptive_field_length
+
(
FLAGS_nnet_decoder_chunk
-
1
)
*
FLAGS_downsampling_rate
;
int32
chunk_size
=
FLAGS_receptive_field_length
+
(
FLAGS_nnet_decoder_chunk
-
1
)
*
FLAGS_downsampling_rate
;
int32
chunk_stride
=
FLAGS_downsampling_rate
*
FLAGS_nnet_decoder_chunk
;
int32
receptive_field_length
=
FLAGS_receptive_field_length
;
LOG
(
INFO
)
<<
"chunk size (frame): "
<<
chunk_size
;
...
...
speechx/speechx/decoder/ctc_tlg_decoder.cc
浏览文件 @
8641608f
...
...
@@ -64,7 +64,7 @@ std::string TLGDecoder::GetPartialResult() {
std
::
string
word
=
word_symbol_table_
->
Find
(
words_id
[
idx
]);
words
+=
word
;
}
return
words
;
return
words
;
}
std
::
string
TLGDecoder
::
GetFinalBestPath
()
{
...
...
speechx/speechx/decoder/param.h
浏览文件 @
8641608f
...
...
@@ -82,7 +82,7 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
opts
.
assembler_opts
.
subsampling_rate
=
FLAGS_downsampling_rate
;
opts
.
assembler_opts
.
receptive_filed_length
=
FLAGS_receptive_field_length
;
opts
.
assembler_opts
.
nnet_decoder_chunk
=
FLAGS_nnet_decoder_chunk
;
return
opts
;
}
...
...
speechx/speechx/decoder/tlg_decoder_main.cc
浏览文件 @
8641608f
...
...
@@ -93,8 +93,8 @@ int main(int argc, char* argv[]) {
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
,
raw_data
,
FLAGS_acoustic_scale
));
int32
chunk_size
=
FLAGS_receptive_field_length
+
(
FLAGS_nnet_decoder_chunk
-
1
)
*
FLAGS_downsampling_rate
;
int32
chunk_size
=
FLAGS_receptive_field_length
+
(
FLAGS_nnet_decoder_chunk
-
1
)
*
FLAGS_downsampling_rate
;
int32
chunk_stride
=
FLAGS_downsampling_rate
*
FLAGS_nnet_decoder_chunk
;
int32
receptive_field_length
=
FLAGS_receptive_field_length
;
LOG
(
INFO
)
<<
"chunk size (frame): "
<<
chunk_size
;
...
...
speechx/speechx/frontend/audio/assembler.cc
浏览文件 @
8641608f
...
...
@@ -24,7 +24,8 @@ using std::unique_ptr;
Assembler
::
Assembler
(
AssemblerOptions
opts
,
unique_ptr
<
FrontendInterface
>
base_extractor
)
{
frame_chunk_stride_
=
opts
.
subsampling_rate
*
opts
.
nnet_decoder_chunk
;
frame_chunk_size_
=
(
opts
.
nnet_decoder_chunk
-
1
)
*
opts
.
subsampling_rate
+
opts
.
receptive_filed_length
;
frame_chunk_size_
=
(
opts
.
nnet_decoder_chunk
-
1
)
*
opts
.
subsampling_rate
+
opts
.
receptive_filed_length
;
receptive_filed_length_
=
opts
.
receptive_filed_length
;
base_extractor_
=
std
::
move
(
base_extractor
);
dim_
=
base_extractor_
->
Dim
();
...
...
@@ -50,8 +51,8 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
Vector
<
BaseFloat
>
feature
;
result
=
base_extractor_
->
Read
(
&
feature
);
if
(
result
==
false
||
feature
.
Dim
()
==
0
)
{
if
(
IsFinished
()
==
false
)
return
false
;
break
;
if
(
IsFinished
()
==
false
)
return
false
;
break
;
}
feature_cache_
.
push
(
feature
);
}
...
...
@@ -61,22 +62,22 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
}
while
(
feature_cache_
.
size
()
<
frame_chunk_size_
)
{
Vector
<
BaseFloat
>
feature
(
dim_
,
kaldi
::
kSetZero
);
feature_cache_
.
push
(
feature
);
Vector
<
BaseFloat
>
feature
(
dim_
,
kaldi
::
kSetZero
);
feature_cache_
.
push
(
feature
);
}
int32
counter
=
0
;
int32
counter
=
0
;
int32
cache_size
=
frame_chunk_size_
-
frame_chunk_stride_
;
int32
elem_dim
=
base_extractor_
->
Dim
();
while
(
counter
<
frame_chunk_size_
)
{
Vector
<
BaseFloat
>&
val
=
feature_cache_
.
front
();
int32
start
=
counter
*
elem_dim
;
feats
->
Range
(
start
,
elem_dim
).
CopyFromVec
(
val
);
if
(
frame_chunk_size_
-
counter
<=
cache_size
)
{
feature_cache_
.
push
(
val
);
}
feature_cache_
.
pop
();
counter
++
;
Vector
<
BaseFloat
>&
val
=
feature_cache_
.
front
();
int32
start
=
counter
*
elem_dim
;
feats
->
Range
(
start
,
elem_dim
).
CopyFromVec
(
val
);
if
(
frame_chunk_size_
-
counter
<=
cache_size
)
{
feature_cache_
.
push
(
val
);
}
feature_cache_
.
pop
();
counter
++
;
}
return
result
;
...
...
speechx/speechx/frontend/audio/assembler.h
浏览文件 @
8641608f
...
...
@@ -25,7 +25,7 @@ struct AssemblerOptions {
int32
receptive_filed_length
;
int32
subsampling_rate
;
int32
nnet_decoder_chunk
;
AssemblerOptions
()
:
receptive_filed_length
(
1
),
subsampling_rate
(
1
),
...
...
@@ -47,15 +47,11 @@ class Assembler : public FrontendInterface {
// feat dim
virtual
size_t
Dim
()
const
{
return
dim_
;
}
virtual
void
SetFinished
()
{
base_extractor_
->
SetFinished
();
}
virtual
void
SetFinished
()
{
base_extractor_
->
SetFinished
();
}
virtual
bool
IsFinished
()
const
{
return
base_extractor_
->
IsFinished
();
}
virtual
void
Reset
()
{
base_extractor_
->
Reset
();
}
virtual
void
Reset
()
{
base_extractor_
->
Reset
();
}
private:
bool
Compute
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
);
...
...
speechx/speechx/frontend/audio/audio_cache.h
浏览文件 @
8641608f
...
...
@@ -30,7 +30,7 @@ class AudioCache : public FrontendInterface {
virtual
bool
Read
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
waves
);
// the audio dim is 1, one sample, which is useless,
// the audio dim is 1, one sample, which is useless,
// so we return size_(cache samples) instead.
virtual
size_t
Dim
()
const
{
return
size_
;
}
...
...
speechx/speechx/frontend/audio/fbank.cc
浏览文件 @
8641608f
...
...
@@ -29,19 +29,19 @@ using kaldi::Matrix;
using
std
::
vector
;
FbankComputer
::
FbankComputer
(
const
Options
&
opts
)
:
opts_
(
opts
),
computer_
(
opts
)
{}
:
opts_
(
opts
),
computer_
(
opts
)
{}
int32
FbankComputer
::
Dim
()
const
{
return
opts_
.
mel_opts
.
num_bins
+
(
opts_
.
use_energy
?
1
:
0
);
}
bool
FbankComputer
::
NeedRawLogEnergy
()
{
return
opts_
.
use_energy
&&
opts_
.
raw_energy
;
return
opts_
.
use_energy
&&
opts_
.
raw_energy
;
}
// Compute feat
bool
FbankComputer
::
Compute
(
Vector
<
BaseFloat
>*
window
,
Vector
<
BaseFloat
>*
feat
)
{
bool
FbankComputer
::
Compute
(
Vector
<
BaseFloat
>*
window
,
Vector
<
BaseFloat
>*
feat
)
{
RealFft
(
window
,
true
);
kaldi
::
ComputePowerSpectrum
(
window
);
const
kaldi
::
MelBanks
&
mel_bank
=
*
(
computer_
.
GetMelBanks
(
1.0
));
...
...
speechx/speechx/frontend/audio/feature_cache.cc
浏览文件 @
8641608f
...
...
@@ -72,9 +72,9 @@ bool FeatureCache::Compute() {
bool
result
=
base_extractor_
->
Read
(
&
feature
);
if
(
result
==
false
||
feature
.
Dim
()
==
0
)
return
false
;
int32
num_chunk
=
feature
.
Dim
()
/
dim_
;
int32
num_chunk
=
feature
.
Dim
()
/
dim_
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_chunk
;
++
chunk_idx
)
{
int32
start
=
chunk_idx
*
dim_
;
int32
start
=
chunk_idx
*
dim_
;
Vector
<
BaseFloat
>
feature_chunk
(
dim_
);
SubVector
<
BaseFloat
>
tmp
(
feature
.
Data
()
+
start
,
dim_
);
feature_chunk
.
CopyFromVec
(
tmp
);
...
...
speechx/speechx/frontend/audio/feature_cache.h
浏览文件 @
8641608f
...
...
@@ -22,9 +22,7 @@ namespace ppspeech {
struct
FeatureCacheOptions
{
int32
max_size
;
int32
timeout
;
// ms
FeatureCacheOptions
()
:
max_size
(
kint16max
),
timeout
(
1
)
{}
FeatureCacheOptions
()
:
max_size
(
kint16max
),
timeout
(
1
)
{}
};
class
FeatureCache
:
public
FrontendInterface
{
...
...
speechx/speechx/frontend/audio/feature_common.h
浏览文件 @
8641608f
...
...
@@ -23,11 +23,11 @@ template <class F>
class
StreamingFeatureTpl
:
public
FrontendInterface
{
public:
typedef
typename
F
::
Options
Options
;
StreamingFeatureTpl
(
const
Options
&
opts
,
StreamingFeatureTpl
(
const
Options
&
opts
,
std
::
unique_ptr
<
FrontendInterface
>
base_extractor
);
virtual
void
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
waves
);
virtual
bool
Read
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
);
// the dim_ is the dim of single frame feature
virtual
size_t
Dim
()
const
{
return
computer_
.
Dim
();
}
...
...
@@ -39,8 +39,9 @@ class StreamingFeatureTpl : public FrontendInterface {
base_extractor_
->
Reset
();
remained_wav_
.
Resize
(
0
);
}
private:
bool
Compute
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
waves
,
bool
Compute
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
waves
,
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
);
Options
opts_
;
std
::
unique_ptr
<
FrontendInterface
>
base_extractor_
;
...
...
speechx/speechx/frontend/audio/feature_common_inl.h
浏览文件 @
8641608f
...
...
@@ -16,16 +16,15 @@
namespace
ppspeech
{
template
<
class
F
>
StreamingFeatureTpl
<
F
>::
StreamingFeatureTpl
(
const
Options
&
opts
,
std
::
unique_ptr
<
FrontendInterface
>
base_extractor
)
:
opts_
(
opts
),
computer_
(
opts
),
window_function_
(
opts
.
frame_opts
)
{
StreamingFeatureTpl
<
F
>::
StreamingFeatureTpl
(
const
Options
&
opts
,
std
::
unique_ptr
<
FrontendInterface
>
base_extractor
)
:
opts_
(
opts
),
computer_
(
opts
),
window_function_
(
opts
.
frame_opts
)
{
base_extractor_
=
std
::
move
(
base_extractor
);
}
template
<
class
F
>
void
StreamingFeatureTpl
<
F
>::
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
waves
)
{
void
StreamingFeatureTpl
<
F
>::
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
waves
)
{
base_extractor_
->
Accept
(
waves
);
}
...
...
@@ -58,8 +57,9 @@ bool StreamingFeatureTpl<F>::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
// Compute feat
template
<
class
F
>
bool
StreamingFeatureTpl
<
F
>::
Compute
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
waves
,
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
)
{
bool
StreamingFeatureTpl
<
F
>::
Compute
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
waves
,
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
)
{
const
kaldi
::
FrameExtractionOptions
&
frame_opts
=
computer_
.
GetFrameOptions
();
int32
num_samples
=
waves
.
Dim
();
...
...
@@ -84,9 +84,11 @@ bool StreamingFeatureTpl<F>::Compute(const kaldi::Vector<kaldi::BaseFloat>& wave
&
window
,
need_raw_log_energy
?
&
raw_log_energy
:
NULL
);
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
this_feature
(
computer_
.
Dim
(),
kaldi
::
kUndefined
);
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
this_feature
(
computer_
.
Dim
(),
kaldi
::
kUndefined
);
computer_
.
Compute
(
&
window
,
&
this_feature
);
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
output_row
(
feats
->
Data
()
+
frame
*
Dim
(),
Dim
());
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
output_row
(
feats
->
Data
()
+
frame
*
Dim
(),
Dim
());
output_row
.
CopyFromVec
(
this_feature
);
}
return
true
;
...
...
speechx/speechx/frontend/audio/feature_pipeline.h
浏览文件 @
8641608f
...
...
@@ -16,6 +16,7 @@
#pragma once
#include "frontend/audio/assembler.h"
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/fbank.h"
...
...
@@ -23,7 +24,6 @@
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
#include "frontend/audio/assembler.h"
namespace
ppspeech
{
...
...
speechx/speechx/frontend/audio/linear_spectrogram.cc
浏览文件 @
8641608f
...
...
@@ -28,22 +28,21 @@ using kaldi::VectorBase;
using
kaldi
::
Matrix
;
using
std
::
vector
;
LinearSpectrogramComputer
::
LinearSpectrogramComputer
(
const
Options
&
opts
)
LinearSpectrogramComputer
::
LinearSpectrogramComputer
(
const
Options
&
opts
)
:
opts_
(
opts
)
{
kaldi
::
FeatureWindowFunction
feature_window_function
(
opts
.
frame_opts
);
int32
window_size
=
opts
.
frame_opts
.
WindowSize
();
frame_length_
=
window_size
;
dim_
=
window_size
/
2
+
1
;
BaseFloat
hanning_window_energy
=
kaldi
::
VecVec
(
feature_window_function
.
window
,
feature_window_function
.
window
);
BaseFloat
hanning_window_energy
=
kaldi
::
VecVec
(
feature_window_function
.
window
,
feature_window_function
.
window
);
int32
sample_rate
=
opts
.
frame_opts
.
samp_freq
;
scale_
=
2.0
/
(
hanning_window_energy
*
sample_rate
);
}
// Compute spectrogram feat
bool
LinearSpectrogramComputer
::
Compute
(
Vector
<
BaseFloat
>*
window
,
Vector
<
BaseFloat
>*
feat
)
{
Vector
<
BaseFloat
>*
feat
)
{
window
->
Resize
(
frame_length_
,
kaldi
::
kCopyData
);
RealFft
(
window
,
true
);
kaldi
::
ComputePowerSpectrum
(
window
);
...
...
speechx/speechx/nnet/nnet_forward_main.cc
浏览文件 @
8641608f
...
...
@@ -14,8 +14,8 @@
#include "base/flags.h"
#include "base/log.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
...
...
@@ -75,8 +75,8 @@ int main(int argc, char* argv[]) {
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
,
raw_data
,
FLAGS_acoustic_scale
));
int32
chunk_size
=
FLAGS_receptive_field_length
+
(
FLAGS_nnet_decoder_chunk
-
1
)
*
FLAGS_downsampling_rate
;
int32
chunk_size
=
FLAGS_receptive_field_length
+
(
FLAGS_nnet_decoder_chunk
-
1
)
*
FLAGS_downsampling_rate
;
int32
chunk_stride
=
FLAGS_downsampling_rate
*
FLAGS_nnet_decoder_chunk
;
int32
receptive_field_length
=
FLAGS_receptive_field_length
;
LOG
(
INFO
)
<<
"chunk size (frame): "
<<
chunk_size
;
...
...
@@ -130,7 +130,9 @@ int main(int argc, char* argv[]) {
vector
<
kaldi
::
BaseFloat
>
prob
;
while
(
decodable
->
FrameLikelihood
(
frame_idx
,
&
prob
))
{
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
vec_tmp
(
prob
.
size
());
std
::
memcpy
(
vec_tmp
.
Data
(),
prob
.
data
(),
sizeof
(
kaldi
::
BaseFloat
)
*
prob
.
size
());
std
::
memcpy
(
vec_tmp
.
Data
(),
prob
.
data
(),
sizeof
(
kaldi
::
BaseFloat
)
*
prob
.
size
());
prob_vec
.
push_back
(
vec_tmp
);
frame_idx
++
;
}
...
...
@@ -142,7 +144,8 @@ int main(int argc, char* argv[]) {
KALDI_LOG
<<
" the nnet prob of "
<<
utt
<<
" is empty"
;
continue
;
}
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
result
(
prob_vec
.
size
(),
prob_vec
[
0
].
Dim
());
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
result
(
prob_vec
.
size
(),
prob_vec
[
0
].
Dim
());
for
(
int32
row_idx
=
0
;
row_idx
<
prob_vec
.
size
();
++
row_idx
)
{
for
(
int32
col_idx
=
0
;
col_idx
<
prob_vec
[
0
].
Dim
();
++
col_idx
)
{
result
(
row_idx
,
col_idx
)
=
prob_vec
[
row_idx
](
col_idx
);
...
...
speechx/speechx/protocol/websocket/websocket_client.h
浏览文件 @
8641608f
...
...
@@ -40,8 +40,8 @@ class WebSocketClient {
void
SendEndSignal
();
void
SendDataEnd
();
bool
Done
()
const
{
return
done_
;
}
std
::
string
GetResult
()
const
{
return
result_
;
}
std
::
string
GetPartialResult
()
const
{
return
partial_result_
;}
std
::
string
GetResult
()
const
{
return
result_
;
}
std
::
string
GetPartialResult
()
const
{
return
partial_result_
;
}
private:
void
Connect
();
...
...
speechx/speechx/protocol/websocket/websocket_server.cc
浏览文件 @
8641608f
...
...
@@ -76,9 +76,10 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) {
recognizer_
->
Accept
(
pcm_data
);
std
::
string
partial_result
=
recognizer_
->
GetPartialResult
();
json
::
value
rv
=
{
{
"status"
,
"ok"
},
{
"type"
,
"partial_result"
},
{
"result"
,
partial_result
}};
json
::
value
rv
=
{{
"status"
,
"ok"
},
{
"type"
,
"partial_result"
},
{
"result"
,
partial_result
}};
ws_
.
text
(
true
);
ws_
.
write
(
asio
::
buffer
(
json
::
serialize
(
rv
)));
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录