Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f869db7a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f869db7a
编写于
5月 06, 2022
作者:
Y
Yang Zhou
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of github.com:SmileGoat/PaddleSpeech into add_fbank
上级
491f2d04
37c6106e
变更
40
展开全部
隐藏空白更改
内联
并排
Showing
40 changed file
with
1673 addition
and
85 deletion
+1673
-85
demos/speaker_verification/README.md
demos/speaker_verification/README.md
+1
-1
demos/speaker_verification/README_cn.md
demos/speaker_verification/README_cn.md
+3
-3
demos/streaming_asr_server/README.md
demos/streaming_asr_server/README.md
+277
-2
demos/streaming_asr_server/README_cn.md
demos/streaming_asr_server/README_cn.md
+277
-1
demos/streaming_asr_server/conf/punc_application.yaml
demos/streaming_asr_server/conf/punc_application.yaml
+35
-0
demos/streaming_asr_server/conf/ws_conformer_application.yaml
...s/streaming_asr_server/conf/ws_conformer_application.yaml
+3
-3
demos/streaming_asr_server/punc_server.py
demos/streaming_asr_server/punc_server.py
+38
-0
demos/streaming_asr_server/server.sh
demos/streaming_asr_server/server.sh
+5
-0
demos/streaming_asr_server/streaming_asr_server.py
demos/streaming_asr_server/streaming_asr_server.py
+38
-0
demos/streaming_asr_server/test.sh
demos/streaming_asr_server/test.sh
+5
-2
demos/streaming_asr_server/websocket_client.py
demos/streaming_asr_server/websocket_client.py
+6
-1
examples/aishell/asr1/conf/chunk_conformer.yaml
examples/aishell/asr1/conf/chunk_conformer.yaml
+3
-3
examples/aishell/asr1/conf/conformer.yaml
examples/aishell/asr1/conf/conformer.yaml
+1
-1
examples/aishell/asr1/conf/transformer.yaml
examples/aishell/asr1/conf/transformer.yaml
+2
-2
paddlespeech/cli/vector/infer.py
paddlespeech/cli/vector/infer.py
+7
-1
paddlespeech/server/README_cn.md
paddlespeech/server/README_cn.md
+20
-0
paddlespeech/server/bin/paddlespeech_client.py
paddlespeech/server/bin/paddlespeech_client.py
+143
-6
paddlespeech/server/conf/application.yaml
paddlespeech/server/conf/application.yaml
+12
-1
paddlespeech/server/conf/vector_application.yaml
paddlespeech/server/conf/vector_application.yaml
+32
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+50
-0
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+74
-16
paddlespeech/server/engine/engine_factory.py
paddlespeech/server/engine/engine_factory.py
+3
-0
paddlespeech/server/engine/vector/__init__.py
paddlespeech/server/engine/vector/__init__.py
+13
-0
paddlespeech/server/engine/vector/python/__init__.py
paddlespeech/server/engine/vector/python/__init__.py
+13
-0
paddlespeech/server/engine/vector/python/vector_engine.py
paddlespeech/server/engine/vector/python/vector_engine.py
+200
-0
paddlespeech/server/restful/api.py
paddlespeech/server/restful/api.py
+3
-1
paddlespeech/server/restful/request.py
paddlespeech/server/restful/request.py
+41
-1
paddlespeech/server/restful/response.py
paddlespeech/server/restful/response.py
+62
-1
paddlespeech/server/restful/vector_api.py
paddlespeech/server/restful/vector_api.py
+151
-0
paddlespeech/server/utils/audio_handler.py
paddlespeech/server/utils/audio_handler.py
+101
-0
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+3
-1
speechx/examples/ds2_ol/aishell/run.sh
speechx/examples/ds2_ol/aishell/run.sh
+0
-1
speechx/examples/ds2_ol/decoder/recognizer_test_main.cc
speechx/examples/ds2_ol/decoder/recognizer_test_main.cc
+3
-1
speechx/examples/ds2_ol/feat/compute_fbank_main.cc
speechx/examples/ds2_ol/feat/compute_fbank_main.cc
+1
-0
speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc
.../examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc
+2
-1
speechx/examples/ds2_ol/websocket/websocket_server.sh
speechx/examples/ds2_ol/websocket/websocket_server.sh
+0
-1
speechx/speechx/decoder/param.h
speechx/speechx/decoder/param.h
+29
-21
speechx/speechx/frontend/audio/fbank.cc
speechx/speechx/frontend/audio/fbank.cc
+7
-4
speechx/speechx/frontend/audio/feature_pipeline.cc
speechx/speechx/frontend/audio/feature_pipeline.cc
+5
-5
speechx/speechx/frontend/audio/feature_pipeline.h
speechx/speechx/frontend/audio/feature_pipeline.h
+4
-4
未找到文件。
demos/speaker_verification/README.md
浏览文件 @
f869db7a
...
...
@@ -14,7 +14,7 @@ see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/doc
You can choose one way from easy, meduim and hard to install paddlespeech.
### 2. Prepare Input File
The input of this demo should be a WAV file(
`.wav`
), and the sample rate must be the same as the model.
The input of this
cli
demo should be a WAV file(
`.wav`
), and the sample rate must be the same as the model.
Here are sample files for this demo that can be downloaded:
```
bash
...
...
demos/speaker_verification/README_cn.md
浏览文件 @
f869db7a
...
...
@@ -4,16 +4,16 @@
## 介绍
声纹识别是一项用计算机程序自动提取说话人特征的技术。
这个 demo 是
一个从给定音频文件
提取说话人特征,它可以通过使用
`PaddleSpeech`
的单个命令或 python 中的几行代码来实现。
这个 demo 是
从一个给定音频文件中
提取说话人特征,它可以通过使用
`PaddleSpeech`
的单个命令或 python 中的几行代码来实现。
## 使用方法
### 1. 安装
请看
[
安装文档
](
https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md
)
。
你可以从
easy,medium,hard 三中
方式中选择一种方式安装。
你可以从
easy medium,hard 三种
方式中选择一种方式安装。
### 2. 准备输入
这个
demo 的输入应该是一个 WAV 文件(
`.wav`
),并且采样率必须与模型的采样率相同。
声纹cli
demo 的输入应该是一个 WAV 文件(
`.wav`
),并且采样率必须与模型的采样率相同。
可以下载此 demo 的示例音频:
```
bash
...
...
demos/streaming_asr_server/README.md
浏览文件 @
f869db7a
此差异已折叠。
点击以展开。
demos/streaming_asr_server/README_cn.md
浏览文件 @
f869db7a
此差异已折叠。
点击以展开。
demos/streaming_asr_server/conf/punc_application.yaml
0 → 100644
浏览文件 @
f869db7a
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host
:
0.0.0.0
port
:
8190
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_python']
# protocol = ['http'] (only one can be selected).
# http only support offline engine type.
protocol
:
'
http'
engine_list
:
[
'
text_python'
]
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### Text #########################################
################### text task: punc; engine_type: python #######################
text_python
:
task
:
punc
model_type
:
'
ernie_linear_p3_wudao'
lang
:
'
zh'
sample_rate
:
16000
cfg_path
:
# [optional]
ckpt_path
:
# [optional]
vocab_file
:
# [optional]
device
:
'
cpu'
# set 'gpu:id' or 'cpu'
demos/streaming_asr_server/conf/ws_conformer_application.yaml
浏览文件 @
f869db7a
...
...
@@ -4,7 +4,7 @@
# SERVER SETTING #
#################################################################################
host
:
0.0.0.0
port
:
8
0
90
port
:
8
2
90
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online']
...
...
@@ -29,7 +29,7 @@ asr_online:
cfg_path
:
decode_method
:
force_yes
:
True
device
:
# cpu or gpu:id
device
:
'
cpu'
# cpu or gpu:id
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
...
...
@@ -42,4 +42,4 @@ asr_online:
window_ms
:
25
# ms
shift_ms
:
10
# ms
sample_rate
:
16000
sample_width
:
2
\ No newline at end of file
sample_width
:
2
demos/streaming_asr_server/punc_server.py
0 → 100644
浏览文件 @
f869db7a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
paddlespeech.cli.log
import
logger
from
paddlespeech.server.bin.paddlespeech_server
import
ServerExecutor
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
'paddlespeech_server.start'
,
add_help
=
True
)
parser
.
add_argument
(
"--config_file"
,
action
=
"store"
,
help
=
"yaml file of the app"
,
default
=
None
,
required
=
True
)
parser
.
add_argument
(
"--log_file"
,
action
=
"store"
,
help
=
"log file"
,
default
=
"./log/paddlespeech.log"
)
logger
.
info
(
"start to parse the args"
)
args
=
parser
.
parse_args
()
logger
.
info
(
"start to launch the punctuation server"
)
punc_server
=
ServerExecutor
()
punc_server
(
config_file
=
args
.
config_file
,
log_file
=
args
.
log_file
)
demos/streaming_asr_server/server.sh
0 → 100755
浏览文件 @
f869db7a
export
CUDA_VISIBLE_DEVICE
=
0,1,2,3
nohup
python3 punc_server.py
--config_file
conf/punc_application.yaml
>
punc.log 2>&1 &
nohup
python3 streaming_asr_server.py
--config_file
conf/ws_conformer_application.yaml
>
streaming_asr.log 2>&1 &
demos/streaming_asr_server/streaming_asr_server.py
0 → 100644
浏览文件 @
f869db7a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
paddlespeech.cli.log
import
logger
from
paddlespeech.server.bin.paddlespeech_server
import
ServerExecutor
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
'paddlespeech_server.start'
,
add_help
=
True
)
parser
.
add_argument
(
"--config_file"
,
action
=
"store"
,
help
=
"yaml file of the app"
,
default
=
None
,
required
=
True
)
parser
.
add_argument
(
"--log_file"
,
action
=
"store"
,
help
=
"log file"
,
default
=
"./log/paddlespeech.log"
)
logger
.
info
(
"start to parse the args"
)
args
=
parser
.
parse_args
()
logger
.
info
(
"start to launch the streaming asr server"
)
streaming_asr_server
=
ServerExecutor
()
streaming_asr_server
(
config_file
=
args
.
config_file
,
log_file
=
args
.
log_file
)
demos/streaming_asr_server/test.sh
100644 → 100755
浏览文件 @
f869db7a
# download the test wav
wget
-c
https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
# read the wav and pass it to service
python3 websocket_client.py
--wavfile
./zh.wav
# read the wav and pass it to only streaming asr service
python3 websocket_client.py
--server_ip
127.0.0.1
--port
8290
--wavfile
./zh.wav
# read the wav and call streaming and punc service
python3 websocket_client.py
--server_ip
127.0.0.1
--port
8290
--punc
.server_ip 127.0.0.1
--punc
.port 8190
--wavfile
./zh.wav
demos/streaming_asr_server/websocket_client.py
浏览文件 @
f869db7a
...
...
@@ -28,6 +28,7 @@ def main(args):
handler
=
ASRWsAudioHandler
(
args
.
server_ip
,
args
.
port
,
endpoint
=
args
.
endpoint
,
punc_server_ip
=
args
.
punc_server_ip
,
punc_server_port
=
args
.
punc_server_port
)
loop
=
asyncio
.
get_event_loop
()
...
...
@@ -69,7 +70,11 @@ if __name__ == "__main__":
default
=
8091
,
dest
=
"punc_server_port"
,
help
=
'Punctuation server port'
)
parser
.
add_argument
(
"--endpoint"
,
type
=
str
,
default
=
"/paddlespeech/asr/streaming"
,
help
=
"ASR websocket endpoint"
)
parser
.
add_argument
(
"--wavfile"
,
action
=
"store"
,
...
...
examples/aishell/asr1/conf/chunk_conformer.yaml
浏览文件 @
f869db7a
...
...
@@ -10,7 +10,7 @@ encoder_conf:
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
...
...
@@ -30,7 +30,7 @@ decoder_conf:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
6
dropout_rate
:
0.1
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
...
...
@@ -39,7 +39,7 @@ model_conf:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
...
...
examples/aishell/asr1/conf/conformer.yaml
浏览文件 @
f869db7a
...
...
@@ -37,7 +37,7 @@ model_conf:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
...
...
examples/aishell/asr1/conf/transformer.yaml
浏览文件 @
f869db7a
...
...
@@ -10,7 +10,7 @@ encoder_conf:
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
...
...
@@ -21,7 +21,7 @@ decoder_conf:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
6
dropout_rate
:
0.1
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
...
...
paddlespeech/cli/vector/infer.py
浏览文件 @
f869db7a
...
...
@@ -272,7 +272,8 @@ class VectorExecutor(BaseExecutor):
model_type
:
str
=
'ecapatdnn_voxceleb12'
,
sample_rate
:
int
=
16000
,
cfg_path
:
Optional
[
os
.
PathLike
]
=
None
,
ckpt_path
:
Optional
[
os
.
PathLike
]
=
None
):
ckpt_path
:
Optional
[
os
.
PathLike
]
=
None
,
task
=
None
):
"""Init the neural network from the model path
Args:
...
...
@@ -284,8 +285,10 @@ class VectorExecutor(BaseExecutor):
Defaults to None.
ckpt_path (Optional[os.PathLike], optional): the pretrained model path, which is stored in the disk.
Defaults to None.
task (str, optional): the model task type
"""
# stage 0: avoid to init the mode again
self
.
task
=
task
if
hasattr
(
self
,
"model"
):
logger
.
info
(
"Model has been initialized"
)
return
...
...
@@ -434,6 +437,9 @@ class VectorExecutor(BaseExecutor):
if
self
.
sample_rate
!=
16000
and
self
.
sample_rate
!=
8000
:
logger
.
error
(
"invalid sample rate, please input --sr 8000 or --sr 16000"
)
logger
.
error
(
f
"The model sample rate:
{
self
.
sample_rate
}
, the external sample rate is:
{
sample_rate
}
"
)
return
False
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
...
...
paddlespeech/server/README_cn.md
浏览文件 @
f869db7a
...
...
@@ -63,3 +63,23 @@ paddlespeech_server start --config_file conf/tts_online_application.yaml
```
paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --input "您好,欢迎使用百度飞桨深度学习框架!" --output output.wav
```
## 声纹识别
### 启动声纹识别服务
```
paddlespeech_server start --config_file conf/vector_application.yaml
```
### 获取说话人音频声纹
```
paddlespeech_client vector --task spk --server_ip 127.0.0.1 --port 8090 --input 85236145389.wav
```
### 两个说话人音频声纹打分
```
paddlespeech_client vector --task score --server_ip 127.0.0.1 --port 8090 --enroll 123456789.wav --test 85236145389.wav
```
paddlespeech/server/bin/paddlespeech_client.py
浏览文件 @
f869db7a
...
...
@@ -35,7 +35,7 @@ from paddlespeech.server.utils.util import wav2base64
__all__
=
[
'TTSClientExecutor'
,
'TTSOnlineClientExecutor'
,
'ASRClientExecutor'
,
'ASROnlineClientExecutor'
,
'CLSClientExecutor'
'ASROnlineClientExecutor'
,
'CLSClientExecutor'
,
'VectorClientExecutor'
]
...
...
@@ -411,6 +411,18 @@ class ASROnlineClientExecutor(BaseExecutor):
'--lang'
,
type
=
str
,
default
=
"zh_cn"
,
help
=
'language'
)
self
.
parser
.
add_argument
(
'--audio_format'
,
type
=
str
,
default
=
"wav"
,
help
=
'audio format'
)
self
.
parser
.
add_argument
(
'--punc.server_ip'
,
type
=
str
,
default
=
None
,
dest
=
"punc_server_ip"
,
help
=
'Punctuation server ip'
)
self
.
parser
.
add_argument
(
'--punc.port'
,
type
=
int
,
default
=
8190
,
dest
=
"punc_server_port"
,
help
=
'Punctuation server port'
)
def
execute
(
self
,
argv
:
List
[
str
])
->
bool
:
args
=
self
.
parser
.
parse_args
(
argv
)
...
...
@@ -428,7 +440,9 @@ class ASROnlineClientExecutor(BaseExecutor):
port
=
port
,
sample_rate
=
sample_rate
,
lang
=
lang
,
audio_format
=
audio_format
)
audio_format
=
audio_format
,
punc_server_ip
=
args
.
punc_server_ip
,
punc_server_port
=
args
.
punc_server_port
)
time_end
=
time
.
time
()
logger
.
info
(
res
)
logger
.
info
(
"Response time %f s."
%
(
time_end
-
time_start
))
...
...
@@ -445,12 +459,30 @@ class ASROnlineClientExecutor(BaseExecutor):
port
:
int
=
8091
,
sample_rate
:
int
=
16000
,
lang
:
str
=
"zh_cn"
,
audio_format
:
str
=
"wav"
):
"""
Python API to call an executor.
audio_format
:
str
=
"wav"
,
punc_server_ip
:
str
=
None
,
punc_server_port
:
str
=
None
):
"""Python API to call asr online executor.
Args:
input (str): the audio file to be send to streaming asr service.
server_ip (str, optional): streaming asr server ip. Defaults to "127.0.0.1".
port (int, optional): streaming asr server port. Defaults to 8091.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
lang (str, optional): audio language type. Defaults to "zh_cn".
audio_format (str, optional): audio format. Defaults to "wav".
punc_server_ip (str, optional): punctuation server ip. Defaults to None.
punc_server_port (str, optional): punctuation server port. Defaults to None.
Returns:
str: the audio text
"""
logger
.
info
(
"asr websocket client start"
)
handler
=
ASRWsAudioHandler
(
server_ip
,
port
)
handler
=
ASRWsAudioHandler
(
server_ip
,
port
,
punc_server_ip
=
punc_server_ip
,
punc_server_port
=
punc_server_port
)
loop
=
asyncio
.
get_event_loop
()
res
=
loop
.
run_until_complete
(
handler
.
run
(
input
))
logger
.
info
(
"asr websocket client finished"
)
...
...
@@ -583,3 +615,108 @@ class TextClientExecutor(BaseExecutor):
response_dict
=
res
.
json
()
punc_text
=
response_dict
[
"result"
][
"punc_text"
]
return
punc_text
@
cli_client_register
(
name
=
'paddlespeech_client.vector'
,
description
=
'visit the vector service'
)
class
VectorClientExecutor
(
BaseExecutor
):
def
__init__
(
self
):
super
(
VectorClientExecutor
,
self
).
__init__
()
self
.
parser
=
argparse
.
ArgumentParser
(
prog
=
'paddlespeech_client.vector'
,
add_help
=
True
)
self
.
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
'127.0.0.1'
,
help
=
'server ip'
)
self
.
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8090
,
help
=
'server port'
)
self
.
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
None
,
help
=
'sentence to be process by text server.'
)
self
.
parser
.
add_argument
(
'--task'
,
type
=
str
,
default
=
"spk"
,
choices
=
[
"spk"
,
"score"
],
help
=
"The vector service task"
)
self
.
parser
.
add_argument
(
"--enroll"
,
type
=
str
,
default
=
None
,
help
=
"The enroll audio"
)
self
.
parser
.
add_argument
(
"--test"
,
type
=
str
,
default
=
None
,
help
=
"The test audio"
)
def
execute
(
self
,
argv
:
List
[
str
])
->
bool
:
"""Execute the request from the argv.
Args:
argv (List): the request arguments
Returns:
str: the request flag
"""
args
=
self
.
parser
.
parse_args
(
argv
)
input_
=
args
.
input
server_ip
=
args
.
server_ip
port
=
args
.
port
task
=
args
.
task
try
:
time_start
=
time
.
time
()
res
=
self
(
input
=
input_
,
server_ip
=
server_ip
,
port
=
port
,
enroll_audio
=
args
.
enroll
,
test_audio
=
args
.
test
,
task
=
task
)
time_end
=
time
.
time
()
logger
.
info
(
f
"The vector:
{
res
}
"
)
logger
.
info
(
"Response time %f s."
%
(
time_end
-
time_start
))
return
True
except
Exception
as
e
:
logger
.
error
(
"Failed to extract vector."
)
logger
.
error
(
e
)
return
False
@
stats_wrapper
def
__call__
(
self
,
input
:
str
,
server_ip
:
str
=
"127.0.0.1"
,
port
:
int
=
8090
,
audio_format
:
str
=
"wav"
,
sample_rate
:
int
=
16000
,
enroll_audio
:
str
=
None
,
test_audio
:
str
=
None
,
task
=
"spk"
):
"""
Python API to call text executor.
Args:
input (str): the request audio data
server_ip (str, optional): the server ip. Defaults to "127.0.0.1".
port (int, optional): the server port. Defaults to 8090.
audio_format (str, optional): audio format. Defaults to "wav".
sample_rate (str, optional): audio sample rate. Defaults to 16000.
enroll_audio (str, optional): enroll audio data. Defaults to None.
test_audio (str, optional): test audio data. Defaults to None.
task (str, optional): the task type, "spk" or "socre". Defaults to "spk"
Returns:
str: the audio embedding or score between enroll and test audio
"""
if
task
==
"spk"
:
from
paddlespeech.server.utils.audio_handler
import
VectorHttpHandler
logger
.
info
(
"vector http client start"
)
logger
.
info
(
f
"the input audio:
{
input
}
"
)
handler
=
VectorHttpHandler
(
server_ip
=
server_ip
,
port
=
port
)
res
=
handler
.
run
(
input
,
audio_format
,
sample_rate
)
return
res
elif
task
==
"score"
:
from
paddlespeech.server.utils.audio_handler
import
VectorScoreHttpHandler
logger
.
info
(
"vector score http client start"
)
logger
.
info
(
f
"enroll audio:
{
enroll_audio
}
, test audio:
{
test_audio
}
"
)
handler
=
VectorScoreHttpHandler
(
server_ip
=
server_ip
,
port
=
port
)
res
=
handler
.
run
(
enroll_audio
,
test_audio
,
audio_format
,
sample_rate
)
logger
.
info
(
f
"The vector score is:
{
res
}
"
)
else
:
logger
.
error
(
f
"Sorry, we have not support such task
{
task
}
"
)
paddlespeech/server/conf/application.yaml
浏览文件 @
f869db7a
...
...
@@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket', 'http'] (only one can be selected).
# http only support offline engine type.
protocol
:
'
http'
engine_list
:
[
'
asr_python'
,
'
tts_python'
,
'
cls_python'
,
'
text_python'
]
engine_list
:
[
'
asr_python'
,
'
tts_python'
,
'
cls_python'
,
'
text_python'
,
'
vector_python'
]
#################################################################################
...
...
@@ -166,4 +166,15 @@ text_python:
cfg_path
:
# [optional]
ckpt_path
:
# [optional]
vocab_file
:
# [optional]
device
:
# set 'gpu:id' or 'cpu'
################################### Vector ######################################
################### Vector task: spk; engine_type: python #######################
vector_python
:
task
:
spk
model_type
:
'
ecapatdnn_voxceleb12'
sample_rate
:
16000
cfg_path
:
# [optional]
ckpt_path
:
# [optional]
device
:
# set 'gpu:id' or 'cpu'
\ No newline at end of file
paddlespeech/server/conf/vector_application.yaml
0 → 100644
浏览文件 @
f869db7a
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host
:
0.0.0.0
port
:
8090
# The task format in the engin_list is: <speech task>_<engine type>
# protocol = ['http'] (only one can be selected).
# http only support offline engine type.
protocol
:
'
http'
engine_list
:
[
'
vector_python'
]
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### Vector ######################################
################### Vector task: spk; engine_type: python #######################
vector_python
:
task
:
spk
model_type
:
'
ecapatdnn_voxceleb12'
sample_rate
:
16000
cfg_path
:
# [optional]
ckpt_path
:
# [optional]
device
:
# set 'gpu:id' or 'cpu'
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
f869db7a
...
...
@@ -153,6 +153,12 @@ class PaddleASRConnectionHanddler:
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
def
extract_feat
(
self
,
samples
):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
# if self.first_char_occur_elapsed is not None:
# self.first_char_occur_elapsed = time.time()
if
"deepspeech2online"
in
self
.
model_type
:
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
...
...
@@ -290,6 +296,9 @@ class PaddleASRConnectionHanddler:
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
self
.
word_time_stamp
=
[]
self
.
time_stamp
=
[]
self
.
first_char_occur_elapsed
=
None
def
decode
(
self
,
is_finished
=
False
):
if
"deepspeech2online"
in
self
.
model_type
:
...
...
@@ -505,6 +514,9 @@ class PaddleASRConnectionHanddler:
else
:
return
''
def
get_word_time_stamp
(
self
):
return
self
.
word_time_stamp
@
paddle
.
no_grad
()
def
rescoring
(
self
):
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
...
...
@@ -567,10 +579,48 @@ class PaddleASRConnectionHanddler:
best_index
=
i
# update the one best result
# hyps stored the beam results and each fields is:
logger
.
info
(
f
"best index:
{
best_index
}
"
)
# logger.info(f'best result: {hyps[best_index]}')
# the field of the hyps is:
# hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][1]: the sentence decoding probability with all paths
# hyps[0][2]: viterbi_blank ending probability
# hyps[0][3]: viterbi_non_blank probability
# hyps[0][4]: current_token_prob,
# hyps[0][5]: times_viterbi_blank,
# hyps[0][6]: times_titerbi_non_blank
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
# update the hyps time stamp
self
.
time_stamp
=
hyps
[
best_index
][
5
]
if
hyps
[
best_index
][
2
]
>
hyps
[
best_index
][
3
]
else
hyps
[
best_index
][
6
]
logger
.
info
(
f
"time stamp:
{
self
.
time_stamp
}
"
)
self
.
update_result
()
# update each word start and end time stamp
frame_shift_in_ms
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
*
self
.
n_shift
/
self
.
sample_rate
logger
.
info
(
f
"frame shift ms:
{
frame_shift_in_ms
}
"
)
word_time_stamp
=
[]
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
)
/
2.0
if
idx
>
0
else
0
start
=
start
*
frame_shift_in_ms
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
end
=
end
*
frame_shift_in_ms
word_time_stamp
.
append
({
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"bg"
:
start
,
"ed"
:
end
})
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}")
self
.
word_time_stamp
=
word_time_stamp
logger
.
info
(
f
"word time stamp:
{
self
.
word_time_stamp
}
"
)
class
ASRServerExecutor
(
ASRExecutor
):
def
__init__
(
self
):
...
...
paddlespeech/server/engine/asr/online/ctc_search.py
浏览文件 @
f869db7a
...
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
from
collections
import
defaultdict
import
paddle
...
...
@@ -26,7 +27,7 @@ class CTCPrefixBeamSearch:
"""Implement the ctc prefix beam search
Args:
config (yacs.config.CfgNode):
_description_
config (yacs.config.CfgNode):
the ctc prefix beam search configuration
"""
self
.
config
=
config
self
.
reset
()
...
...
@@ -54,14 +55,23 @@ class CTCPrefixBeamSearch:
assert
len
(
ctc_probs
.
shape
)
==
2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
# 0. blank_ending_score,
# 1. none_blank_ending_score,
# 2. viterbi_blank ending,
# 3. viterbi_non_blank,
# 4. current_token_prob,
# 5. times_viterbi_blank,
# 6. times_titerbi_non_blank
if
self
.
cur_hyps
is
None
:
self
.
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
self
.
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
),
0.0
,
0.0
,
-
float
(
'inf'
),
[],
[]))]
# self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
logp
=
ctc_probs
[
t
]
# (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
)))
# next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
),
-
float
(
'inf'
),
-
float
(
'inf'
),
-
float
(
'inf'
),
[],
[]))
# 2.1 First beam prune: select topk best
# do token passing process
...
...
@@ -69,36 +79,83 @@ class CTCPrefixBeamSearch:
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
for
prefix
,
(
pb
,
pnb
)
in
self
.
cur_hyps
:
for
prefix
,
(
pb
,
pnb
,
v_b_s
,
v_nb_s
,
cur_token_prob
,
times_s
,
times_ns
)
in
self
.
cur_hyps
:
last
=
prefix
[
-
1
]
if
len
(
prefix
)
>
0
else
None
if
s
==
blank_id
:
# blank
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
=
next_hyps
[
prefix
]
n_pb
=
log_add
([
n_pb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
pre_times
=
times_s
if
v_b_s
>
v_nb_s
else
times_ns
n_times_s
=
copy
.
deepcopy
(
pre_times
)
viterbi_score
=
v_b_s
if
v_b_s
>
v_nb_s
else
v_nb_s
n_v_s
=
viterbi_score
+
ps
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
)
elif
s
==
last
:
# Update *ss -> *s;
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
# case1: *a + a => *a
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
=
next_hyps
[
prefix
]
n_pnb
=
log_add
([
n_pnb
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
if
n_v_ns
<
v_nb_s
+
ps
:
n_v_ns
=
v_nb_s
+
ps
if
n_cur_token_prob
<
ps
:
n_cur_token_prob
=
ps
n_times_ns
=
copy
.
deepcopy
(
times_ns
)
n_times_ns
[
-
1
]
=
self
.
abs_time_step
# 注意,这里要重新使用绝对时间
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
)
# Update *s-s -> *ss, - is for blank
# Case 2: *aε + a => *aa
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
=
next_hyps
[
n_prefix
]
if
n_v_ns
<
v_b_s
+
ps
:
n_v_ns
=
v_b_s
+
ps
n_cur_token_prob
=
ps
n_times_ns
=
copy
.
deepcopy
(
times_s
)
n_times_ns
.
append
(
self
.
abs_time_step
)
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
)
else
:
# Case 3: *a + b => *ab, *aε + b => *ab
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
=
next_hyps
[
n_prefix
]
viterbi_score
=
v_b_s
if
v_b_s
>
v_nb_s
else
v_nb_s
pre_times
=
times_s
if
v_b_s
>
v_nb_s
else
times_ns
if
n_v_ns
<
viterbi_score
+
ps
:
n_v_ns
=
viterbi_score
+
ps
n_cur_token_prob
=
ps
n_times_ns
=
copy
.
deepcopy
(
pre_times
)
n_times_ns
.
append
(
self
.
abs_time_step
)
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
,
n_v_s
,
n_v_ns
,
n_cur_token_prob
,
n_times_s
,
n_times_ns
)
# 2.2 Second beam prune
next_hyps
=
sorted
(
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])
),
key
=
lambda
x
:
log_add
(
[
x
[
1
][
0
],
x
[
1
][
1
]]
),
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
self
.
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
self
.
cur_hyps
]
# 2.3 update the absolute time step
self
.
abs_time_step
+=
1
self
.
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]),
y
[
1
][
2
],
y
[
1
][
3
],
y
[
1
][
4
],
y
[
1
][
5
],
y
[
1
][
6
])
for
y
in
self
.
cur_hyps
]
logger
.
info
(
"ctc prefix search success"
)
return
self
.
hyps
...
...
@@ -123,6 +180,7 @@ class CTCPrefixBeamSearch:
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
def
finalize_search
(
self
):
"""do nothing in ctc_prefix_beam_search
...
...
paddlespeech/server/engine/engine_factory.py
浏览文件 @
f869db7a
...
...
@@ -49,5 +49,8 @@ class EngineFactory(object):
elif
engine_name
.
lower
()
==
'text'
and
engine_type
.
lower
()
==
'python'
:
from
paddlespeech.server.engine.text.python.text_engine
import
TextEngine
return
TextEngine
()
elif
engine_name
.
lower
()
==
'vector'
and
engine_type
.
lower
()
==
'python'
:
from
paddlespeech.server.engine.vector.python.vector_engine
import
VectorEngine
return
VectorEngine
()
else
:
return
None
paddlespeech/server/engine/vector/__init__.py
0 → 100644
浏览文件 @
f869db7a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/engine/vector/python/__init__.py
0 → 100644
浏览文件 @
f869db7a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/engine/vector/python/vector_engine.py
0 → 100644
浏览文件 @
f869db7a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
io
from
collections
import
OrderedDict
import
numpy
as
np
import
paddle
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.compliance.librosa
import
melspectrogram
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.vector.infer
import
VectorExecutor
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.vector.io.batch
import
feature_normalize
class
PaddleVectorConnectionHandler
:
def
__init__
(
self
,
vector_engine
):
"""The PaddleSpeech Vector Server Connection Handler
This connection process every server request
Args:
vector_engine (VectorEngine): The Vector engine
"""
super
().
__init__
()
logger
.
info
(
"Create PaddleVectorConnectionHandler to process the vector request"
)
self
.
vector_engine
=
vector_engine
self
.
executor
=
self
.
vector_engine
.
executor
self
.
task
=
self
.
vector_engine
.
executor
.
task
self
.
model
=
self
.
vector_engine
.
executor
.
model
self
.
config
=
self
.
vector_engine
.
executor
.
config
self
.
_inputs
=
OrderedDict
()
self
.
_outputs
=
OrderedDict
()
@
paddle
.
no_grad
()
def
run
(
self
,
audio_data
,
task
=
"spk"
):
"""The connection process the http request audio
Args:
audio_data (bytes): base64.b64decode
Returns:
str: the punctuation text
"""
logger
.
info
(
f
"start to extract the do vector
{
self
.
task
}
from the http request"
)
if
self
.
task
==
"spk"
and
task
==
"spk"
:
embedding
=
self
.
extract_audio_embedding
(
audio_data
)
return
embedding
else
:
logger
.
error
(
"The request task is not matched with server model task"
)
logger
.
error
(
f
"The server model task is:
{
self
.
task
}
, but the request task is:
{
task
}
"
)
return
np
.
array
([
0.0
,
])
@
paddle
.
no_grad
()
def
get_enroll_test_score
(
self
,
enroll_audio
,
test_audio
):
"""Get the enroll and test audio score
Args:
enroll_audio (str): the base64 format enroll audio
test_audio (str): the base64 format test audio
Returns:
float: the score between enroll and test audio
"""
logger
.
info
(
"start to extract the enroll audio embedding"
)
enroll_emb
=
self
.
extract_audio_embedding
(
enroll_audio
)
logger
.
info
(
"start to extract the test audio embedding"
)
test_emb
=
self
.
extract_audio_embedding
(
test_audio
)
logger
.
info
(
"start to get the score between the enroll and test embedding"
)
score
=
self
.
executor
.
get_embeddings_score
(
enroll_emb
,
test_emb
)
logger
.
info
(
f
"get the enroll vs test score:
{
score
}
"
)
return
score
@
paddle
.
no_grad
()
def
extract_audio_embedding
(
self
,
audio
:
str
,
sample_rate
:
int
=
16000
):
"""extract the audio embedding
Args:
audio (str): the audio data
sample_rate (int, optional): the audio sample rate. Defaults to 16000.
"""
# we can not reuse the cache io.BytesIO(audio) data,
# because the soundfile will change the io.BytesIO(audio) to the end
# thus we should convert the base64 string to io.BytesIO when we need the audio data
if
not
self
.
executor
.
_check
(
io
.
BytesIO
(
audio
),
sample_rate
):
logger
.
info
(
"check the audio sample rate occurs error"
)
return
np
.
array
([
0.0
])
waveform
,
sr
=
load_audio
(
io
.
BytesIO
(
audio
))
logger
.
info
(
f
"load the audio sample points, shape is:
{
waveform
.
shape
}
"
)
# stage 2: get the audio feat
# Note: Now we only support fbank feature
try
:
feats
=
melspectrogram
(
x
=
waveform
,
sr
=
self
.
config
.
sr
,
n_mels
=
self
.
config
.
n_mels
,
window_size
=
self
.
config
.
window_size
,
hop_length
=
self
.
config
.
hop_size
)
logger
.
info
(
f
"extract the audio feats, shape is:
{
feats
.
shape
}
"
)
except
Exception
as
e
:
logger
.
info
(
f
"feats occurs exception
{
e
}
"
)
sys
.
exit
(
-
1
)
feats
=
paddle
.
to_tensor
(
feats
).
unsqueeze
(
0
)
# in inference period, the lengths is all one without padding
lengths
=
paddle
.
ones
([
1
])
# stage 3: we do feature normalize,
# Now we assume that the feats must do normalize
feats
=
feature_normalize
(
feats
,
mean_norm
=
True
,
std_norm
=
False
)
# stage 4: store the feats and length in the _inputs,
# which will be used in other function
logger
.
info
(
f
"feats shape:
{
feats
.
shape
}
"
)
logger
.
info
(
"audio extract the feats success"
)
logger
.
info
(
"start to extract the audio embedding"
)
embedding
=
self
.
model
.
backbone
(
feats
,
lengths
).
squeeze
().
numpy
()
logger
.
info
(
f
"embedding size:
{
embedding
.
shape
}
"
)
return
embedding
class
VectorServerExecutor
(
VectorExecutor
):
def
__init__
(
self
):
"""The wrapper for TextEcutor
"""
super
().
__init__
()
pass
class
VectorEngine
(
BaseEngine
):
def
__init__
(
self
):
"""The Vector Engine
"""
super
(
VectorEngine
,
self
).
__init__
()
logger
.
info
(
"Create the VectorEngine Instance"
)
def
init
(
self
,
config
:
dict
):
"""Init the Vector Engine
Args:
config (dict): The server configuation
Returns:
bool: The engine instance flag
"""
logger
.
info
(
"Init the vector engine"
)
try
:
self
.
config
=
config
if
self
.
config
.
device
:
self
.
device
=
self
.
config
.
device
else
:
self
.
device
=
paddle
.
get_device
()
paddle
.
set_device
(
self
.
device
)
logger
.
info
(
f
"Vector Engine set the device:
{
self
.
device
}
"
)
except
BaseException
as
e
:
logger
.
error
(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger
.
error
(
"Initialize Vector server engine Failed on device: %s."
%
(
self
.
device
))
return
False
self
.
executor
=
VectorServerExecutor
()
self
.
executor
.
_init_from_path
(
model_type
=
config
.
model_type
,
cfg_path
=
config
.
cfg_path
,
ckpt_path
=
config
.
ckpt_path
,
task
=
config
.
task
)
logger
.
info
(
"Init the Vector engine successfully"
)
return
True
paddlespeech/server/restful/api.py
浏览文件 @
f869db7a
...
...
@@ -21,7 +21,7 @@ from paddlespeech.server.restful.asr_api import router as asr_router
from
paddlespeech.server.restful.cls_api
import
router
as
cls_router
from
paddlespeech.server.restful.text_api
import
router
as
text_router
from
paddlespeech.server.restful.tts_api
import
router
as
tts_router
from
paddlespeech.server.restful.vector_api
import
router
as
vec_router
_router
=
APIRouter
()
...
...
@@ -43,6 +43,8 @@ def setup_router(api_list: List):
_router
.
include_router
(
cls_router
)
elif
api_name
==
'text'
:
_router
.
include_router
(
text_router
)
elif
api_name
.
lower
()
==
'vector'
:
_router
.
include_router
(
vec_router
)
else
:
logger
.
error
(
f
"PaddleSpeech has not support such service:
{
api_name
}
"
)
...
...
paddlespeech/server/restful/request.py
浏览文件 @
f869db7a
...
...
@@ -15,7 +15,10 @@ from typing import Optional
from
pydantic
import
BaseModel
__all__
=
[
'ASRRequest'
,
'TTSRequest'
,
'CLSRequest'
]
__all__
=
[
'ASRRequest'
,
'TTSRequest'
,
'CLSRequest'
,
'VectorRequest'
,
'VectorScoreRequest'
]
#****************************************************************************************/
...
...
@@ -85,3 +88,40 @@ class CLSRequest(BaseModel):
#****************************************************************************************/
class
TextRequest
(
BaseModel
):
text
:
str
#****************************************************************************************/
#************************************ Vecotr request ************************************/
#****************************************************************************************/
class
VectorRequest
(
BaseModel
):
"""
request body example
{
"audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"task": "spk",
"audio_format": "wav",
"sample_rate": 16000,
}
"""
audio
:
str
task
:
str
audio_format
:
str
sample_rate
:
int
class
VectorScoreRequest
(
BaseModel
):
"""
request body example
{
"enroll_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"test_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"task": "score",
"audio_format": "wav",
"sample_rate": 16000,
}
"""
enroll_audio
:
str
test_audio
:
str
task
:
str
audio_format
:
str
sample_rate
:
int
paddlespeech/server/restful/response.py
浏览文件 @
f869db7a
...
...
@@ -15,7 +15,10 @@ from typing import List
from
pydantic
import
BaseModel
__all__
=
[
'ASRResponse'
,
'TTSResponse'
,
'CLSResponse'
]
__all__
=
[
'ASRResponse'
,
'TTSResponse'
,
'CLSResponse'
,
'TextResponse'
,
'VectorResponse'
,
'VectorScoreResponse'
]
class
Message
(
BaseModel
):
...
...
@@ -129,6 +132,11 @@ class CLSResponse(BaseModel):
result
:
CLSResult
#****************************************************************************************/
#************************************ Text response **************************************/
#****************************************************************************************/
class
TextResult
(
BaseModel
):
punc_text
:
str
...
...
@@ -153,6 +161,59 @@ class TextResponse(BaseModel):
result
:
TextResult
#****************************************************************************************/
#************************************ Vector response **************************************/
#****************************************************************************************/
class
VectorResult
(
BaseModel
):
vec
:
list
class
VectorResponse
(
BaseModel
):
"""
response example
{
"success": true,
"code": 0,
"message": {
"description": "success"
},
"result": {
"vec": [1.0, 1.0]
}
}
"""
success
:
bool
code
:
int
message
:
Message
result
:
VectorResult
class
VectorScoreResult
(
BaseModel
):
score
:
float
class
VectorScoreResponse
(
BaseModel
):
"""
response example
{
"success": true,
"code": 0,
"message": {
"description": "success"
},
"result": {
"score": 1.0
}
}
"""
success
:
bool
code
:
int
message
:
Message
result
:
VectorScoreResult
#****************************************************************************************/
#********************************** Error response **************************************/
#****************************************************************************************/
...
...
paddlespeech/server/restful/vector_api.py
0 → 100644
浏览文件 @
f869db7a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
base64
import
traceback
from
typing
import
Union
import
numpy
as
np
from
fastapi
import
APIRouter
from
paddlespeech.cli.log
import
logger
from
paddlespeech.server.engine.engine_pool
import
get_engine_pool
from
paddlespeech.server.engine.vector.python.vector_engine
import
PaddleVectorConnectionHandler
from
paddlespeech.server.restful.request
import
VectorRequest
from
paddlespeech.server.restful.request
import
VectorScoreRequest
from
paddlespeech.server.restful.response
import
ErrorResponse
from
paddlespeech.server.restful.response
import
VectorResponse
from
paddlespeech.server.restful.response
import
VectorScoreResponse
from
paddlespeech.server.utils.errors
import
ErrorCode
from
paddlespeech.server.utils.errors
import
failed_response
from
paddlespeech.server.utils.exception
import
ServerBaseException
router
=
APIRouter
()
@
router
.
get
(
'/paddlespeech/vector/help'
)
def
help
():
"""help
Returns:
json: The /paddlespeech/vector api response content
"""
response
=
{
"success"
:
"True"
,
"code"
:
200
,
"message"
:
{
"global"
:
"success"
},
"vector"
:
[
2.3
,
3.5
,
5.5
,
6.2
,
2.8
,
1.2
,
0.3
,
3.6
]
}
return
response
@
router
.
post
(
"/paddlespeech/vector"
,
response_model
=
Union
[
VectorResponse
,
ErrorResponse
])
def
vector
(
request_body
:
VectorRequest
):
"""vector api
Args:
request_body (VectorRequest): the vector request body
Returns:
json: the vector response body
"""
try
:
# 1. get the audio data
# the audio must be base64 format
audio_data
=
base64
.
b64decode
(
request_body
.
audio
)
# 2. get single engine from engine pool
# and we use the vector_engine to create an connection handler to process the request
engine_pool
=
get_engine_pool
()
vector_engine
=
engine_pool
[
'vector'
]
connection_handler
=
PaddleVectorConnectionHandler
(
vector_engine
)
# 3. we use the connection handler to process the audio
audio_vec
=
connection_handler
.
run
(
audio_data
,
request_body
.
task
)
# 4. we need the result of the vector instance be numpy.ndarray
if
not
isinstance
(
audio_vec
,
np
.
ndarray
):
logger
.
error
(
f
"the vector type is not numpy.array, that is:
{
type
(
audio_vec
)
}
"
)
error_reponse
=
ErrorResponse
()
error_reponse
.
message
.
description
=
f
"the vector type is not numpy.array, that is:
{
type
(
audio_vec
)
}
"
return
error_reponse
response
=
{
"success"
:
True
,
"code"
:
200
,
"message"
:
{
"description"
:
"success"
},
"result"
:
{
"vec"
:
audio_vec
.
tolist
()
}
}
except
ServerBaseException
as
e
:
response
=
failed_response
(
e
.
error_code
,
e
.
msg
)
except
BaseException
:
response
=
failed_response
(
ErrorCode
.
SERVER_UNKOWN_ERR
)
traceback
.
print_exc
()
return
response
@
router
.
post
(
"/paddlespeech/vector/score"
,
response_model
=
Union
[
VectorScoreResponse
,
ErrorResponse
])
def
score
(
request_body
:
VectorScoreRequest
):
"""vector api
Args:
request_body (VectorScoreRequest): the punctuation request body
Returns:
json: the punctuation response body
"""
try
:
# 1. get the audio data
# the audio must be base64 format
enroll_data
=
base64
.
b64decode
(
request_body
.
enroll_audio
)
test_data
=
base64
.
b64decode
(
request_body
.
test_audio
)
# 2. get single engine from engine pool
# and we use the vector_engine to create an connection handler to process the request
engine_pool
=
get_engine_pool
()
vector_engine
=
engine_pool
[
'vector'
]
connection_handler
=
PaddleVectorConnectionHandler
(
vector_engine
)
# 3. we use the connection handler to process the audio
score
=
connection_handler
.
get_enroll_test_score
(
enroll_data
,
test_data
)
response
=
{
"success"
:
True
,
"code"
:
200
,
"message"
:
{
"description"
:
"success"
},
"result"
:
{
"score"
:
score
}
}
except
ServerBaseException
as
e
:
response
=
failed_response
(
e
.
error_code
,
e
.
msg
)
except
BaseException
:
response
=
failed_response
(
ErrorCode
.
SERVER_UNKOWN_ERR
)
traceback
.
print_exc
()
return
response
paddlespeech/server/utils/audio_handler.py
浏览文件 @
f869db7a
...
...
@@ -142,6 +142,7 @@ class ASRWsAudioHandler:
return
""
# 1. send websocket handshake protocal
start_time
=
time
.
time
()
async
with
websockets
.
connect
(
self
.
url
)
as
ws
:
# 2. server has already received handshake protocal
# client start to send the command
...
...
@@ -187,7 +188,14 @@ class ASRWsAudioHandler:
if
self
.
punc_server
:
msg
[
"result"
]
=
self
.
punc_server
.
run
(
msg
[
"result"
])
# 6. logging the final result and comptute the statstics
elapsed_time
=
time
.
time
()
-
start_time
audio_info
=
soundfile
.
info
(
wavfile_path
)
logger
.
info
(
"client final receive msg={}"
.
format
(
msg
))
logger
.
info
(
f
"audio duration:
{
audio_info
.
duration
}
, elapsed time:
{
elapsed_time
}
, RTF=
{
elapsed_time
/
audio_info
.
duration
}
"
)
result
=
msg
return
result
...
...
@@ -456,3 +464,96 @@ class TTSHttpHandler:
self
.
stream
.
stop_stream
()
self
.
stream
.
close
()
self
.
p
.
terminate
()
class
VectorHttpHandler
:
def
__init__
(
self
,
server_ip
=
None
,
port
=
None
):
"""The Vector client http request
Args:
server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1".
port (int, optional): the http vector server port. Defaults to 8090.
"""
super
().
__init__
()
self
.
server_ip
=
server_ip
self
.
port
=
port
if
server_ip
is
None
or
port
is
None
:
self
.
url
=
None
else
:
self
.
url
=
'http://'
+
self
.
server_ip
+
":"
+
str
(
self
.
port
)
+
'/paddlespeech/vector'
def
run
(
self
,
input
,
audio_format
,
sample_rate
,
task
=
"spk"
):
"""Call the http asr to process the audio
Args:
input (str): the audio file path
audio_format (str): the audio format
sample_rate (str): the audio sample rate
Returns:
list: the audio vector
"""
if
self
.
url
is
None
:
logger
.
error
(
"No vector server, please input valid ip and port"
)
return
""
audio
=
wav2base64
(
input
)
data
=
{
"audio"
:
audio
,
"task"
:
task
,
"audio_format"
:
audio_format
,
"sample_rate"
:
sample_rate
,
}
logger
.
info
(
self
.
url
)
res
=
requests
.
post
(
url
=
self
.
url
,
data
=
json
.
dumps
(
data
))
return
res
.
json
()
class
VectorScoreHttpHandler
:
def
__init__
(
self
,
server_ip
=
None
,
port
=
None
):
"""The Vector score client http request
Args:
server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1".
port (int, optional): the http vector server port. Defaults to 8090.
"""
super
().
__init__
()
self
.
server_ip
=
server_ip
self
.
port
=
port
if
server_ip
is
None
or
port
is
None
:
self
.
url
=
None
else
:
self
.
url
=
'http://'
+
self
.
server_ip
+
":"
+
str
(
self
.
port
)
+
'/paddlespeech/vector/score'
def
run
(
self
,
enroll_audio
,
test_audio
,
audio_format
,
sample_rate
):
"""Call the http asr to process the audio
Args:
input (str): the audio file path
audio_format (str): the audio format
sample_rate (str): the audio sample rate
Returns:
list: the audio vector
"""
if
self
.
url
is
None
:
logger
.
error
(
"No vector server, please input valid ip and port"
)
return
""
enroll_audio
=
wav2base64
(
enroll_audio
)
test_audio
=
wav2base64
(
test_audio
)
data
=
{
"enroll_audio"
:
enroll_audio
,
"test_audio"
:
test_audio
,
"task"
:
"score"
,
"audio_format"
:
audio_format
,
"sample_rate"
:
sample_rate
,
}
res
=
requests
.
post
(
url
=
self
.
url
,
data
=
json
.
dumps
(
data
))
return
res
.
json
()
paddlespeech/server/ws/asr_socket.py
浏览文件 @
f869db7a
...
...
@@ -78,12 +78,14 @@ async def websocket_endpoint(websocket: WebSocket):
connection_handler
.
decode
(
is_finished
=
True
)
connection_handler
.
rescoring
()
asr_results
=
connection_handler
.
get_result
()
word_time_stamp
=
connection_handler
.
get_word_time_stamp
()
connection_handler
.
reset
()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'result'
:
asr_results
'result'
:
asr_results
,
'times'
:
word_time_stamp
}
await
websocket
.
send_json
(
resp
)
break
...
...
speechx/examples/ds2_ol/aishell/run.sh
浏览文件 @
f869db7a
...
...
@@ -155,7 +155,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
--wav_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/
${
aishell_wav_scp
}
\
--cmvn_file
=
$cmvn
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
--to_float32
=
true
\
--streaming_chunk
=
30
\
--param_path
=
$model_dir
/avg_1.jit.pdiparams
\
--word_symbol_table
=
$wfst
/words.txt
\
...
...
speechx/examples/ds2_ol/decoder/recognizer_test_main.cc
浏览文件 @
f869db7a
...
...
@@ -19,6 +19,7 @@
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
...
...
@@ -30,7 +31,8 @@ int main(int argc, char* argv[]) {
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
FLAGS_wav_rspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
int
sample_rate
=
16000
;
int
sample_rate
=
FLAGS_sample_rate
;
float
streaming_chunk
=
FLAGS_streaming_chunk
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
LOG
(
INFO
)
<<
"sr: "
<<
sample_rate
;
...
...
speechx/examples/ds2_ol/feat/compute_fbank_main.cc
浏览文件 @
f869db7a
...
...
@@ -69,6 +69,7 @@ int main(int argc, char* argv[]) {
feat_cache_opts
.
frame_chunk_stride
=
1
;
feat_cache_opts
.
frame_chunk_size
=
1
;
ppspeech
::
FeatureCache
feature_cache
(
feat_cache_opts
,
std
::
move
(
cmvn
));
LOG
(
INFO
)
<<
"fbank: "
<<
true
;
LOG
(
INFO
)
<<
"feat dim: "
<<
feature_cache
.
Dim
();
int
sample_rate
=
16000
;
...
...
speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc
浏览文件 @
f869db7a
...
...
@@ -56,6 +56,7 @@ int main(int argc, char* argv[]) {
opt
.
frame_opts
.
remove_dc_offset
=
false
;
opt
.
frame_opts
.
window_type
=
"hanning"
;
opt
.
frame_opts
.
preemph_coeff
=
0.0
;
LOG
(
INFO
)
<<
"linear feature: "
<<
true
;
LOG
(
INFO
)
<<
"frame length (ms): "
<<
opt
.
frame_opts
.
frame_length_ms
;
LOG
(
INFO
)
<<
"frame shift (ms): "
<<
opt
.
frame_opts
.
frame_shift_ms
;
...
...
@@ -77,7 +78,7 @@ int main(int argc, char* argv[]) {
int
sample_rate
=
16000
;
float
streaming_chunk
=
FLAGS_streaming_chunk
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
LOG
(
INFO
)
<<
"s
r
: "
<<
sample_rate
;
LOG
(
INFO
)
<<
"s
ample rate
: "
<<
sample_rate
;
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
...
...
speechx/examples/ds2_ol/websocket/websocket_server.sh
浏览文件 @
f869db7a
...
...
@@ -63,7 +63,6 @@ websocket_server_main \
--cmvn_file
=
$cmvn
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
--streaming_chunk
=
0.1
\
--to_float32
=
true
\
--param_path
=
$model_dir
/avg_1.jit.pdiparams
\
--word_symbol_table
=
$wfst
/words.txt
\
--model_output_names
=
softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0
\
...
...
speechx/speechx/decoder/param.h
浏览文件 @
f869db7a
...
...
@@ -19,23 +19,24 @@
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
// feature
DEFINE_bool
(
use_fbank
,
false
,
"False for fbank; or linear feature"
);
// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear
// feature, or fbank");
DEFINE_int32
(
num_bins
,
161
,
"num bins of mel"
);
DEFINE_string
(
cmvn_file
,
""
,
"read cmvn"
);
DEFINE_double
(
streaming_chunk
,
0.1
,
"streaming feature chunk size"
);
DEFINE_bool
(
to_float32
,
true
,
"audio convert to pcm32"
);
DEFINE_string
(
model_path
,
"avg_1.jit.pdmodel"
,
"paddle nnet model"
);
DEFINE_string
(
param_path
,
"avg_1.jit.pdiparams"
,
"paddle nnet model param"
);
DEFINE_string
(
word_symbol_table
,
"words.txt"
,
"word symbol table"
);
DEFINE_string
(
graph_path
,
"TLG"
,
"decoder graph"
);
DEFINE_double
(
acoustic_scale
,
1.0
,
"acoustic scale"
);
DEFINE_int32
(
max_active
,
7500
,
"max active"
);
DEFINE_double
(
beam
,
15.0
,
"decoder beam"
);
DEFINE_double
(
lattice_beam
,
7.5
,
"decoder beam"
);
// feature sliding window
DEFINE_int32
(
receptive_field_length
,
7
,
"receptive field of two CNN(kernel=5) downsampling module."
);
DEFINE_int32
(
downsampling_rate
,
4
,
"two CNN(kernel=5) module downsampling rate."
);
// nnet
DEFINE_string
(
model_path
,
"avg_1.jit.pdmodel"
,
"paddle nnet model"
);
DEFINE_string
(
param_path
,
"avg_1.jit.pdiparams"
,
"paddle nnet model param"
);
DEFINE_string
(
model_input_names
,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"
,
...
...
@@ -47,8 +48,14 @@ DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box"
,
"model cache names"
);
DEFINE_string
(
model_cache_shapes
,
"5-1-1024,5-1-1024"
,
"model cache shapes"
);
DEFINE_bool
(
use_fbank
,
false
,
"use fbank or linear feature"
);
DEFINE_int32
(
num_bins
,
161
,
"num bins of mel"
);
// decoder
DEFINE_string
(
word_symbol_table
,
"words.txt"
,
"word symbol table"
);
DEFINE_string
(
graph_path
,
"TLG"
,
"decoder graph"
);
DEFINE_double
(
acoustic_scale
,
1.0
,
"acoustic scale"
);
DEFINE_int32
(
max_active
,
7500
,
"max active"
);
DEFINE_double
(
beam
,
15.0
,
"decoder beam"
);
DEFINE_double
(
lattice_beam
,
7.5
,
"decoder beam"
);
namespace
ppspeech
{
// todo refactor later
...
...
@@ -56,22 +63,23 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions
opts
;
opts
.
cmvn_file
=
FLAGS_cmvn_file
;
opts
.
linear_spectrogram_opts
.
streaming_chunk
=
FLAGS_streaming_chunk
;
opts
.
to_float32
=
FLAGS_to_float32
;
kaldi
::
FrameExtractionOptions
frame_opts
;
frame_opts
.
dither
=
0.0
;
frame_opts
.
frame_shift_ms
=
10
;
opts
.
use_fbank
=
FLAGS_use_fbank
;
if
(
opts
.
use_fbank
)
{
frame_opts
.
window_type
=
"povey"
;
frame_opts
.
frame_length_ms
=
25
;
opts
.
fbank_opts
.
fbank_opts
.
mel_opts
.
num_bins
=
FLAGS_num_bins
;
opts
.
fbank_opts
.
fbank_opts
.
frame_opts
=
frame_opts
;
opts
.
to_float32
=
false
;
frame_opts
.
window_type
=
"povey"
;
frame_opts
.
frame_length_ms
=
25
;
opts
.
fbank_opts
.
fbank_opts
.
mel_opts
.
num_bins
=
FLAGS_num_bins
;
opts
.
fbank_opts
.
fbank_opts
.
frame_opts
=
frame_opts
;
}
else
{
frame_opts
.
remove_dc_offset
=
false
;
frame_opts
.
frame_length_ms
=
20
;
frame_opts
.
window_type
=
"hanning"
;
frame_opts
.
preemph_coeff
=
0.0
;
opts
.
linear_spectrogram_opts
.
frame_opts
=
frame_opts
;
opts
.
to_float32
=
true
;
frame_opts
.
remove_dc_offset
=
false
;
frame_opts
.
frame_length_ms
=
20
;
frame_opts
.
window_type
=
"hanning"
;
frame_opts
.
preemph_coeff
=
0.0
;
opts
.
linear_spectrogram_opts
.
frame_opts
=
frame_opts
;
}
opts
.
feature_cache_opts
.
frame_chunk_size
=
FLAGS_receptive_field_length
;
opts
.
feature_cache_opts
.
frame_chunk_stride
=
FLAGS_downsampling_rate
;
...
...
speechx/speechx/frontend/audio/fbank.cc
浏览文件 @
f869db7a
...
...
@@ -102,13 +102,16 @@ bool Fbank::Compute(const Vector<BaseFloat>& waves, Vector<BaseFloat>* feats) {
// note: this online feature-extraction code does not support VTLN.
RealFft
(
&
window
,
true
);
kaldi
::
ComputePowerSpectrum
(
&
window
);
const
kaldi
::
MelBanks
&
mel_bank
=
*
(
computer_
.
GetMelBanks
(
1.0
));
SubVector
<
BaseFloat
>
power_spectrum
(
window
,
0
,
window
.
Dim
()
/
2
+
1
);
const
kaldi
::
MelBanks
&
mel_bank
=
*
(
computer_
.
GetMelBanks
(
1.0
));
SubVector
<
BaseFloat
>
power_spectrum
(
window
,
0
,
window
.
Dim
()
/
2
+
1
);
if
(
!
opts_
.
fbank_opts
.
use_power
)
{
power_spectrum
.
ApplyPow
(
0.5
);
}
int32
mel_offset
=
((
opts_
.
fbank_opts
.
use_energy
&&
!
opts_
.
fbank_opts
.
htk_compat
)
?
1
:
0
);
SubVector
<
BaseFloat
>
mel_energies
(
this_feature
,
mel_offset
,
opts_
.
fbank_opts
.
mel_opts
.
num_bins
);
int32
mel_offset
=
((
opts_
.
fbank_opts
.
use_energy
&&
!
opts_
.
fbank_opts
.
htk_compat
)
?
1
:
0
);
SubVector
<
BaseFloat
>
mel_energies
(
this_feature
,
mel_offset
,
opts_
.
fbank_opts
.
mel_opts
.
num_bins
);
mel_bank
.
Compute
(
power_spectrum
,
&
mel_energies
);
mel_energies
.
ApplyFloor
(
1e-07
);
mel_energies
.
ApplyLog
();
...
...
speechx/speechx/frontend/audio/feature_pipeline.cc
浏览文件 @
f869db7a
...
...
@@ -23,13 +23,13 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
new
ppspeech
::
AudioCache
(
1000
*
kint16max
,
opts
.
to_float32
));
unique_ptr
<
FrontendInterface
>
base_feature
;
if
(
opts
.
use_fbank
)
{
base_feature
.
reset
(
new
ppspeech
::
Fbank
(
opts
.
fbank_opts
,
std
::
move
(
data_source
)));
base_feature
.
reset
(
new
ppspeech
::
Fbank
(
opts
.
fbank_opts
,
std
::
move
(
data_source
)));
}
else
{
base_feature
.
reset
(
new
ppspeech
::
LinearSpectrogram
(
opts
.
linear_spectrogram_opts
,
std
::
move
(
data_source
)));
base_feature
.
reset
(
new
ppspeech
::
LinearSpectrogram
(
opts
.
linear_spectrogram_opts
,
std
::
move
(
data_source
)));
}
unique_ptr
<
FrontendInterface
>
cmvn
(
...
...
speechx/speechx/frontend/audio/feature_pipeline.h
浏览文件 @
f869db7a
...
...
@@ -18,25 +18,25 @@
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/fbank.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/fbank.h"
#include "frontend/audio/normalizer.h"
namespace
ppspeech
{
struct
FeaturePipelineOptions
{
std
::
string
cmvn_file
;
bool
to_float32
;
bool
to_float32
;
// true, only for linear feature
bool
use_fbank
;
LinearSpectrogramOptions
linear_spectrogram_opts
;
FbankOptions
fbank_opts
;
FeatureCacheOptions
feature_cache_opts
;
FeaturePipelineOptions
()
:
cmvn_file
(
""
),
to_float32
(
false
),
use_fbank
(
fals
e
),
to_float32
(
false
),
// true, only for linear feature
use_fbank
(
tru
e
),
linear_spectrogram_opts
(),
fbank_opts
(),
feature_cache_opts
()
{}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录