提交 f869db7a 编写于 作者: Y Yang Zhou

Merge branch 'develop' of github.com:SmileGoat/PaddleSpeech into add_fbank

......@@ -14,7 +14,7 @@ see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/doc
You can choose one way from easy, meduim and hard to install paddlespeech.
### 2. Prepare Input File
The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
The input of this cli demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
Here are sample files for this demo that can be downloaded:
```bash
......
......@@ -4,16 +4,16 @@
## 介绍
声纹识别是一项用计算机程序自动提取说话人特征的技术。
这个 demo 是一个从给定音频文件提取说话人特征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。
这个 demo 是从一个给定音频文件中提取说话人特征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。
## 使用方法
### 1. 安装
请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)
你可以从 easy,medium,hard 三中方式中选择一种方式安装。
你可以从easy medium,hard 三种方式中选择一种方式安装。
### 2. 准备输入
这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
声纹cli demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
可以下载此 demo 的示例音频:
```bash
......
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8190
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_python']
# protocol = ['http'] (only one can be selected).
# http only support offline engine type.
protocol: 'http'
engine_list: ['text_python']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### Text #########################################
################### text task: punc; engine_type: python #######################
text_python:
task: punc
model_type: 'ernie_linear_p3_wudao'
lang: 'zh'
sample_rate: 16000
cfg_path: # [optional]
ckpt_path: # [optional]
vocab_file: # [optional]
device: 'cpu' # set 'gpu:id' or 'cpu'
......@@ -4,7 +4,7 @@
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
port: 8290
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online']
......@@ -29,7 +29,7 @@ asr_online:
cfg_path:
decode_method:
force_yes: True
device: # cpu or gpu:id
device: 'cpu' # cpu or gpu:id
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
......@@ -42,4 +42,4 @@ asr_online:
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2
\ No newline at end of file
sample_width: 2
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from paddlespeech.cli.log import logger
from paddlespeech.server.bin.paddlespeech_server import ServerExecutor
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='paddlespeech_server.start', add_help=True)
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default=None,
required=True)
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
logger.info("start to parse the args")
args = parser.parse_args()
logger.info("start to launch the punctuation server")
punc_server = ServerExecutor()
punc_server(config_file=args.config_file, log_file=args.log_file)
export CUDA_VISIBLE_DEVICE=0,1,2,3
nohup python3 punc_server.py --config_file conf/punc_application.yaml > punc.log 2>&1 &
nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from paddlespeech.cli.log import logger
from paddlespeech.server.bin.paddlespeech_server import ServerExecutor
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='paddlespeech_server.start', add_help=True)
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default=None,
required=True)
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
logger.info("start to parse the args")
args = parser.parse_args()
logger.info("start to launch the streaming asr server")
streaming_asr_server = ServerExecutor()
streaming_asr_server(config_file=args.config_file, log_file=args.log_file)
# download the test wav
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
# read the wav and pass it to service
python3 websocket_client.py --wavfile ./zh.wav
# read the wav and pass it to only streaming asr service
python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
# read the wav and call streaming and punc service
python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
......@@ -28,6 +28,7 @@ def main(args):
handler = ASRWsAudioHandler(
args.server_ip,
args.port,
endpoint=args.endpoint,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
loop = asyncio.get_event_loop()
......@@ -69,7 +70,11 @@ if __name__ == "__main__":
default=8091,
dest="punc_server_port",
help='Punctuation server port')
parser.add_argument(
"--endpoint",
type=str,
default="/paddlespeech/asr/streaming",
help="ASR websocket endpoint")
parser.add_argument(
"--wavfile",
action="store",
......
......@@ -10,7 +10,7 @@ encoder_conf:
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
......@@ -30,7 +30,7 @@ decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
......@@ -39,7 +39,7 @@ model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform'
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
......
......@@ -37,7 +37,7 @@ model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform'
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
......
......@@ -10,7 +10,7 @@ encoder_conf:
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
......@@ -21,7 +21,7 @@ decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
......
......@@ -272,7 +272,8 @@ class VectorExecutor(BaseExecutor):
model_type: str='ecapatdnn_voxceleb12',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None):
ckpt_path: Optional[os.PathLike]=None,
task=None):
"""Init the neural network from the model path
Args:
......@@ -284,8 +285,10 @@ class VectorExecutor(BaseExecutor):
Defaults to None.
ckpt_path (Optional[os.PathLike], optional): the pretrained model path, which is stored in the disk.
Defaults to None.
task (str, optional): the model task type
"""
# stage 0: avoid to init the mode again
self.task = task
if hasattr(self, "model"):
logger.info("Model has been initialized")
return
......@@ -434,6 +437,9 @@ class VectorExecutor(BaseExecutor):
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(
"invalid sample rate, please input --sr 8000 or --sr 16000")
logger.error(
f"The model sample rate: {self.sample_rate}, the external sample rate is: {sample_rate}"
)
return False
if isinstance(audio_file, (str, os.PathLike)):
......
......@@ -63,3 +63,23 @@ paddlespeech_server start --config_file conf/tts_online_application.yaml
```
paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --input "您好,欢迎使用百度飞桨深度学习框架!" --output output.wav
```
## 声纹识别
### 启动声纹识别服务
```
paddlespeech_server start --config_file conf/vector_application.yaml
```
### 获取说话人音频声纹
```
paddlespeech_client vector --task spk --server_ip 127.0.0.1 --port 8090 --input 85236145389.wav
```
### 两个说话人音频声纹打分
```
paddlespeech_client vector --task score --server_ip 127.0.0.1 --port 8090 --enroll 123456789.wav --test 85236145389.wav
```
......@@ -35,7 +35,7 @@ from paddlespeech.server.utils.util import wav2base64
__all__ = [
'TTSClientExecutor', 'TTSOnlineClientExecutor', 'ASRClientExecutor',
'ASROnlineClientExecutor', 'CLSClientExecutor'
'ASROnlineClientExecutor', 'CLSClientExecutor', 'VectorClientExecutor'
]
......@@ -411,6 +411,18 @@ class ASROnlineClientExecutor(BaseExecutor):
'--lang', type=str, default="zh_cn", help='language')
self.parser.add_argument(
'--audio_format', type=str, default="wav", help='audio format')
self.parser.add_argument(
'--punc.server_ip',
type=str,
default=None,
dest="punc_server_ip",
help='Punctuation server ip')
self.parser.add_argument(
'--punc.port',
type=int,
default=8190,
dest="punc_server_port",
help='Punctuation server port')
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
......@@ -428,7 +440,9 @@ class ASROnlineClientExecutor(BaseExecutor):
port=port,
sample_rate=sample_rate,
lang=lang,
audio_format=audio_format)
audio_format=audio_format,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
time_end = time.time()
logger.info(res)
logger.info("Response time %f s." % (time_end - time_start))
......@@ -445,12 +459,30 @@ class ASROnlineClientExecutor(BaseExecutor):
port: int=8091,
sample_rate: int=16000,
lang: str="zh_cn",
audio_format: str="wav"):
"""
Python API to call an executor.
audio_format: str="wav",
punc_server_ip: str=None,
punc_server_port: str=None):
"""Python API to call asr online executor.
Args:
input (str): the audio file to be send to streaming asr service.
server_ip (str, optional): streaming asr server ip. Defaults to "127.0.0.1".
port (int, optional): streaming asr server port. Defaults to 8091.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
lang (str, optional): audio language type. Defaults to "zh_cn".
audio_format (str, optional): audio format. Defaults to "wav".
punc_server_ip (str, optional): punctuation server ip. Defaults to None.
punc_server_port (str, optional): punctuation server port. Defaults to None.
Returns:
str: the audio text
"""
logger.info("asr websocket client start")
handler = ASRWsAudioHandler(server_ip, port)
handler = ASRWsAudioHandler(
server_ip,
port,
punc_server_ip=punc_server_ip,
punc_server_port=punc_server_port)
loop = asyncio.get_event_loop()
res = loop.run_until_complete(handler.run(input))
logger.info("asr websocket client finished")
......@@ -583,3 +615,108 @@ class TextClientExecutor(BaseExecutor):
response_dict = res.json()
punc_text = response_dict["result"]["punc_text"]
return punc_text
@cli_client_register(
name='paddlespeech_client.vector', description='visit the vector service')
class VectorClientExecutor(BaseExecutor):
def __init__(self):
super(VectorClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.vector', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8090, help='server port')
self.parser.add_argument(
'--input',
type=str,
default=None,
help='sentence to be process by text server.')
self.parser.add_argument(
'--task',
type=str,
default="spk",
choices=["spk", "score"],
help="The vector service task")
self.parser.add_argument(
"--enroll", type=str, default=None, help="The enroll audio")
self.parser.add_argument(
"--test", type=str, default=None, help="The test audio")
def execute(self, argv: List[str]) -> bool:
"""Execute the request from the argv.
Args:
argv (List): the request arguments
Returns:
str: the request flag
"""
args = self.parser.parse_args(argv)
input_ = args.input
server_ip = args.server_ip
port = args.port
task = args.task
try:
time_start = time.time()
res = self(
input=input_,
server_ip=server_ip,
port=port,
enroll_audio=args.enroll,
test_audio=args.test,
task=task)
time_end = time.time()
logger.info(f"The vector: {res}")
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to extract vector.")
logger.error(e)
return False
@stats_wrapper
def __call__(self,
input: str,
server_ip: str="127.0.0.1",
port: int=8090,
audio_format: str="wav",
sample_rate: int=16000,
enroll_audio: str=None,
test_audio: str=None,
task="spk"):
"""
Python API to call text executor.
Args:
input (str): the request audio data
server_ip (str, optional): the server ip. Defaults to "127.0.0.1".
port (int, optional): the server port. Defaults to 8090.
audio_format (str, optional): audio format. Defaults to "wav".
sample_rate (str, optional): audio sample rate. Defaults to 16000.
enroll_audio (str, optional): enroll audio data. Defaults to None.
test_audio (str, optional): test audio data. Defaults to None.
task (str, optional): the task type, "spk" or "socre". Defaults to "spk"
Returns:
str: the audio embedding or score between enroll and test audio
"""
if task == "spk":
from paddlespeech.server.utils.audio_handler import VectorHttpHandler
logger.info("vector http client start")
logger.info(f"the input audio: {input}")
handler = VectorHttpHandler(server_ip=server_ip, port=port)
res = handler.run(input, audio_format, sample_rate)
return res
elif task == "score":
from paddlespeech.server.utils.audio_handler import VectorScoreHttpHandler
logger.info("vector score http client start")
logger.info(
f"enroll audio: {enroll_audio}, test audio: {test_audio}")
handler = VectorScoreHttpHandler(server_ip=server_ip, port=port)
res = handler.run(enroll_audio, test_audio, audio_format,
sample_rate)
logger.info(f"The vector score is: {res}")
else:
logger.error(f"Sorry, we have not support such task {task}")
......@@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket', 'http'] (only one can be selected).
# http only support offline engine type.
protocol: 'http'
engine_list: ['asr_python', 'tts_python', 'cls_python', 'text_python']
engine_list: ['asr_python', 'tts_python', 'cls_python', 'text_python', 'vector_python']
#################################################################################
......@@ -166,4 +166,15 @@ text_python:
cfg_path: # [optional]
ckpt_path: # [optional]
vocab_file: # [optional]
device: # set 'gpu:id' or 'cpu'
################################### Vector ######################################
################### Vector task: spk; engine_type: python #######################
vector_python:
task: spk
model_type: 'ecapatdnn_voxceleb12'
sample_rate: 16000
cfg_path: # [optional]
ckpt_path: # [optional]
device: # set 'gpu:id' or 'cpu'
\ No newline at end of file
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# protocol = ['http'] (only one can be selected).
# http only support offline engine type.
protocol: 'http'
engine_list: ['vector_python']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### Vector ######################################
################### Vector task: spk; engine_type: python #######################
vector_python:
task: spk
model_type: 'ecapatdnn_voxceleb12'
sample_rate: 16000
cfg_path: # [optional]
ckpt_path: # [optional]
device: # set 'gpu:id' or 'cpu'
......@@ -153,6 +153,12 @@ class PaddleASRConnectionHanddler:
self.n_shift = self.preprocess_conf.process[0]['n_shift']
def extract_feat(self, samples):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
# if self.first_char_occur_elapsed is not None:
# self.first_char_occur_elapsed = time.time()
if "deepspeech2online" in self.model_type:
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
......@@ -290,6 +296,9 @@ class PaddleASRConnectionHanddler:
self.chunk_num = 0
self.global_frame_offset = 0
self.result_transcripts = ['']
self.word_time_stamp = []
self.time_stamp = []
self.first_char_occur_elapsed = None
def decode(self, is_finished=False):
if "deepspeech2online" in self.model_type:
......@@ -505,6 +514,9 @@ class PaddleASRConnectionHanddler:
else:
return ''
def get_word_time_stamp(self):
return self.word_time_stamp
@paddle.no_grad()
def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
......@@ -567,10 +579,48 @@ class PaddleASRConnectionHanddler:
best_index = i
# update the one best result
# hyps stored the beam results and each fields is:
logger.info(f"best index: {best_index}")
# logger.info(f'best result: {hyps[best_index]}')
# the field of the hyps is:
# hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][1]: the sentence decoding probability with all paths
# hyps[0][2]: viterbi_blank ending probability
# hyps[0][3]: viterbi_non_blank probability
# hyps[0][4]: current_token_prob,
# hyps[0][5]: times_viterbi_blank,
# hyps[0][6]: times_titerbi_non_blank
self.hyps = [hyps[best_index][0]]
# update the hyps time stamp
self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[
best_index][3] else hyps[best_index][6]
logger.info(f"time stamp: {self.time_stamp}")
self.update_result()
# update each word start and end time stamp
frame_shift_in_ms = self.model.encoder.embed.subsampling_rate * self.n_shift / self.sample_rate
logger.info(f"frame shift ms: {frame_shift_in_ms}")
word_time_stamp = []
for idx, _ in enumerate(self.time_stamp):
start = (self.time_stamp[idx - 1] + self.time_stamp[idx]
) / 2.0 if idx > 0 else 0
start = start * frame_shift_in_ms
end = (self.time_stamp[idx] + self.time_stamp[idx + 1]
) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset
end = end * frame_shift_in_ms
word_time_stamp.append({
"w": self.result_transcripts[0][idx],
"bg": start,
"ed": end
})
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}")
self.word_time_stamp = word_time_stamp
logger.info(f"word time stamp: {self.word_time_stamp}")
class ASRServerExecutor(ASRExecutor):
def __init__(self):
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections import defaultdict
import paddle
......@@ -26,7 +27,7 @@ class CTCPrefixBeamSearch:
"""Implement the ctc prefix beam search
Args:
config (yacs.config.CfgNode): _description_
config (yacs.config.CfgNode): the ctc prefix beam search configuration
"""
self.config = config
self.reset()
......@@ -54,14 +55,23 @@ class CTCPrefixBeamSearch:
assert len(ctc_probs.shape) == 2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
# 0. blank_ending_score,
# 1. none_blank_ending_score,
# 2. viterbi_blank ending,
# 3. viterbi_non_blank,
# 4. current_token_prob,
# 5. times_viterbi_blank,
# 6. times_titerbi_non_blank
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
self.cur_hyps = [(tuple(), (0.0, -float('inf'), 0.0, 0.0,
-float('inf'), [], []))]
# self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
next_hyps = defaultdict(
lambda: (-float('inf'), -float('inf'), -float('inf'), -float('inf'), -float('inf'), [], []))
# 2.1 First beam prune: select topk best
# do token passing process
......@@ -69,36 +79,83 @@ class CTCPrefixBeamSearch:
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
for prefix, (pb, pnb, v_b_s, v_nb_s, cur_token_prob, times_s,
times_ns) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
pre_times = times_s if v_b_s > v_nb_s else times_ns
n_times_s = copy.deepcopy(pre_times)
viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s
n_v_s = viterbi_score + ps
next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns,
n_cur_token_prob, n_times_s,
n_times_ns)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
# case1: *a + a => *a
n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
if n_v_ns < v_nb_s + ps:
n_v_ns = v_nb_s + ps
if n_cur_token_prob < ps:
n_cur_token_prob = ps
n_times_ns = copy.deepcopy(times_ns)
n_times_ns[
-1] = self.abs_time_step # 注意,这里要重新使用绝对时间
next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns,
n_cur_token_prob, n_times_s,
n_times_ns)
# Update *s-s -> *ss, - is for blank
# Case 2: *aε + a => *aa
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
n_prefix]
if n_v_ns < v_b_s + ps:
n_v_ns = v_b_s + ps
n_cur_token_prob = ps
n_times_ns = copy.deepcopy(times_s)
n_times_ns.append(self.abs_time_step)
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
next_hyps[n_prefix] = (n_pb, n_pnb, n_v_s, n_v_ns,
n_cur_token_prob, n_times_s,
n_times_ns)
else:
# Case 3: *a + b => *ab, *aε + b => *ab
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[
n_prefix]
viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s
pre_times = times_s if v_b_s > v_nb_s else times_ns
if n_v_ns < viterbi_score + ps:
n_v_ns = viterbi_score + ps
n_cur_token_prob = ps
n_times_ns = copy.deepcopy(pre_times)
n_times_ns.append(self.abs_time_step)
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
next_hyps[n_prefix] = (n_pb, n_pnb, n_v_s, n_v_ns,
n_cur_token_prob, n_times_s,
n_times_ns)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
key=lambda x: log_add([x[1][0], x[1][1]]),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
# 2.3 update the absolute time step
self.abs_time_step += 1
self.hyps = [(y[0], log_add([y[1][0], y[1][1]]), y[1][2], y[1][3],
y[1][4], y[1][5], y[1][6]) for y in self.cur_hyps]
logger.info("ctc prefix search success")
return self.hyps
......@@ -123,6 +180,7 @@ class CTCPrefixBeamSearch:
"""
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
......
......@@ -49,5 +49,8 @@ class EngineFactory(object):
elif engine_name.lower() == 'text' and engine_type.lower() == 'python':
from paddlespeech.server.engine.text.python.text_engine import TextEngine
return TextEngine()
elif engine_name.lower() == 'vector' and engine_type.lower() == 'python':
from paddlespeech.server.engine.vector.python.vector_engine import VectorEngine
return VectorEngine()
else:
return None
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from collections import OrderedDict
import numpy as np
import paddle
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.cli.log import logger
from paddlespeech.cli.vector.infer import VectorExecutor
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.vector.io.batch import feature_normalize
class PaddleVectorConnectionHandler:
def __init__(self, vector_engine):
"""The PaddleSpeech Vector Server Connection Handler
This connection process every server request
Args:
vector_engine (VectorEngine): The Vector engine
"""
super().__init__()
logger.info(
"Create PaddleVectorConnectionHandler to process the vector request")
self.vector_engine = vector_engine
self.executor = self.vector_engine.executor
self.task = self.vector_engine.executor.task
self.model = self.vector_engine.executor.model
self.config = self.vector_engine.executor.config
self._inputs = OrderedDict()
self._outputs = OrderedDict()
@paddle.no_grad()
def run(self, audio_data, task="spk"):
"""The connection process the http request audio
Args:
audio_data (bytes): base64.b64decode
Returns:
str: the punctuation text
"""
logger.info(
f"start to extract the do vector {self.task} from the http request")
if self.task == "spk" and task == "spk":
embedding = self.extract_audio_embedding(audio_data)
return embedding
else:
logger.error(
"The request task is not matched with server model task")
logger.error(
f"The server model task is: {self.task}, but the request task is: {task}"
)
return np.array([
0.0,
])
@paddle.no_grad()
def get_enroll_test_score(self, enroll_audio, test_audio):
"""Get the enroll and test audio score
Args:
enroll_audio (str): the base64 format enroll audio
test_audio (str): the base64 format test audio
Returns:
float: the score between enroll and test audio
"""
logger.info("start to extract the enroll audio embedding")
enroll_emb = self.extract_audio_embedding(enroll_audio)
logger.info("start to extract the test audio embedding")
test_emb = self.extract_audio_embedding(test_audio)
logger.info(
"start to get the score between the enroll and test embedding")
score = self.executor.get_embeddings_score(enroll_emb, test_emb)
logger.info(f"get the enroll vs test score: {score}")
return score
@paddle.no_grad()
def extract_audio_embedding(self, audio: str, sample_rate: int=16000):
"""extract the audio embedding
Args:
audio (str): the audio data
sample_rate (int, optional): the audio sample rate. Defaults to 16000.
"""
# we can not reuse the cache io.BytesIO(audio) data,
# because the soundfile will change the io.BytesIO(audio) to the end
# thus we should convert the base64 string to io.BytesIO when we need the audio data
if not self.executor._check(io.BytesIO(audio), sample_rate):
logger.info("check the audio sample rate occurs error")
return np.array([0.0])
waveform, sr = load_audio(io.BytesIO(audio))
logger.info(f"load the audio sample points, shape is: {waveform.shape}")
# stage 2: get the audio feat
# Note: Now we only support fbank feature
try:
feats = melspectrogram(
x=waveform,
sr=self.config.sr,
n_mels=self.config.n_mels,
window_size=self.config.window_size,
hop_length=self.config.hop_size)
logger.info(f"extract the audio feats, shape is: {feats.shape}")
except Exception as e:
logger.info(f"feats occurs exception {e}")
sys.exit(-1)
feats = paddle.to_tensor(feats).unsqueeze(0)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
# stage 3: we do feature normalize,
# Now we assume that the feats must do normalize
feats = feature_normalize(feats, mean_norm=True, std_norm=False)
# stage 4: store the feats and length in the _inputs,
# which will be used in other function
logger.info(f"feats shape: {feats.shape}")
logger.info("audio extract the feats success")
logger.info("start to extract the audio embedding")
embedding = self.model.backbone(feats, lengths).squeeze().numpy()
logger.info(f"embedding size: {embedding.shape}")
return embedding
class VectorServerExecutor(VectorExecutor):
def __init__(self):
"""The wrapper for TextEcutor
"""
super().__init__()
pass
class VectorEngine(BaseEngine):
def __init__(self):
"""The Vector Engine
"""
super(VectorEngine, self).__init__()
logger.info("Create the VectorEngine Instance")
def init(self, config: dict):
"""Init the Vector Engine
Args:
config (dict): The server configuation
Returns:
bool: The engine instance flag
"""
logger.info("Init the vector engine")
try:
self.config = config
if self.config.device:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
logger.info(f"Vector Engine set the device: {self.device}")
except BaseException as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize Vector server engine Failed on device: %s."
% (self.device))
return False
self.executor = VectorServerExecutor()
self.executor._init_from_path(
model_type=config.model_type,
cfg_path=config.cfg_path,
ckpt_path=config.ckpt_path,
task=config.task)
logger.info("Init the Vector engine successfully")
return True
......@@ -21,7 +21,7 @@ from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.text_api import router as text_router
from paddlespeech.server.restful.tts_api import router as tts_router
from paddlespeech.server.restful.vector_api import router as vec_router
_router = APIRouter()
......@@ -43,6 +43,8 @@ def setup_router(api_list: List):
_router.include_router(cls_router)
elif api_name == 'text':
_router.include_router(text_router)
elif api_name.lower() == 'vector':
_router.include_router(vec_router)
else:
logger.error(
f"PaddleSpeech has not support such service: {api_name}")
......
......@@ -15,7 +15,10 @@ from typing import Optional
from pydantic import BaseModel
__all__ = ['ASRRequest', 'TTSRequest', 'CLSRequest']
__all__ = [
'ASRRequest', 'TTSRequest', 'CLSRequest', 'VectorRequest',
'VectorScoreRequest'
]
#****************************************************************************************/
......@@ -85,3 +88,40 @@ class CLSRequest(BaseModel):
#****************************************************************************************/
class TextRequest(BaseModel):
text: str
#****************************************************************************************/
#************************************ Vecotr request ************************************/
#****************************************************************************************/
class VectorRequest(BaseModel):
"""
request body example
{
"audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"task": "spk",
"audio_format": "wav",
"sample_rate": 16000,
}
"""
audio: str
task: str
audio_format: str
sample_rate: int
class VectorScoreRequest(BaseModel):
"""
request body example
{
"enroll_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"test_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"task": "score",
"audio_format": "wav",
"sample_rate": 16000,
}
"""
enroll_audio: str
test_audio: str
task: str
audio_format: str
sample_rate: int
......@@ -15,7 +15,10 @@ from typing import List
from pydantic import BaseModel
__all__ = ['ASRResponse', 'TTSResponse', 'CLSResponse']
__all__ = [
'ASRResponse', 'TTSResponse', 'CLSResponse', 'TextResponse',
'VectorResponse', 'VectorScoreResponse'
]
class Message(BaseModel):
......@@ -129,6 +132,11 @@ class CLSResponse(BaseModel):
result: CLSResult
#****************************************************************************************/
#************************************ Text response **************************************/
#****************************************************************************************/
class TextResult(BaseModel):
punc_text: str
......@@ -153,6 +161,59 @@ class TextResponse(BaseModel):
result: TextResult
#****************************************************************************************/
#************************************ Vector response **************************************/
#****************************************************************************************/
class VectorResult(BaseModel):
vec: list
class VectorResponse(BaseModel):
"""
response example
{
"success": true,
"code": 0,
"message": {
"description": "success"
},
"result": {
"vec": [1.0, 1.0]
}
}
"""
success: bool
code: int
message: Message
result: VectorResult
class VectorScoreResult(BaseModel):
score: float
class VectorScoreResponse(BaseModel):
"""
response example
{
"success": true,
"code": 0,
"message": {
"description": "success"
},
"result": {
"score": 1.0
}
}
"""
success: bool
code: int
message: Message
result: VectorScoreResult
#****************************************************************************************/
#********************************** Error response **************************************/
#****************************************************************************************/
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import traceback
from typing import Union
import numpy as np
from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.engine.vector.python.vector_engine import PaddleVectorConnectionHandler
from paddlespeech.server.restful.request import VectorRequest
from paddlespeech.server.restful.request import VectorScoreRequest
from paddlespeech.server.restful.response import ErrorResponse
from paddlespeech.server.restful.response import VectorResponse
from paddlespeech.server.restful.response import VectorScoreResponse
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.errors import failed_response
from paddlespeech.server.utils.exception import ServerBaseException
router = APIRouter()
@router.get('/paddlespeech/vector/help')
def help():
"""help
Returns:
json: The /paddlespeech/vector api response content
"""
response = {
"success": "True",
"code": 200,
"message": {
"global": "success"
},
"vector": [2.3, 3.5, 5.5, 6.2, 2.8, 1.2, 0.3, 3.6]
}
return response
@router.post(
"/paddlespeech/vector", response_model=Union[VectorResponse, ErrorResponse])
def vector(request_body: VectorRequest):
"""vector api
Args:
request_body (VectorRequest): the vector request body
Returns:
json: the vector response body
"""
try:
# 1. get the audio data
# the audio must be base64 format
audio_data = base64.b64decode(request_body.audio)
# 2. get single engine from engine pool
# and we use the vector_engine to create an connection handler to process the request
engine_pool = get_engine_pool()
vector_engine = engine_pool['vector']
connection_handler = PaddleVectorConnectionHandler(vector_engine)
# 3. we use the connection handler to process the audio
audio_vec = connection_handler.run(audio_data, request_body.task)
# 4. we need the result of the vector instance be numpy.ndarray
if not isinstance(audio_vec, np.ndarray):
logger.error(
f"the vector type is not numpy.array, that is: {type(audio_vec)}"
)
error_reponse = ErrorResponse()
error_reponse.message.description = f"the vector type is not numpy.array, that is: {type(audio_vec)}"
return error_reponse
response = {
"success": True,
"code": 200,
"message": {
"description": "success"
},
"result": {
"vec": audio_vec.tolist()
}
}
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
return response
@router.post(
"/paddlespeech/vector/score",
response_model=Union[VectorScoreResponse, ErrorResponse])
def score(request_body: VectorScoreRequest):
"""vector api
Args:
request_body (VectorScoreRequest): the punctuation request body
Returns:
json: the punctuation response body
"""
try:
# 1. get the audio data
# the audio must be base64 format
enroll_data = base64.b64decode(request_body.enroll_audio)
test_data = base64.b64decode(request_body.test_audio)
# 2. get single engine from engine pool
# and we use the vector_engine to create an connection handler to process the request
engine_pool = get_engine_pool()
vector_engine = engine_pool['vector']
connection_handler = PaddleVectorConnectionHandler(vector_engine)
# 3. we use the connection handler to process the audio
score = connection_handler.get_enroll_test_score(enroll_data, test_data)
response = {
"success": True,
"code": 200,
"message": {
"description": "success"
},
"result": {
"score": score
}
}
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
return response
......@@ -142,6 +142,7 @@ class ASRWsAudioHandler:
return ""
# 1. send websocket handshake protocal
start_time = time.time()
async with websockets.connect(self.url) as ws:
# 2. server has already received handshake protocal
# client start to send the command
......@@ -187,7 +188,14 @@ class ASRWsAudioHandler:
if self.punc_server:
msg["result"] = self.punc_server.run(msg["result"])
# 6. logging the final result and comptute the statstics
elapsed_time = time.time() - start_time
audio_info = soundfile.info(wavfile_path)
logger.info("client final receive msg={}".format(msg))
logger.info(
f"audio duration: {audio_info.duration}, elapsed time: {elapsed_time}, RTF={elapsed_time/audio_info.duration}"
)
result = msg
return result
......@@ -456,3 +464,96 @@ class TTSHttpHandler:
self.stream.stop_stream()
self.stream.close()
self.p.terminate()
class VectorHttpHandler:
def __init__(self, server_ip=None, port=None):
"""The Vector client http request
Args:
server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1".
port (int, optional): the http vector server port. Defaults to 8090.
"""
super().__init__()
self.server_ip = server_ip
self.port = port
if server_ip is None or port is None:
self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector'
def run(self, input, audio_format, sample_rate, task="spk"):
"""Call the http asr to process the audio
Args:
input (str): the audio file path
audio_format (str): the audio format
sample_rate (str): the audio sample rate
Returns:
list: the audio vector
"""
if self.url is None:
logger.error("No vector server, please input valid ip and port")
return ""
audio = wav2base64(input)
data = {
"audio": audio,
"task": task,
"audio_format": audio_format,
"sample_rate": sample_rate,
}
logger.info(self.url)
res = requests.post(url=self.url, data=json.dumps(data))
return res.json()
class VectorScoreHttpHandler:
def __init__(self, server_ip=None, port=None):
"""The Vector score client http request
Args:
server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1".
port (int, optional): the http vector server port. Defaults to 8090.
"""
super().__init__()
self.server_ip = server_ip
self.port = port
if server_ip is None or port is None:
self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector/score'
def run(self, enroll_audio, test_audio, audio_format, sample_rate):
"""Call the http asr to process the audio
Args:
input (str): the audio file path
audio_format (str): the audio format
sample_rate (str): the audio sample rate
Returns:
list: the audio vector
"""
if self.url is None:
logger.error("No vector server, please input valid ip and port")
return ""
enroll_audio = wav2base64(enroll_audio)
test_audio = wav2base64(test_audio)
data = {
"enroll_audio": enroll_audio,
"test_audio": test_audio,
"task": "score",
"audio_format": audio_format,
"sample_rate": sample_rate,
}
res = requests.post(url=self.url, data=json.dumps(data))
return res.json()
......@@ -78,12 +78,14 @@ async def websocket_endpoint(websocket: WebSocket):
connection_handler.decode(is_finished=True)
connection_handler.rescoring()
asr_results = connection_handler.get_result()
word_time_stamp = connection_handler.get_word_time_stamp()
connection_handler.reset()
resp = {
"status": "ok",
"signal": "finished",
'result': asr_results
'result': asr_results,
'times': word_time_stamp
}
await websocket.send_json(resp)
break
......
......@@ -155,7 +155,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_1.jit.pdmodel \
--to_float32=true \
--streaming_chunk=30 \
--param_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \
......
......@@ -19,6 +19,7 @@
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_int32(sample_rate, 16000, "sample rate");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
......@@ -30,7 +31,8 @@ int main(int argc, char* argv[]) {
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int sample_rate = 16000;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
......
......@@ -69,6 +69,7 @@ int main(int argc, char* argv[]) {
feat_cache_opts.frame_chunk_stride = 1;
feat_cache_opts.frame_chunk_size = 1;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "fbank: " << true;
LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
......
......@@ -56,6 +56,7 @@ int main(int argc, char* argv[]) {
opt.frame_opts.remove_dc_offset = false;
opt.frame_opts.window_type = "hanning";
opt.frame_opts.preemph_coeff = 0.0;
LOG(INFO) << "linear feature: " << true;
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
......@@ -77,7 +78,7 @@ int main(int argc, char* argv[]) {
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "sample rate: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
......
......@@ -63,7 +63,6 @@ websocket_server_main \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_1.jit.pdmodel \
--streaming_chunk=0.1 \
--to_float32=true \
--param_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
......
......@@ -19,23 +19,24 @@
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
// feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear
// feature, or fbank");
DEFINE_int32(num_bins, 161, "num bins of mel");
DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
DEFINE_bool(to_float32, true, "audio convert to pcm32");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
// feature sliding window
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
// nnet
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
......@@ -47,8 +48,14 @@ DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_bool(use_fbank, false, "use fbank or linear feature");
DEFINE_int32(num_bins, 161, "num bins of mel");
// decoder
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
namespace ppspeech {
// todo refactor later
......@@ -56,22 +63,23 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk;
opts.to_float32 = FLAGS_to_float32;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.dither = 0.0;
frame_opts.frame_shift_ms = 10;
opts.use_fbank = FLAGS_use_fbank;
if (opts.use_fbank) {
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
opts.fbank_opts.fbank_opts.frame_opts = frame_opts;
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
opts.fbank_opts.fbank_opts.frame_opts = frame_opts;
} else {
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length;
opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate;
......
......@@ -102,13 +102,16 @@ bool Fbank::Compute(const Vector<BaseFloat>& waves, Vector<BaseFloat>* feats) {
// note: this online feature-extraction code does not support VTLN.
RealFft(&window, true);
kaldi::ComputePowerSpectrum(&window);
const kaldi::MelBanks &mel_bank = *(computer_.GetMelBanks(1.0));
SubVector<BaseFloat> power_spectrum(window, 0, window.Dim() / 2 + 1);
const kaldi::MelBanks& mel_bank = *(computer_.GetMelBanks(1.0));
SubVector<BaseFloat> power_spectrum(window, 0, window.Dim() / 2 + 1);
if (!opts_.fbank_opts.use_power) {
power_spectrum.ApplyPow(0.5);
}
int32 mel_offset = ((opts_.fbank_opts.use_energy && !opts_.fbank_opts.htk_compat) ? 1 : 0);
SubVector<BaseFloat> mel_energies(this_feature, mel_offset, opts_.fbank_opts.mel_opts.num_bins);
int32 mel_offset =
((opts_.fbank_opts.use_energy && !opts_.fbank_opts.htk_compat) ? 1
: 0);
SubVector<BaseFloat> mel_energies(
this_feature, mel_offset, opts_.fbank_opts.mel_opts.num_bins);
mel_bank.Compute(power_spectrum, &mel_energies);
mel_energies.ApplyFloor(1e-07);
mel_energies.ApplyLog();
......
......@@ -23,13 +23,13 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
new ppspeech::AudioCache(1000 * kint16max, opts.to_float32));
unique_ptr<FrontendInterface> base_feature;
if (opts.use_fbank) {
base_feature.reset(new ppspeech::Fbank(opts.fbank_opts,
std::move(data_source)));
base_feature.reset(
new ppspeech::Fbank(opts.fbank_opts, std::move(data_source)));
} else {
base_feature.reset(new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts,
std::move(data_source)));
base_feature.reset(new ppspeech::LinearSpectrogram(
opts.linear_spectrogram_opts, std::move(data_source)));
}
unique_ptr<FrontendInterface> cmvn(
......
......@@ -18,25 +18,25 @@
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/fbank.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/fbank.h"
#include "frontend/audio/normalizer.h"
namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file;
bool to_float32;
bool to_float32; // true, only for linear feature
bool use_fbank;
LinearSpectrogramOptions linear_spectrogram_opts;
FbankOptions fbank_opts;
FeatureCacheOptions feature_cache_opts;
FeaturePipelineOptions()
: cmvn_file(""),
to_float32(false),
use_fbank(false),
to_float32(false), // true, only for linear feature
use_fbank(true),
linear_spectrogram_opts(),
fbank_opts(),
feature_cache_opts() {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册