未验证 提交 43582f50 编写于 作者: Honei_X's avatar Honei_X 提交者: GitHub

Merge branch 'develop' into asr_time

......@@ -14,7 +14,7 @@ see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/doc
You can choose one way from easy, meduim and hard to install paddlespeech.
### 2. Prepare Input File
The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
The input of this cli demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
Here are sample files for this demo that can be downloaded:
```bash
......
......@@ -4,16 +4,16 @@
## 介绍
声纹识别是一项用计算机程序自动提取说话人特征的技术。
这个 demo 是一个从给定音频文件提取说话人特征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。
这个 demo 是从一个给定音频文件中提取说话人特征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。
## 使用方法
### 1. 安装
请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)
你可以从 easy,medium,hard 三中方式中选择一种方式安装。
你可以从easy medium,hard 三种方式中选择一种方式安装。
### 2. 准备输入
这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
声纹cli demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
可以下载此 demo 的示例音频:
```bash
......
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8190
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_python']
# protocol = ['http'] (only one can be selected).
# http only support offline engine type.
protocol: 'http'
engine_list: ['text_python']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### Text #########################################
################### text task: punc; engine_type: python #######################
text_python:
task: punc
model_type: 'ernie_linear_p3_wudao'
lang: 'zh'
sample_rate: 16000
cfg_path: # [optional]
ckpt_path: # [optional]
vocab_file: # [optional]
device: 'cpu' # set 'gpu:id' or 'cpu'
......@@ -4,7 +4,7 @@
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
port: 8290
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online']
......@@ -29,7 +29,7 @@ asr_online:
cfg_path:
decode_method:
force_yes: True
device: # cpu or gpu:id
device: 'cpu' # cpu or gpu:id
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
......@@ -42,4 +42,4 @@ asr_online:
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2
\ No newline at end of file
sample_width: 2
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from paddlespeech.cli.log import logger
from paddlespeech.server.bin.paddlespeech_server import ServerExecutor
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='paddlespeech_server.start', add_help=True)
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default=None,
required=True)
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
logger.info("start to parse the args")
args = parser.parse_args()
logger.info("start to launch the punctuation server")
punc_server = ServerExecutor()
punc_server(config_file=args.config_file, log_file=args.log_file)
export CUDA_VISIBLE_DEVICE=0,1,2,3
nohup python3 punc_server.py --config_file conf/punc_application.yaml > punc.log 2>&1 &
nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from paddlespeech.cli.log import logger
from paddlespeech.server.bin.paddlespeech_server import ServerExecutor
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='paddlespeech_server.start', add_help=True)
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default=None,
required=True)
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
logger.info("start to parse the args")
args = parser.parse_args()
logger.info("start to launch the streaming asr server")
streaming_asr_server = ServerExecutor()
streaming_asr_server(config_file=args.config_file, log_file=args.log_file)
# download the test wav
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
# read the wav and pass it to service
python3 websocket_client.py --wavfile ./zh.wav
# read the wav and pass it to only streaming asr service
python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
# read the wav and call streaming and punc service
python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
......@@ -28,6 +28,7 @@ def main(args):
handler = ASRWsAudioHandler(
args.server_ip,
args.port,
endpoint=args.endpoint,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
loop = asyncio.get_event_loop()
......@@ -69,7 +70,11 @@ if __name__ == "__main__":
default=8091,
dest="punc_server_port",
help='Punctuation server port')
parser.add_argument(
"--endpoint",
type=str,
default="/paddlespeech/asr/streaming",
help="ASR websocket endpoint")
parser.add_argument(
"--wavfile",
action="store",
......
......@@ -13,6 +13,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks
- Apache-2.0 License
- U2 model
- Building TLG based Graph
- websocket server & client
* [kaldi](https://github.com/kaldi-asr/kaldi/blob/master/COPYING)
- Apache-2.0 License
......
......@@ -10,7 +10,7 @@ encoder_conf:
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
......@@ -30,7 +30,7 @@ decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
......@@ -39,7 +39,7 @@ model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform'
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
......
......@@ -37,7 +37,7 @@ model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform'
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
......
......@@ -10,7 +10,7 @@ encoder_conf:
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
......@@ -21,7 +21,7 @@ decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
......
......@@ -272,7 +272,8 @@ class VectorExecutor(BaseExecutor):
model_type: str='ecapatdnn_voxceleb12',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None):
ckpt_path: Optional[os.PathLike]=None,
task=None):
"""Init the neural network from the model path
Args:
......@@ -284,8 +285,10 @@ class VectorExecutor(BaseExecutor):
Defaults to None.
ckpt_path (Optional[os.PathLike], optional): the pretrained model path, which is stored in the disk.
Defaults to None.
task (str, optional): the model task type
"""
# stage 0: avoid to init the mode again
self.task = task
if hasattr(self, "model"):
logger.info("Model has been initialized")
return
......@@ -434,6 +437,7 @@ class VectorExecutor(BaseExecutor):
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(
"invalid sample rate, please input --sr 8000 or --sr 16000")
logger.error(f"The model sample rate: {self.sample_rate}, the external sample rate is: {sample_rate}")
return False
if isinstance(audio_file, (str, os.PathLike)):
......
......@@ -63,3 +63,23 @@ paddlespeech_server start --config_file conf/tts_online_application.yaml
```
paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --input "您好,欢迎使用百度飞桨深度学习框架!" --output output.wav
```
## 声纹识别
### 启动声纹识别服务
```
paddlespeech_server start --config_file conf/vector_application.yaml
```
### 获取说话人音频声纹
```
paddlespeech_client vector --task spk --server_ip 127.0.0.1 --port 8090 --input 85236145389.wav
```
### 两个说话人音频声纹打分
```
paddlespeech_client vector --task score --server_ip 127.0.0.1 --port 8090 --enroll 123456789.wav --test 85236145389.wav
```
\ No newline at end of file
......@@ -35,7 +35,7 @@ from paddlespeech.server.utils.util import wav2base64
__all__ = [
'TTSClientExecutor', 'TTSOnlineClientExecutor', 'ASRClientExecutor',
'ASROnlineClientExecutor', 'CLSClientExecutor'
'ASROnlineClientExecutor', 'CLSClientExecutor', 'VectorClientExecutor'
]
......@@ -411,6 +411,18 @@ class ASROnlineClientExecutor(BaseExecutor):
'--lang', type=str, default="zh_cn", help='language')
self.parser.add_argument(
'--audio_format', type=str, default="wav", help='audio format')
self.parser.add_argument(
'--punc.server_ip',
type=str,
default=None,
dest="punc_server_ip",
help='Punctuation server ip')
self.parser.add_argument(
'--punc.port',
type=int,
default=8190,
dest="punc_server_port",
help='Punctuation server port')
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
......@@ -428,7 +440,9 @@ class ASROnlineClientExecutor(BaseExecutor):
port=port,
sample_rate=sample_rate,
lang=lang,
audio_format=audio_format)
audio_format=audio_format,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
time_end = time.time()
logger.info(res)
logger.info("Response time %f s." % (time_end - time_start))
......@@ -445,12 +459,30 @@ class ASROnlineClientExecutor(BaseExecutor):
port: int=8091,
sample_rate: int=16000,
lang: str="zh_cn",
audio_format: str="wav"):
"""
Python API to call an executor.
audio_format: str="wav",
punc_server_ip: str=None,
punc_server_port: str=None):
"""Python API to call asr online executor.
Args:
input (str): the audio file to be send to streaming asr service.
server_ip (str, optional): streaming asr server ip. Defaults to "127.0.0.1".
port (int, optional): streaming asr server port. Defaults to 8091.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
lang (str, optional): audio language type. Defaults to "zh_cn".
audio_format (str, optional): audio format. Defaults to "wav".
punc_server_ip (str, optional): punctuation server ip. Defaults to None.
punc_server_port (str, optional): punctuation server port. Defaults to None.
Returns:
str: the audio text
"""
logger.info("asr websocket client start")
handler = ASRWsAudioHandler(server_ip, port)
handler = ASRWsAudioHandler(
server_ip,
port,
punc_server_ip=punc_server_ip,
punc_server_port=punc_server_port)
loop = asyncio.get_event_loop()
res = loop.run_until_complete(handler.run(input))
logger.info("asr websocket client finished")
......@@ -583,3 +615,108 @@ class TextClientExecutor(BaseExecutor):
response_dict = res.json()
punc_text = response_dict["result"]["punc_text"]
return punc_text
@cli_client_register(
name='paddlespeech_client.vector', description='visit the vector service')
class VectorClientExecutor(BaseExecutor):
def __init__(self):
super(VectorClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.vector', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8090, help='server port')
self.parser.add_argument(
'--input',
type=str,
default=None,
help='sentence to be process by text server.')
self.parser.add_argument(
'--task',
type=str,
default="spk",
choices=["spk", "score"],
help="The vector service task")
self.parser.add_argument(
"--enroll", type=str, default=None, help="The enroll audio")
self.parser.add_argument(
"--test", type=str, default=None, help="The test audio")
def execute(self, argv: List[str]) -> bool:
"""Execute the request from the argv.
Args:
argv (List): the request arguments
Returns:
str: the request flag
"""
args = self.parser.parse_args(argv)
input_ = args.input
server_ip = args.server_ip
port = args.port
task = args.task
try:
time_start = time.time()
res = self(
input=input_,
server_ip=server_ip,
port=port,
enroll_audio=args.enroll,
test_audio=args.test,
task=task)
time_end = time.time()
logger.info(f"The vector: {res}")
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to extract vector.")
logger.error(e)
return False
@stats_wrapper
def __call__(self,
input: str,
server_ip: str="127.0.0.1",
port: int=8090,
audio_format: str="wav",
sample_rate: int=16000,
enroll_audio: str=None,
test_audio: str=None,
task="spk"):
"""
Python API to call text executor.
Args:
input (str): the request audio data
server_ip (str, optional): the server ip. Defaults to "127.0.0.1".
port (int, optional): the server port. Defaults to 8090.
audio_format (str, optional): audio format. Defaults to "wav".
sample_rate (str, optional): audio sample rate. Defaults to 16000.
enroll_audio (str, optional): enroll audio data. Defaults to None.
test_audio (str, optional): test audio data. Defaults to None.
task (str, optional): the task type, "spk" or "socre". Defaults to "spk"
Returns:
str: the audio embedding or score between enroll and test audio
"""
if task == "spk":
from paddlespeech.server.utils.audio_handler import VectorHttpHandler
logger.info("vector http client start")
logger.info(f"the input audio: {input}")
handler = VectorHttpHandler(server_ip=server_ip, port=port)
res = handler.run(input, audio_format, sample_rate)
return res
elif task == "score":
from paddlespeech.server.utils.audio_handler import VectorScoreHttpHandler
logger.info("vector score http client start")
logger.info(
f"enroll audio: {enroll_audio}, test audio: {test_audio}")
handler = VectorScoreHttpHandler(server_ip=server_ip, port=port)
res = handler.run(enroll_audio, test_audio, audio_format,
sample_rate)
logger.info(f"The vector score is: {res}")
else:
logger.error(f"Sorry, we have not support such task {task}")
......@@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket', 'http'] (only one can be selected).
# http only support offline engine type.
protocol: 'http'
engine_list: ['asr_python', 'tts_python', 'cls_python', 'text_python']
engine_list: ['asr_python', 'tts_python', 'cls_python', 'text_python', 'vector_python']
#################################################################################
......@@ -166,4 +166,15 @@ text_python:
cfg_path: # [optional]
ckpt_path: # [optional]
vocab_file: # [optional]
device: # set 'gpu:id' or 'cpu'
################################### Vector ######################################
################### Vector task: spk; engine_type: python #######################
vector_python:
task: spk
model_type: 'ecapatdnn_voxceleb12'
sample_rate: 16000
cfg_path: # [optional]
ckpt_path: # [optional]
device: # set 'gpu:id' or 'cpu'
\ No newline at end of file
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# protocol = ['http'] (only one can be selected).
# http only support offline engine type.
protocol: 'http'
engine_list: ['vector_python']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### Vector ######################################
################### Vector task: spk; engine_type: python #######################
vector_python:
task: spk
model_type: 'ecapatdnn_voxceleb12'
sample_rate: 16000
cfg_path: # [optional]
ckpt_path: # [optional]
device: # set 'gpu:id' or 'cpu'
......@@ -13,6 +13,7 @@
# limitations under the License.
import copy
import os
import time
from typing import Optional
import numpy as np
......@@ -153,6 +154,12 @@ class PaddleASRConnectionHanddler:
self.n_shift = self.preprocess_conf.process[0]['n_shift']
def extract_feat(self, samples):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
# if self.first_char_occur_elapsed is not None:
# self.first_char_occur_elapsed = time.time()
if "deepspeech2online" in self.model_type:
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
......@@ -291,6 +298,7 @@ class PaddleASRConnectionHanddler:
self.global_frame_offset = 0
self.result_transcripts = ['']
self.word_time_stamp = None
self.first_char_occur_elapsed = None
def decode(self, is_finished=False):
if "deepspeech2online" in self.model_type:
......
......@@ -49,5 +49,8 @@ class EngineFactory(object):
elif engine_name.lower() == 'text' and engine_type.lower() == 'python':
from paddlespeech.server.engine.text.python.text_engine import TextEngine
return TextEngine()
elif engine_name.lower() == 'vector' and engine_type.lower() == 'python':
from paddlespeech.server.engine.vector.python.vector_engine import VectorEngine
return VectorEngine()
else:
return None
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from collections import OrderedDict
import numpy as np
import paddle
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.cli.log import logger
from paddlespeech.cli.vector.infer import VectorExecutor
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.vector.io.batch import feature_normalize
class PaddleVectorConnectionHandler:
def __init__(self, vector_engine):
"""The PaddleSpeech Vector Server Connection Handler
This connection process every server request
Args:
vector_engine (VectorEngine): The Vector engine
"""
super().__init__()
logger.info(
"Create PaddleVectorConnectionHandler to process the vector request")
self.vector_engine = vector_engine
self.executor = self.vector_engine.executor
self.task = self.vector_engine.executor.task
self.model = self.vector_engine.executor.model
self.config = self.vector_engine.executor.config
self._inputs = OrderedDict()
self._outputs = OrderedDict()
@paddle.no_grad()
def run(self, audio_data, task="spk"):
"""The connection process the http request audio
Args:
audio_data (bytes): base64.b64decode
Returns:
str: the punctuation text
"""
logger.info(
f"start to extract the do vector {self.task} from the http request")
if self.task == "spk" and task == "spk":
embedding = self.extract_audio_embedding(audio_data)
return embedding
else:
logger.error(
"The request task is not matched with server model task")
logger.error(
f"The server model task is: {self.task}, but the request task is: {task}"
)
return np.array([
0.0,
])
@paddle.no_grad()
def get_enroll_test_score(self, enroll_audio, test_audio):
"""Get the enroll and test audio score
Args:
enroll_audio (str): the base64 format enroll audio
test_audio (str): the base64 format test audio
Returns:
float: the score between enroll and test audio
"""
logger.info("start to extract the enroll audio embedding")
enroll_emb = self.extract_audio_embedding(enroll_audio)
logger.info("start to extract the test audio embedding")
test_emb = self.extract_audio_embedding(test_audio)
logger.info(
"start to get the score between the enroll and test embedding")
score = self.executor.get_embeddings_score(enroll_emb, test_emb)
logger.info(f"get the enroll vs test score: {score}")
return score
@paddle.no_grad()
def extract_audio_embedding(self, audio: str, sample_rate: int=16000):
"""extract the audio embedding
Args:
audio (str): the audio data
sample_rate (int, optional): the audio sample rate. Defaults to 16000.
"""
# we can not reuse the cache io.BytesIO(audio) data,
# because the soundfile will change the io.BytesIO(audio) to the end
# thus we should convert the base64 string to io.BytesIO when we need the audio data
if not self.executor._check(io.BytesIO(audio), sample_rate):
logger.info("check the audio sample rate occurs error")
return np.array([0.0])
waveform, sr = load_audio(io.BytesIO(audio))
logger.info(f"load the audio sample points, shape is: {waveform.shape}")
# stage 2: get the audio feat
# Note: Now we only support fbank feature
try:
feats = melspectrogram(
x=waveform,
sr=self.config.sr,
n_mels=self.config.n_mels,
window_size=self.config.window_size,
hop_length=self.config.hop_size)
logger.info(f"extract the audio feats, shape is: {feats.shape}")
except Exception as e:
logger.info(f"feats occurs exception {e}")
sys.exit(-1)
feats = paddle.to_tensor(feats).unsqueeze(0)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
# stage 3: we do feature normalize,
# Now we assume that the feats must do normalize
feats = feature_normalize(feats, mean_norm=True, std_norm=False)
# stage 4: store the feats and length in the _inputs,
# which will be used in other function
logger.info(f"feats shape: {feats.shape}")
logger.info("audio extract the feats success")
logger.info("start to extract the audio embedding")
embedding = self.model.backbone(feats, lengths).squeeze().numpy()
logger.info(f"embedding size: {embedding.shape}")
return embedding
class VectorServerExecutor(VectorExecutor):
def __init__(self):
"""The wrapper for TextEcutor
"""
super().__init__()
pass
class VectorEngine(BaseEngine):
def __init__(self):
"""The Vector Engine
"""
super(VectorEngine, self).__init__()
logger.info("Create the VectorEngine Instance")
def init(self, config: dict):
"""Init the Vector Engine
Args:
config (dict): The server configuation
Returns:
bool: The engine instance flag
"""
logger.info("Init the vector engine")
try:
self.config = config
if self.config.device:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
logger.info(f"Vector Engine set the device: {self.device}")
except BaseException as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize Vector server engine Failed on device: %s."
% (self.device))
return False
self.executor = VectorServerExecutor()
self.executor._init_from_path(
model_type=config.model_type,
cfg_path=config.cfg_path,
ckpt_path=config.ckpt_path,
task=config.task)
logger.info("Init the Vector engine successfully")
return True
......@@ -21,7 +21,7 @@ from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.text_api import router as text_router
from paddlespeech.server.restful.tts_api import router as tts_router
from paddlespeech.server.restful.vector_api import router as vec_router
_router = APIRouter()
......@@ -43,6 +43,8 @@ def setup_router(api_list: List):
_router.include_router(cls_router)
elif api_name == 'text':
_router.include_router(text_router)
elif api_name.lower() == 'vector':
_router.include_router(vec_router)
else:
logger.error(
f"PaddleSpeech has not support such service: {api_name}")
......
......@@ -15,7 +15,10 @@ from typing import Optional
from pydantic import BaseModel
__all__ = ['ASRRequest', 'TTSRequest', 'CLSRequest']
__all__ = [
'ASRRequest', 'TTSRequest', 'CLSRequest', 'VectorRequest',
'VectorScoreRequest'
]
#****************************************************************************************/
......@@ -85,3 +88,40 @@ class CLSRequest(BaseModel):
#****************************************************************************************/
class TextRequest(BaseModel):
text: str
#****************************************************************************************/
#************************************ Vecotr request ************************************/
#****************************************************************************************/
class VectorRequest(BaseModel):
"""
request body example
{
"audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"task": "spk",
"audio_format": "wav",
"sample_rate": 16000,
}
"""
audio: str
task: str
audio_format: str
sample_rate: int
class VectorScoreRequest(BaseModel):
"""
request body example
{
"enroll_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"test_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"task": "score",
"audio_format": "wav",
"sample_rate": 16000,
}
"""
enroll_audio: str
test_audio: str
task: str
audio_format: str
sample_rate: int
......@@ -15,7 +15,10 @@ from typing import List
from pydantic import BaseModel
__all__ = ['ASRResponse', 'TTSResponse', 'CLSResponse']
__all__ = [
'ASRResponse', 'TTSResponse', 'CLSResponse', 'TextResponse',
'VectorResponse', 'VectorScoreResponse'
]
class Message(BaseModel):
......@@ -129,6 +132,11 @@ class CLSResponse(BaseModel):
result: CLSResult
#****************************************************************************************/
#************************************ Text response **************************************/
#****************************************************************************************/
class TextResult(BaseModel):
punc_text: str
......@@ -153,6 +161,59 @@ class TextResponse(BaseModel):
result: TextResult
#****************************************************************************************/
#************************************ Vector response **************************************/
#****************************************************************************************/
class VectorResult(BaseModel):
vec: list
class VectorResponse(BaseModel):
"""
response example
{
"success": true,
"code": 0,
"message": {
"description": "success"
},
"result": {
"vec": [1.0, 1.0]
}
}
"""
success: bool
code: int
message: Message
result: VectorResult
class VectorScoreResult(BaseModel):
score: float
class VectorScoreResponse(BaseModel):
"""
response example
{
"success": true,
"code": 0,
"message": {
"description": "success"
},
"result": {
"score": 1.0
}
}
"""
success: bool
code: int
message: Message
result: VectorScoreResult
#****************************************************************************************/
#********************************** Error response **************************************/
#****************************************************************************************/
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import traceback
from typing import Union
import numpy as np
from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.engine.vector.python.vector_engine import PaddleVectorConnectionHandler
from paddlespeech.server.restful.request import VectorRequest
from paddlespeech.server.restful.request import VectorScoreRequest
from paddlespeech.server.restful.response import ErrorResponse
from paddlespeech.server.restful.response import VectorResponse
from paddlespeech.server.restful.response import VectorScoreResponse
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.errors import failed_response
from paddlespeech.server.utils.exception import ServerBaseException
router = APIRouter()
@router.get('/paddlespeech/vector/help')
def help():
"""help
Returns:
json: The /paddlespeech/vector api response content
"""
response = {
"success": "True",
"code": 200,
"message": {
"global": "success"
},
"vector": [2.3, 3.5, 5.5, 6.2, 2.8, 1.2, 0.3, 3.6]
}
return response
@router.post(
"/paddlespeech/vector", response_model=Union[VectorResponse, ErrorResponse])
def vector(request_body: VectorRequest):
"""vector api
Args:
request_body (VectorRequest): the vector request body
Returns:
json: the vector response body
"""
try:
# 1. get the audio data
# the audio must be base64 format
audio_data = base64.b64decode(request_body.audio)
# 2. get single engine from engine pool
# and we use the vector_engine to create an connection handler to process the request
engine_pool = get_engine_pool()
vector_engine = engine_pool['vector']
connection_handler = PaddleVectorConnectionHandler(vector_engine)
# 3. we use the connection handler to process the audio
audio_vec = connection_handler.run(audio_data, request_body.task)
# 4. we need the result of the vector instance be numpy.ndarray
if not isinstance(audio_vec, np.ndarray):
logger.error(
f"the vector type is not numpy.array, that is: {type(audio_vec)}"
)
error_reponse = ErrorResponse()
error_reponse.message.description = f"the vector type is not numpy.array, that is: {type(audio_vec)}"
return error_reponse
response = {
"success": True,
"code": 200,
"message": {
"description": "success"
},
"result": {
"vec": audio_vec.tolist()
}
}
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
return response
@router.post(
"/paddlespeech/vector/score",
response_model=Union[VectorScoreResponse, ErrorResponse])
def score(request_body: VectorScoreRequest):
"""vector api
Args:
request_body (VectorScoreRequest): the punctuation request body
Returns:
json: the punctuation response body
"""
try:
# 1. get the audio data
# the audio must be base64 format
enroll_data = base64.b64decode(request_body.enroll_audio)
test_data = base64.b64decode(request_body.test_audio)
# 2. get single engine from engine pool
# and we use the vector_engine to create an connection handler to process the request
engine_pool = get_engine_pool()
vector_engine = engine_pool['vector']
connection_handler = PaddleVectorConnectionHandler(vector_engine)
# 3. we use the connection handler to process the audio
score = connection_handler.get_enroll_test_score(enroll_data, test_data)
response = {
"success": True,
"code": 200,
"message": {
"description": "success"
},
"result": {
"score": score
}
}
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
return response
......@@ -142,6 +142,7 @@ class ASRWsAudioHandler:
return ""
# 1. send websocket handshake protocal
start_time = time.time()
async with websockets.connect(self.url) as ws:
# 2. server has already received handshake protocal
# client start to send the command
......@@ -187,7 +188,14 @@ class ASRWsAudioHandler:
if self.punc_server:
msg["result"] = self.punc_server.run(msg["result"])
# 6. logging the final result and comptute the statstics
elapsed_time = time.time() - start_time
audio_info = soundfile.info(wavfile_path)
logger.info("client final receive msg={}".format(msg))
logger.info(
f"audio duration: {audio_info.duration}, elapsed time: {elapsed_time}, RTF={elapsed_time/audio_info.duration}"
)
result = msg
return result
......@@ -456,3 +464,96 @@ class TTSHttpHandler:
self.stream.stop_stream()
self.stream.close()
self.p.terminate()
class VectorHttpHandler:
def __init__(self, server_ip=None, port=None):
"""The Vector client http request
Args:
server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1".
port (int, optional): the http vector server port. Defaults to 8090.
"""
super().__init__()
self.server_ip = server_ip
self.port = port
if server_ip is None or port is None:
self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector'
def run(self, input, audio_format, sample_rate, task="spk"):
"""Call the http asr to process the audio
Args:
input (str): the audio file path
audio_format (str): the audio format
sample_rate (str): the audio sample rate
Returns:
list: the audio vector
"""
if self.url is None:
logger.error("No vector server, please input valid ip and port")
return ""
audio = wav2base64(input)
data = {
"audio": audio,
"task": task,
"audio_format": audio_format,
"sample_rate": sample_rate,
}
logger.info(self.url)
res = requests.post(url=self.url, data=json.dumps(data))
return res.json()
class VectorScoreHttpHandler:
def __init__(self, server_ip=None, port=None):
"""The Vector score client http request
Args:
server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1".
port (int, optional): the http vector server port. Defaults to 8090.
"""
super().__init__()
self.server_ip = server_ip
self.port = port
if server_ip is None or port is None:
self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector/score'
def run(self, enroll_audio, test_audio, audio_format, sample_rate):
"""Call the http asr to process the audio
Args:
input (str): the audio file path
audio_format (str): the audio format
sample_rate (str): the audio sample rate
Returns:
list: the audio vector
"""
if self.url is None:
logger.error("No vector server, please input valid ip and port")
return ""
enroll_audio = wav2base64(enroll_audio)
test_audio = wav2base64(test_audio)
data = {
"enroll_audio": enroll_audio,
"test_audio": test_audio,
"task": "score",
"audio_format": audio_format,
"sample_rate": sample_rate,
}
res = requests.post(url=self.url, data=json.dumps(data))
return res.json()
# Examples for SpeechX
* ds2_ol - ds2 streaming test under `aishell-1` test dataset.
The entrypoint is `ds2_ol/aishell/run.sh`
* `ds2_ol` - ds2 streaming test under `aishell-1` test dataset.
## How to run
......
# Deepspeech2 Streaming ASR
* websocket
Streaming ASR with websocket.
## Examples
* aishell
Streaming Decoding under aishell dataset, for local WER test and so on.
* `websocket` - Streaming ASR with websocket.
* `aishell` - Streaming Decoding under aishell dataset, for local WER test.
## More
The below is for developing and offline testing:
> The below is for developing and offline testing. Do not run it only if you know what it is.
* nnet
* feat
* decoder
# Aishell - Deepspeech2 Streaming
## CTC Prefix Beam Search w/o LM
## How to run
```
bash run.sh
```
## Results
### CTC Prefix Beam Search w/o LM
```
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
......@@ -8,7 +16,7 @@ Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
## CTC Prefix Beam Search w/ LM
### CTC Prefix Beam Search w/ LM
LM: zh_giga.no_cna_cmn.prune01244.klm
```
......@@ -17,7 +25,7 @@ Mandarin -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
## CTC WFST
### CTC WFST
LM: [aishell train](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/aishell_graph.zip)
--acoustic_scale=1.2
......
......@@ -98,6 +98,7 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
......@@ -160,5 +161,7 @@ int main(int argc, char* argv[]) {
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
return (num_done != 0 ? 0 : 1);
}
......@@ -38,6 +38,9 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
kaldi::Timer timer;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
......@@ -47,6 +50,7 @@ int main(int argc, char* argv[]) {
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
tot_wav_duration += tot_samples * 1.0 / sample_rate;
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
......@@ -85,4 +89,9 @@ int main(int argc, char* argv[]) {
result_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done);
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s";
KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration;
}
\ No newline at end of file
......@@ -100,7 +100,7 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
......@@ -160,6 +160,9 @@ int main(int argc, char* argv[]) {
++num_done;
}
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
......
......@@ -5,8 +5,12 @@ add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
set(bin_name compute_fbank_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
set(bin_name cmvn-json2kaldi)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog ${DEPS})
\ No newline at end of file
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog ${DEPS})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/fbank.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/normalizer.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(num_bins, 161, "fbank num bins");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
int32 num_done = 0, num_err = 0;
// feature pipeline: wave cache --> povey window
// -->fbank --> global cmvn -> feat cache
std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::AudioCache(3600 * 1600, false));
ppspeech::FbankOptions opt;
opt.fbank_opts.frame_opts.frame_length_ms = 25;
opt.fbank_opts.frame_opts.frame_shift_ms = 10;
opt.streaming_chunk = FLAGS_streaming_chunk;
opt.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
opt.fbank_opts.frame_opts.dither = 0.0;
std::unique_ptr<ppspeech::FrontendInterface> fbank(
new ppspeech::Fbank(opt, std::move(data_source)));
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
ppspeech::FeatureCacheOptions feat_cache_opts;
// the feature cache output feature chunk by chunk.
// frame_chunk_size : num frame of a chunk.
// frame_chunk_stride: chunk sliding window stride.
feat_cache_opts.frame_chunk_stride = 1;
feat_cache_opts.frame_chunk_size = 1;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "process utt: " << utt;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Vector<BaseFloat> features;
feature_cache.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
feature_cache.SetFinished();
}
bool flag = true;
do {
flag = feature_cache.Read(&features);
feats.push_back(features);
feature_rows += features.Dim() / feature_cache.Dim();
} while (flag == true && features.Dim() != 0);
sample_offset += cur_chunk_size;
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feature_cache.Dim());
for (auto feat : feats) {
int num_rows = feat.Dim() / feature_cache.Dim();
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
++col_idx) {
features(cur_idx, col_idx) =
feat(row_idx * feature_cache.Dim() + col_idx);
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
feature_cache.Reset();
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
reference:
this patch is from WeNet wenet/runtime/core/patch
......@@ -47,7 +47,8 @@ DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_bool(use_fbank, false, "use fbank or linear feature");
DEFINE_int32(num_bins, 161, "num bins of mel");
namespace ppspeech {
// todo refactor later
......@@ -57,13 +58,21 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk;
opts.to_float32 = FLAGS_to_float32;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.frame_length_ms = 20;
frame_opts.frame_shift_ms = 10;
frame_opts.remove_dc_offset = false;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
frame_opts.dither = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
frame_opts.frame_shift_ms = 10;
opts.use_fbank = FLAGS_use_fbank;
if (opts.use_fbank) {
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
opts.fbank_opts.fbank_opts.frame_opts = frame_opts;
} else {
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length;
opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate;
return opts;
......
......@@ -7,6 +7,7 @@ add_library(frontend STATIC
audio_cache.cc
feature_cache.cc
feature_pipeline.cc
fbank.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common)
target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common kaldi-fbank)
......@@ -29,14 +29,16 @@ using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector;
// todo refactor later:(SmileGoat)
Fbank::Fbank(const FbankOptions& opts,
std::unique_ptr<FrontendInterface> base_extractor)
: opts_(opts),
computer_(opts.fbank_opts),
window_function_(computer_.GetFrameOptions()) {
window_function_(opts.fbank_opts.frame_opts) {
base_extractor_ = std::move(base_extractor);
chunk_sample_size_ =
static_cast<int32>(opts.streaming_chunk * opts.frame_opts.samp_freq);
chunk_sample_size_ = static_cast<int32>(
opts.streaming_chunk * opts.fbank_opts.frame_opts.samp_freq);
}
void Fbank::Accept(const VectorBase<BaseFloat>& inputs) {
......@@ -71,7 +73,8 @@ bool Fbank::Read(Vector<BaseFloat>* feats) {
// Compute spectrogram feat
bool Fbank::Compute(const Vector<BaseFloat>& waves, Vector<BaseFloat>* feats) {
const FrameExtractionOptions& frame_opts = computer_.GetFrameOptions();
const kaldi::FrameExtractionOptions& frame_opts =
computer_.GetFrameOptions();
int32 num_samples = waves.Dim();
int32 frame_length = frame_opts.WindowSize();
int32 sample_rate = frame_opts.samp_freq;
......@@ -80,7 +83,7 @@ bool Fbank::Compute(const Vector<BaseFloat>& waves, Vector<BaseFloat>* feats) {
}
int32 num_frames = kaldi::NumFrames(num_samples, frame_opts);
feats->Rsize(num_frames * Dim());
feats->Resize(num_frames * Dim());
Vector<BaseFloat> window;
bool need_raw_log_energy = computer_.NeedRawLogEnergy();
......@@ -95,14 +98,24 @@ bool Fbank::Compute(const Vector<BaseFloat>& waves, Vector<BaseFloat>* feats) {
need_raw_log_energy ? &raw_log_energy : NULL);
Vector<BaseFloat> this_feature(computer_.Dim(), kUndefined);
Vector<BaseFloat> this_feature(computer_.Dim(), kaldi::kUndefined);
// note: this online feature-extraction code does not support VTLN.
BaseFloat vtln_warp = 1.0;
computer_.Compute(raw_log_energy, vtln_warp, &window, &this_feature);
RealFft(&window, true);
kaldi::ComputePowerSpectrum(&window);
const kaldi::MelBanks &mel_bank = *(computer_.GetMelBanks(1.0));
SubVector<BaseFloat> power_spectrum(window, 0, window.Dim() / 2 + 1);
if (!opts_.fbank_opts.use_power) {
power_spectrum.ApplyPow(0.5);
}
int32 mel_offset = ((opts_.fbank_opts.use_energy && !opts_.fbank_opts.htk_compat) ? 1 : 0);
SubVector<BaseFloat> mel_energies(this_feature, mel_offset, opts_.fbank_opts.mel_opts.num_bins);
mel_bank.Compute(power_spectrum, &mel_energies);
mel_energies.ApplyFloor(1e-07);
mel_energies.ApplyLog();
SubVector<BaseFloat> output_row(feats->Data() + frame * Dim(), Dim());
output_row.CopyFromVec(this_feature);
}
return true;
}
} // namespace ppspeech
\ No newline at end of file
} // namespace ppspeech
......@@ -14,6 +14,8 @@
#pragma once
#include "base/common.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/feat/feature-fbank.h"
#include "kaldi/feat/feature-mfcc.h"
#include "kaldi/matrix/kaldi-vector.h"
......@@ -38,7 +40,7 @@ struct FbankOptions {
class Fbank : public FrontendInterface {
public:
explicit Fbank(const FbankOptions& opts,
unique_ptr<FrontendInterface> base_extractor);
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
......@@ -61,15 +63,15 @@ class Fbank : public FrontendInterface {
FbankOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_;
FeatureWindowFunction window_function_;
kaldi::FeatureWindowFunction window_function_;
kaldi::FbankComputer computer_;
// features_ is the Mfcc or Plp or Fbank features that we have already
// computed.
kaldi::Vector<kaldi::BaseFloat> features_;
kaldi::Vector<kaldi::BaseFloat> remained_wav_;
kaldi::int32 chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(Fbank);
};
} // namespace ppspeech
\ No newline at end of file
} // namespace ppspeech
......@@ -22,12 +22,18 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.to_float32));
unique_ptr<FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts,
std::move(data_source)));
unique_ptr<FrontendInterface> base_feature;
if (opts.use_fbank) {
base_feature.reset(new ppspeech::Fbank(opts.fbank_opts,
std::move(data_source)));
} else {
base_feature.reset(new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts,
std::move(data_source)));
}
unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(linear_spectrogram)));
new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature)));
base_extractor_.reset(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn)));
......
......@@ -21,6 +21,7 @@
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/fbank.h"
#include "frontend/audio/normalizer.h"
namespace ppspeech {
......@@ -28,12 +29,16 @@ namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file;
bool to_float32;
bool use_fbank;
LinearSpectrogramOptions linear_spectrogram_opts;
FbankOptions fbank_opts;
FeatureCacheOptions feature_cache_opts;
FeaturePipelineOptions()
: cmvn_file(""),
to_float32(false),
use_fbank(false),
linear_spectrogram_opts(),
fbank_opts(),
feature_cache_opts() {}
};
......
......@@ -3,10 +3,10 @@ add_library(kaldi-mfcc
)
target_link_libraries(kaldi-mfcc PUBLIC kaldi-feat-common)
add_library(fbank
add_library(kaldi-fbank
feature-fbank.cc
)
target_link_libraries(fbank PUBLIC kaldi-feat-common)
target_link_libraries(kaldi-fbank PUBLIC kaldi-feat-common)
add_library(kaldi-feat-common
wave-reader.cc
......
......@@ -128,8 +128,8 @@ class FbankComputer {
~FbankComputer();
private:
const MelBanks *GetMelBanks(BaseFloat vtln_warp);
private:
FbankOptions opts_;
......
......@@ -120,8 +120,8 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
last_index = i;
}
}
KALDI_ASSERT(first_index != -1 && last_index >= first_index
&& "You may have set --num-mel-bins too large.");
//KALDI_ASSERT(first_index != -1 && last_index >= first_index
// && "You may have set --num-mel-bins too large.");
bins_[bin].first = first_index;
int32 size = last_index + 1 - first_index;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Wenet Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Wenet Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Wenet Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Wenet Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册