未验证 提交 c938a450 编写于 作者: H Hui Zhang 提交者: GitHub

Merge branch 'develop' into ngram

......@@ -27,20 +27,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt
fi
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
--voc=style_melgan_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
# hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
......
......@@ -28,7 +28,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt
fi
# hifigan
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
......
......@@ -109,6 +109,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
--phones_dict=dump/phone_id_map.txt #\
# --inference_dir=${train_output_path}/inference
fi
......@@ -26,7 +26,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi
# hifigan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
......
......@@ -40,7 +40,6 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor']
@cli_register(
name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor):
......@@ -125,6 +124,7 @@ class ASRExecutor(BaseExecutor):
"""
Init model and other resources from a specific path.
"""
logger.info("start to init the model")
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
......@@ -140,14 +140,15 @@ class ASRExecutor(BaseExecutor):
res_path,
self.pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
......@@ -176,7 +177,6 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method
else:
raise Exception("wrong type")
model_name = model_type[:model_type.rindex(
......@@ -254,12 +254,14 @@ class ASRExecutor(BaseExecutor):
else:
raise Exception("wrong type")
logger.info("audio feat process success")
@paddle.no_grad()
def infer(self, model_type: str):
"""
Model inference and result stored in self.output.
"""
logger.info("start to infer the model to get the output")
cfg = self.config.decode
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
......@@ -276,17 +278,22 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type:
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0]
logger.info(f"we will use the transformer like model : {model_type}")
try:
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0]
except Exception as e:
logger.exception(e)
else:
raise Exception("invalid model name")
......
......@@ -88,6 +88,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer":
"paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
"wenetspeech":
......
......@@ -279,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
......@@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1:
logger.fatal(
logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1'
)
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
if decoding_method == 'attention':
hyps = self.recognize(
feats,
......
......@@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
# init once
if self._ext_scorer is not None:
return
if language_model_path != '':
logger.info("begin to initialize the external scorer "
"for decoding")
......
......@@ -35,3 +35,16 @@
```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
```
## Online ASR Server
### Lanuch online asr server
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### Access online asr server
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```
\ No newline at end of file
......@@ -35,3 +35,17 @@
```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
```
## 流式ASR
### 启动流式语音识别服务
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### 访问流式语音识别服务
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```
\ No newline at end of file
......@@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor):
lang=lang,
audio_format=audio_format)
time_end = time.time()
logger.info(res.json())
logger.info(res)
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to speech recognition.")
logger.error(e)
return False
@stats_wrapper
......@@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor):
logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input))
res = loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished")
return res['asr_results']
@cli_client_register(
name='paddlespeech_client.cls', description='visit cls service')
......
......@@ -41,11 +41,7 @@ asr_online:
shift_ms: 40
sample_rate: 16000
sample_width: 2
vad_conf:
aggressiveness: 2
sample_rate: 16000
frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 20 # ms
shift_ms: 10 # ms
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer_online_multicn'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add
__all__ = ['CTCPrefixBeamSearch']
class CTCPrefixBeamSearch:
def __init__(self, config):
"""Implement the ctc prefix beam search
Args:
config (yacs.config.CfgNode): _description_
"""
self.config = config
self.reset()
@paddle.no_grad()
def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0).
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger.info("start to ctc prefix search")
batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
logger.info("ctc prefix search success")
return self.hyps
def get_one_best_hyps(self):
"""Return the one best result
Returns:
list: the one best result
"""
return [self.hyps[0][0]]
def get_hyps(self):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
"""
pass
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -34,10 +34,9 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0:
chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size!= 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0
......@@ -48,7 +47,6 @@ class ASRAudioHandler:
assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_size
end = start + chunk_size
......@@ -57,7 +55,11 @@ class ASRAudioHandler:
async def run(self, wavfile_path: str):
logging.info("send a message to the server")
# self.read_wave()
# send websocket handshake protocal
async with websockets.connect(self.url) as ws:
# server has already received handshake protocal
# client start to send the command
audio_info = json.dumps(
{
"name": "test.wav",
......@@ -78,7 +80,6 @@ class ASRAudioHandler:
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
result = msg
# finished
audio_info = json.dumps(
{
......@@ -91,10 +92,12 @@ class ASRAudioHandler:
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
# decode the bytes to str
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
return result
logging.info("final receive msg={}".format(msg))
result = msg
return result
def main(args):
......
......@@ -63,12 +63,12 @@ class ChunkBuffer(object):
the sample rate.
Yields Frames of the requested duration.
"""
audio = self.remained_audio + audio
self.remained_audio = b''
offset = 0
timestamp = 0.0
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec)
......
......@@ -52,6 +52,10 @@ def get_chunks(data, block_size, pad_size, step):
Returns:
list: chunks list
"""
if block_size == -1:
return [data]
if step == "am":
data_len = data.shape[1]
elif step == "voc":
......
......@@ -13,12 +13,12 @@
# limitations under the License.
import json
import numpy as np
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio
......@@ -28,26 +28,29 @@ router = APIRouter()
@router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
connection_handler = None
# init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width'])
window_n=chunk_buffer_conf.window_n,
shift_n=chunk_buffer_conf.shift_n,
window_ms=chunk_buffer_conf.window_ms,
shift_ms=chunk_buffer_conf.shift_ms,
sample_rate=chunk_buffer_conf.sample_rate,
sample_width=chunk_buffer_conf.sample_width)
# init vad
vad_conf = asr_engine.config.vad_conf
vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
vad_conf = asr_engine.config.get('vad_conf', None)
if vad_conf:
vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
try:
while True:
......@@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket):
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp)
elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection
asr_engine.reset()
resp = {"status": "ok", "signal": "finished"}
connection_handler.decode(is_finished=True)
connection_handler.rescoring()
asr_results = connection_handler.get_result()
connection_handler.reset()
resp = {
"status": "ok",
"signal": "finished",
'asr_results': asr_results
}
await websocket.send_json(resp)
break
else:
......@@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message:
message = message["bytes"]
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""
frames = chunk_buffer.frame_generator(message)
for frame in frames:
samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False)
asr_results = connection_handler.get_result()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect:
pass
......@@ -14,6 +14,7 @@
import argparse
from pathlib import Path
import paddle
import soundfile as sf
from timer import timer
......@@ -101,21 +102,35 @@ def parse_args():
# only inference for models trained with csmsc now
def main():
args = parse_args()
paddle.set_device(args.device)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor
am_predictor = get_predictor(args, filed='am')
am_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + ".pdmodel",
params_file=args.am + ".pdiparams",
device=args.device)
# model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor
voc_predictor = get_predictor(args, filed='voc')
voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True
fs = 24000 if am_dataset != 'ljspeech' else 22050
......@@ -123,11 +138,13 @@ def main():
for utt_id, sentence in sentences[:3]:
with timer() as t:
am_output_data = get_am_output(
args,
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
input=sentence)
speaker_dict=args.speaker_dict, )
wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data)
speed = wav.size / t.elapse
......@@ -143,11 +160,13 @@ def main():
for utt_id, sentence in sentences:
with timer() as t:
am_output_data = get_am_output(
args,
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
input=sentence)
speaker_dict=args.speaker_dict, )
wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data)
......
......@@ -15,6 +15,7 @@ import argparse
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
from timer import timer
......@@ -25,7 +26,6 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_predictor
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_predictor
from paddlespeech.t2s.exps.syn_utils import get_voc_output
from paddlespeech.t2s.utils import str2bool
......@@ -101,23 +101,47 @@ def parse_args():
# only inference for models trained with csmsc now
def main():
args = parse_args()
paddle.set_device(args.device)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor
am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor = get_streaming_am_predictor(
args)
am_encoder_infer_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".pdmodel",
params_file=args.am + "_am_encoder_infer" + ".pdiparams",
device=args.device)
am_decoder_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".pdmodel",
params_file=args.am + "_am_decoder" + ".pdiparams",
device=args.device)
am_postnet_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".pdmodel",
params_file=args.am + "_am_postnet" + ".pdiparams",
device=args.device)
am_mu, am_std = np.load(args.am_stat)
# model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor
voc_predictor = get_predictor(args, filed='voc')
voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True
......@@ -126,13 +150,13 @@ def main():
for utt_id, sentence in sentences[:3]:
with timer() as t:
normalized_mel = get_streaming_am_output(
args,
input=sentence,
am_encoder_infer_predictor=am_encoder_infer_predictor,
am_decoder_predictor=am_decoder_predictor,
am_postnet_predictor=am_postnet_predictor,
frontend=frontend,
merge_sentences=merge_sentences,
input=sentence)
lang=args.lang,
merge_sentences=merge_sentences, )
mel = denorm(normalized_mel, am_mu, am_std)
wav = get_voc_output(voc_predictor=voc_predictor, input=mel)
speed = wav.size / t.elapse
......
......@@ -16,6 +16,7 @@ from pathlib import Path
import jsonlines
import numpy as np
import paddle
import soundfile as sf
from timer import timer
......@@ -25,12 +26,13 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args):
# construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset)
test_dataset = get_test_dataset(test_metadata=test_metadata, am=args.am)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
......@@ -38,10 +40,18 @@ def ort_predict(args):
fs = 24000 if am_dataset != 'ljspeech' else 22050
# am
am_sess = get_sess(args, filed='am')
am_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder
voc_sess = get_sess(args, filed='voc')
voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# am warmup
for T in [27, 38, 54]:
......@@ -135,6 +145,8 @@ def parse_args():
def main():
args = parse_args()
paddle.set_device(args.device)
ort_predict(args)
......
......@@ -15,6 +15,7 @@ import argparse
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
from timer import timer
......@@ -27,21 +28,31 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args):
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050
# am
am_sess = get_sess(args, filed='am')
am_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder
voc_sess = get_sess(args, filed='voc')
voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup
# Loading model cost 0.5+ seconds
......@@ -168,6 +179,8 @@ def parse_args():
def main():
args = parse_args()
paddle.set_device(args.device)
ort_predict(args)
......
......@@ -15,6 +15,7 @@ import argparse
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
from timer import timer
......@@ -23,30 +24,50 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_sess
from paddlespeech.t2s.utils import str2bool
def ort_predict(args):
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050
# am
am_encoder_infer_sess, am_decoder_sess, am_postnet_sess = get_streaming_am_sess(
args)
# streaming acoustic model
am_encoder_infer_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_decoder_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_postnet_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_mu, am_std = np.load(args.am_stat)
# vocoder
voc_sess = get_sess(args, filed='voc')
voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup
# Loading model cost 0.5+ seconds
......@@ -226,6 +247,8 @@ def parse_args():
def main():
args = parse_args()
paddle.set_device(args.device)
ort_predict(args)
......
......@@ -14,6 +14,10 @@
import math
import os
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import numpy as np
import onnxruntime as ort
......@@ -21,6 +25,7 @@ import paddle
from paddle import inference
from paddle import jit
from paddle.static import InputSpec
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.datasets.data_table import DataTable
......@@ -70,7 +75,7 @@ def denorm(data, mean, std):
return data * std + mean
def get_chunks(data, chunk_size, pad_size):
def get_chunks(data, chunk_size: int, pad_size: int):
data_len = data.shape[1]
chunks = []
n = math.ceil(data_len / chunk_size)
......@@ -82,28 +87,34 @@ def get_chunks(data, chunk_size, pad_size):
# input
def get_sentences(args):
def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
# construct dataset for evaluation
sentences = []
with open(args.text, 'rt') as f:
with open(text_file, 'rt') as f:
for line in f:
items = line.strip().split()
utt_id = items[0]
if 'lang' in args and args.lang == 'zh':
if lang == 'zh':
sentence = "".join(items[1:])
elif 'lang' in args and args.lang == 'en':
elif lang == 'en':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence))
return sentences
def get_test_dataset(args, test_metadata, am_name, am_dataset):
def get_test_dataset(test_metadata: List[Dict[str, Any]],
am: str,
speaker_dict: Optional[os.PathLike]=None,
voice_cloning: bool=False):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2':
fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
print("multiple speaker fastspeech2!")
fields += ["spk_id"]
elif 'voice_cloning' in args and args.voice_cloning:
elif voice_cloning:
print("voice cloning!")
fields += ["spk_emb"]
else:
......@@ -112,7 +123,7 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
fields = ["utt_id", "phones", "tones"]
elif am_name == 'tacotron2':
fields = ["utt_id", "text"]
if 'voice_cloning' in args and args.voice_cloning:
if voice_cloning:
print("voice cloning!")
fields += ["spk_emb"]
......@@ -121,12 +132,14 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
# frontend
def get_frontend(args):
if 'lang' in args and args.lang == 'zh':
def get_frontend(lang: str='zh',
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None):
if lang == 'zh':
frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
elif 'lang' in args and args.lang == 'en':
frontend = English(phone_vocab_path=args.phones_dict)
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
elif lang == 'en':
frontend = English(phone_vocab_path=phones_dict)
else:
print("wrong lang!")
print("frontend done!")
......@@ -134,30 +147,37 @@ def get_frontend(args):
# dygraph
def get_am_inference(args, am_config):
with open(args.phones_dict, "r") as f:
def get_am_inference(
am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, ):
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
tone_size = None
if 'tones_dict' in args and args.tones_dict:
with open(args.tones_dict, "r") as f:
if tones_dict is not None:
with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None
if 'speaker_dict' in args and args.speaker_dict:
with open(args.speaker_dict, 'rt') as f:
if speaker_dict is not None:
with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = am_config.n_mels
# model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias)
......@@ -174,34 +194,38 @@ def get_am_inference(args, am_config):
elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval()
am_mu, am_std = np.load(args.am_stat)
am_mu, am_std = np.load(am_stat)
am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am)
am_inference.eval()
print("acoustic model done!")
return am_inference, am_name, am_dataset
return am_inference
def get_voc_inference(args, voc_config):
def get_voc_inference(
voc: str='pwgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, ):
# model: {model_name}_{dataset}
voc_name = args.voc[:args.voc.rindex('_')]
voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
if voc_name != 'wavernn':
voc = voc_class(**voc_config["generator_params"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"])
voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"])
voc.remove_weight_norm()
voc.eval()
else:
voc = voc_class(**voc_config["model"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"])
voc.set_state_dict(paddle.load(voc_ckpt)["main_params"])
voc.eval()
voc_mu, voc_std = np.load(args.voc_stat)
voc_mu, voc_std = np.load(voc_stat)
voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std)
......@@ -211,10 +235,16 @@ def get_voc_inference(args, voc_config):
return voc_inference
# to static
def am_to_static(args, am_inference, am_name, am_dataset):
# dygraph to static graph
def am_to_static(am_inference,
am: str='fastspeech2_csmsc',
inference_dir=Optional[os.PathLike],
speaker_dict: Optional[os.PathLike]=None):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
......@@ -226,7 +256,7 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
......@@ -247,56 +277,64 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am))
am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am))
paddle.jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = paddle.jit.load(os.path.join(inference_dir, am))
return am_inference
def voc_to_static(args, voc_inference):
def voc_to_static(voc_inference,
voc: str='pwgan_csmsc',
inference_dir=Optional[os.PathLike]):
voc_inference = jit.to_static(
voc_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32),
])
paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc))
voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc))
paddle.jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = paddle.jit.load(os.path.join(inference_dir, voc))
return voc_inference
# inference
def get_predictor(args, filed='am'):
full_name = ''
if filed == 'am':
full_name = args.am
elif filed == 'voc':
full_name = args.voc
def get_predictor(model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
params_file: Optional[os.PathLike]=None,
device: str='cpu'):
config = inference.Config(
str(Path(args.inference_dir) / (full_name + ".pdmodel")),
str(Path(args.inference_dir) / (full_name + ".pdiparams")))
if args.device == "gpu":
str(Path(model_dir) / model_file), str(Path(model_dir) / params_file))
if device == "gpu":
config.enable_use_gpu(100, 0)
elif args.device == "cpu":
elif device == "cpu":
config.disable_gpu()
config.enable_memory_optim()
predictor = inference.create_predictor(config)
return predictor
def get_am_output(args, am_predictor, frontend, merge_sentences, input):
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
def get_am_output(
input: str,
am_predictor,
am,
frontend,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
spk_id: int=0, ):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names()
get_tone_ids = False
get_spk_id = False
if am_name == 'speedyspeech':
get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict:
get_spk_id = True
spk_id = np.array([args.spk_id])
if args.lang == 'zh':
spk_id = np.array([spk_id])
if lang == 'zh':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'en':
elif lang == 'en':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
......@@ -338,50 +376,6 @@ def get_voc_output(voc_predictor, input):
return wav
# streaming am
def get_streaming_am_predictor(args):
full_name = args.am
am_encoder_infer_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdiparams")))
am_decoder_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdiparams")))
am_postnet_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdiparams")))
if args.device == "gpu":
am_encoder_infer_config.enable_use_gpu(100, 0)
am_decoder_config.enable_use_gpu(100, 0)
am_postnet_config.enable_use_gpu(100, 0)
elif args.device == "cpu":
am_encoder_infer_config.disable_gpu()
am_decoder_config.disable_gpu()
am_postnet_config.disable_gpu()
am_encoder_infer_config.enable_memory_optim()
am_decoder_config.enable_memory_optim()
am_postnet_config.enable_memory_optim()
am_encoder_infer_predictor = inference.create_predictor(
am_encoder_infer_config)
am_decoder_predictor = inference.create_predictor(am_decoder_config)
am_postnet_predictor = inference.create_predictor(am_postnet_config)
return am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor
def get_am_sublayer_output(am_sublayer_predictor, input):
am_sublayer_input_names = am_sublayer_predictor.get_input_names()
input_handle = am_sublayer_predictor.get_input_handle(
......@@ -397,11 +391,15 @@ def get_am_sublayer_output(am_sublayer_predictor, input):
return am_sublayer_output
def get_streaming_am_output(args, am_encoder_infer_predictor,
am_decoder_predictor, am_postnet_predictor,
frontend, merge_sentences, input):
def get_streaming_am_output(input: str,
am_encoder_infer_predictor,
am_decoder_predictor,
am_postnet_predictor,
frontend,
lang: str='zh',
merge_sentences: bool=True):
get_tone_ids = False
if args.lang == 'zh':
if lang == 'zh':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
......@@ -423,58 +421,27 @@ def get_streaming_am_output(args, am_encoder_infer_predictor,
return normalized_mel
def get_sess(args, filed='am'):
full_name = ''
if filed == 'am':
full_name = args.am
elif filed == 'voc':
full_name = args.voc
model_dir = str(Path(args.inference_dir) / (full_name + ".onnx"))
# onnx
def get_sess(model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
device: str='cpu',
cpu_threads: int=1,
use_trt: bool=False):
model_dir = str(Path(model_dir) / model_file)
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu":
if device == "gpu":
# fastspeech2/mb_melgan can't use trt now!
if args.use_trt:
if use_trt:
providers = ['TensorrtExecutionProvider']
else:
providers = ['CUDAExecutionProvider']
elif args.device == "cpu":
elif device == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads
sess_options.intra_op_num_threads = cpu_threads
sess = ort.InferenceSession(
model_dir, providers=providers, sess_options=sess_options)
return sess
# streaming am
def get_streaming_am_sess(args):
full_name = args.am
am_encoder_infer_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_encoder_infer" + ".onnx"))
am_decoder_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_decoder" + ".onnx"))
am_postnet_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_postnet" + ".onnx"))
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu":
# fastspeech2/mb_melgan can't use trt now!
if args.use_trt:
providers = ['TensorrtExecutionProvider']
else:
providers = ['CUDAExecutionProvider']
elif args.device == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads
am_encoder_infer_sess = ort.InferenceSession(
am_encoder_infer_model_dir,
providers=providers,
sess_options=sess_options)
am_decoder_sess = ort.InferenceSession(
am_decoder_model_dir, providers=providers, sess_options=sess_options)
am_postnet_sess = ort.InferenceSession(
am_postnet_model_dir, providers=providers, sess_options=sess_options)
return am_encoder_infer_sess, am_decoder_sess, am_postnet_sess
......@@ -50,11 +50,29 @@ def evaluate(args):
print(voc_config)
# acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config)
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
test_dataset = get_test_dataset(
test_metadata=test_metadata,
am=args.am,
speaker_dict=args.speaker_dict,
voice_cloning=args.voice_cloning)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
......
......@@ -42,24 +42,48 @@ def evaluate(args):
print(am_config)
print(voc_config)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static
if args.inference_dir:
# acoustic model
am_inference = am_to_static(args, am_inference, am_name, am_dataset)
am_inference = am_to_static(
am_inference=am_inference,
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
# vocoder
voc_inference = voc_to_static(args, voc_inference)
voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
......
......@@ -49,10 +49,13 @@ def evaluate(args):
print(am_config)
print(voc_config)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
......@@ -60,7 +63,6 @@ def evaluate(args):
print("vocab_size:", vocab_size)
# acoustic model, only support fastspeech2 here now!
# am_inference, am_name, am_dataset = get_am_inference(args, am_config)
# model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
......@@ -80,7 +82,11 @@ def evaluate(args):
am_postnet = am.postnet
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static
if args.inference_dir:
......@@ -115,7 +121,10 @@ def evaluate(args):
os.path.join(args.inference_dir, args.am + "_am_postnet"))
# vocoder
voc_inference = voc_to_static(args, voc_inference)
voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
......
......@@ -66,10 +66,19 @@ def voice_cloning(args):
print("frontend done!")
# acoustic model
am_inference, *_ = get_am_inference(args, am_config)
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
......
......@@ -58,8 +58,7 @@ def main():
else:
print("ngpu should >= 0 !")
model = WaveRNN(
hop_length=config.n_shift, sample_rate=config.fs, **config["model"])
model = WaveRNN(**config["model"])
state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"])
......
......@@ -91,3 +91,199 @@ class LogSoftmaxWrapper(nn.Layer):
predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum()
return loss
class NCELoss(nn.Layer):
"""Noise Contrastive Estimation loss funtion
Noise Contrastive Estimation (NCE) is an approximation method that is used to
work around the huge computational cost of large softmax layer.
The basic idea is to convert the prediction problem into classification problem
at training stage. It has been proved that these two criterions converges to
the same minimal point as long as noise distribution is close enough to real one.
NCE bridges the gap between generative models and discriminative models,
rather than simply speedup the softmax layer.
With NCE, you can turn almost anything into posterior with less effort (I think).
Refs:
NCE:http://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann10AISTATS.pdf
Thanks: https://github.com/mingen-pan/easy-to-use-NCE-RNN-for-Pytorch/blob/master/nce.py
Examples:
Q = Q_from_tokens(output_dim)
NCELoss(Q)
"""
def __init__(self, Q, noise_ratio=100, Z_offset=9.5):
"""Noise Contrastive Estimation loss funtion
Args:
Q (tensor): prior model, uniform or guassian
noise_ratio (int, optional): noise sampling times. Defaults to 100.
Z_offset (float, optional): scale of post processing the score. Defaults to 9.5.
"""
super(NCELoss, self).__init__()
assert type(noise_ratio) is int
self.Q = paddle.to_tensor(Q, stop_gradient=False)
self.N = self.Q.shape[0]
self.K = noise_ratio
self.Z_offset = Z_offset
def forward(self, output, target):
"""Forward inference
Args:
output (tensor): the model output, which is the input of loss function
"""
output = paddle.reshape(output, [-1, self.N])
B = output.shape[0]
noise_idx = self.get_noise(B)
idx = self.get_combined_idx(target, noise_idx)
P_target, P_noise = self.get_prob(idx, output, sep_target=True)
Q_target, Q_noise = self.get_Q(idx)
loss = self.nce_loss(P_target, P_noise, Q_noise, Q_target)
return loss.mean()
def get_Q(self, idx, sep_target=True):
"""Get prior model of batchsize data
"""
idx_size = idx.size
prob_model = paddle.to_tensor(
self.Q.numpy()[paddle.reshape(idx, [-1]).numpy()])
prob_model = paddle.reshape(prob_model, [idx.shape[0], idx.shape[1]])
if sep_target:
return prob_model[:, 0], prob_model[:, 1:]
else:
return prob_model
def get_prob(self, idx, scores, sep_target=True):
"""Post processing the score of post model(output of nn) of batchsize data
"""
scores = self.get_scores(idx, scores)
scale = paddle.to_tensor([self.Z_offset], dtype='float64')
scores = paddle.add(scores, -scale)
prob = paddle.exp(scores)
if sep_target:
return prob[:, 0], prob[:, 1:]
else:
return prob
def get_scores(self, idx, scores):
"""Get the score of post model(output of nn) of batchsize data
"""
B, N = scores.shape
K = idx.shape[1]
idx_increment = paddle.to_tensor(
N * paddle.reshape(paddle.arange(B), [B, 1]) * paddle.ones([1, K]),
dtype="int64",
stop_gradient=False)
new_idx = idx_increment + idx
new_scores = paddle.index_select(
paddle.reshape(scores, [-1]), paddle.reshape(new_idx, [-1]))
return paddle.reshape(new_scores, [B, K])
def get_noise(self, batch_size, uniform=True):
"""Select noise sample
"""
if uniform:
noise = np.random.randint(self.N, size=self.K * batch_size)
else:
noise = np.random.choice(
self.N, self.K * batch_size, replace=True, p=self.Q.data)
noise = paddle.to_tensor(noise, dtype='int64', stop_gradient=False)
noise_idx = paddle.reshape(noise, [batch_size, self.K])
return noise_idx
def get_combined_idx(self, target_idx, noise_idx):
"""Combined target and noise
"""
target_idx = paddle.reshape(target_idx, [-1, 1])
return paddle.concat((target_idx, noise_idx), 1)
def nce_loss(self, prob_model, prob_noise_in_model, prob_noise,
prob_target_in_noise):
"""Combined the loss of target and noise
"""
def safe_log(tensor):
"""Safe log
"""
EPSILON = 1e-10
return paddle.log(EPSILON + tensor)
model_loss = safe_log(prob_model /
(prob_model + self.K * prob_target_in_noise))
model_loss = paddle.reshape(model_loss, [-1])
noise_loss = paddle.sum(
safe_log((self.K * prob_noise) /
(prob_noise_in_model + self.K * prob_noise)), -1)
noise_loss = paddle.reshape(noise_loss, [-1])
loss = -(model_loss + noise_loss)
return loss
class FocalLoss(nn.Layer):
"""This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=-100):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.size_average = size_average
self.ce = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction="none")
def forward(self, outputs, targets):
"""Forword inference.
Args:
outputs: input tensor
target: target label tensor
"""
ce_loss = self.ce(outputs, targets)
pt = paddle.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss
if self.size_average:
return focal_loss.mean()
else:
return focal_loss.sum()
if __name__ == "__main__":
import numpy as np
from paddlespeech.vector.utils.vector_utils import Q_from_tokens
paddle.set_device("cpu")
input_data = paddle.uniform([5, 100], dtype="float64")
label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
loss1 = FocalLoss()
loss = loss1.forward(input, label)
print("loss: %.5f" % (loss))
Q = Q_from_tokens(100)
loss2 = NCELoss(Q)
loss = loss2.forward(input, label)
print("loss: %.5f" % (loss))
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def get_chunks(seg_dur, audio_id, audio_duration):
......@@ -30,3 +31,11 @@ def get_chunks(seg_dur, audio_id, audio_duration):
for i in range(num_chunks)
]
return chunk_lst
def Q_from_tokens(token_num):
"""Get prior model, data from uniform, would support others(guassian) in future
"""
freq = [1] * token_num
Q = paddle.to_tensor(freq, dtype='float64')
return Q / Q.sum()
......@@ -63,7 +63,8 @@ include(libsndfile)
# include(boost) # not work
set(boost_SOURCE_DIR ${fc_patch}/boost-src)
set(BOOST_ROOT ${boost_SOURCE_DIR})
# #find_package(boost REQUIRED PATHS ${BOOST_ROOT})
include_directories(${boost_SOURCE_DIR})
link_directories(${boost_SOURCE_DIR}/stage/lib)
# Eigen
include(eigen)
......@@ -141,4 +142,4 @@ set(DEPS ${DEPS}
set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
add_subdirectory(speechx)
add_subdirectory(examples)
\ No newline at end of file
add_subdirectory(examples)
......@@ -2,4 +2,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)
\ No newline at end of file
add_subdirectory(decoder)
add_subdirectory(websocket)
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../
SPEECHX_ROOT=$PWD/../../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
......@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
\ No newline at end of file
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
......@@ -87,7 +87,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result
......@@ -102,7 +102,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \
......@@ -129,7 +129,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
wfst-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--params_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
......
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char
# output
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
export GLOG_logtostderr=1
# websocket client
websocket_client_main \
--wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.36
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
if [ ! -d $ckpt_dir ]; then
mkdir -p $ckpt_dir
wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $ckpt_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir
fi
export GLOG_logtostderr=1
# 3. gen cmvn
cmvn=$PWD/cmvn.ark
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
text=$data/test/text
graph_dir=./aishell_graph
if [ ! -d $graph_dir ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip aishell_graph.zip
fi
# 5. test websocket server
websocket_server_main \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_1.jit.pdmodel \
--streaming_chunk=0.1 \
--convert2PCM32=true \
--params_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2
......@@ -17,3 +17,6 @@ add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc)
target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
......@@ -34,12 +34,10 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
......@@ -58,12 +56,11 @@ int main(int argc, char* argv[]) {
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_path = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
......@@ -76,10 +73,9 @@ int main(int argc, char* argv[]) {
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.model_path = model_path;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
......@@ -125,7 +121,6 @@ int main(int argc, char* argv[]) {
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
int32 end = start + chunk_size;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
int32 num_done = 0, num_err = 0;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
}
recognizer.Decode();
sample_offset += cur_chunk_size;
}
std::string result;
result = recognizer.GetFinalResult();
recognizer.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
}
\ No newline at end of file
......@@ -73,9 +73,9 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) {
LOG(ERR) << err.what();
LOG(ERROR) << err.what();
}
return 0;
}
\ No newline at end of file
}
......@@ -32,7 +32,6 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
......@@ -66,7 +65,8 @@ int main(int argc, char* argv[]) {
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
......
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc)
target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})
add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc)
target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_client.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(host, "127.0.0.1", "host of websocket server");
DEFINE_int32(port, 201314, "port of websocket server");
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
using kaldi::int16;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::WebSocketClient client(FLAGS_host, FLAGS_port);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
const int sample_rate = 16000;
const float streaming_chunk = FLAGS_streaming_chunk;
const int chunk_sample_size = streaming_chunk * sample_rate;
for (; !wav_reader.Done(); wav_reader.Next()) {
client.SendStartSignal();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
CHECK_EQ(wave_data.SampFreq(), sample_rate);
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
const int tot_samples = waveform.Dim();
int sample_offset = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<int16> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = static_cast<int16>(waveform(sample_offset + i));
}
client.SendBinaryData(wav_chunk.data(),
wav_chunk.size() * sizeof(int16));
sample_offset += cur_chunk_size;
LOG(INFO) << "Send " << cur_chunk_size << " samples";
std::this_thread::sleep_for(
std::chrono::milliseconds(static_cast<int>(1 * 1000)));
if (cur_chunk_size < chunk_sample_size) {
client.SendEndSignal();
}
}
while (!client.Done()) {
}
std::string result = client.GetResult();
LOG(INFO) << "utt: " << utt << " " << result;
client.Join();
return 0;
}
return 0;
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_server.h"
#include "decoder/param.h"
DEFINE_int32(port, 201314, "websocket listening port");
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::WebSocketServer server(FLAGS_port, resource);
LOG(INFO) << "Listening at port " << FLAGS_port;
server.Start();
return 0;
}
......@@ -30,4 +30,10 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
\ No newline at end of file
add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/websocket
)
add_subdirectory(websocket)
......@@ -28,8 +28,10 @@
#include <sstream>
#include <stack>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "base/basic_types.h"
......
......@@ -7,5 +7,6 @@ add_library(decoder STATIC
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_tlg_decoder.cc
recognizer.cc
)
target_link_libraries(decoder PUBLIC kenlm utils fst)
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder)
......@@ -33,7 +33,6 @@ void TLGDecoder::InitDecoder() {
void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) {
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get());
}
}
......@@ -63,4 +62,4 @@ std::string TLGDecoder::GetFinalBestPath() {
}
return words;
}
}
\ No newline at end of file
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
DEFINE_bool(convert2PCM32, true, "audio convert to pcm32");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(params_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
namespace ppspeech {
// todo refactor later
FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk;
opts.convert2PCM32 = FLAGS_convert2PCM32;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.frame_length_ms = 20;
frame_opts.frame_shift_ms = 10;
frame_opts.remove_dc_offset = false;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
frame_opts.dither = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length;
opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate;
return opts;
}
ModelOptions InitModelOptions() {
ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.params_path = FLAGS_params_path;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.output_names = FLAGS_model_output_names;
return model_opts;
}
TLGDecoderOptions InitDecoderOptions() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
return decoder_opts;
}
RecognizerResource InitRecognizerResoure() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = InitFeaturePipelineOptions();
resource.model_opts = InitModelOptions();
resource.tlg_opts = InitDecoderOptions();
return resource;
}
}
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
Recognizer::Recognizer(const RecognizerResource& resource) {
// resource_ = resource;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
BaseFloat ac_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
decoder_.reset(new TLGDecoder(resource.tlg_opts));
input_finished_ = false;
}
void Recognizer::Accept(const Vector<BaseFloat>& waves) {
feature_pipeline_->Accept(waves);
}
void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); }
std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
}
void Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
bool Recognizer::IsFinished() { return input_finished_; }
void Recognizer::Reset() {
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
}
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
namespace ppspeech {
struct RecognizerResource {
FeaturePipelineOptions feature_pipeline_opts;
ModelOptions model_opts;
TLGDecoderOptions tlg_opts;
// CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale;
RecognizerResource()
: acoustic_scale(1.0),
feature_pipeline_opts(),
model_opts(),
tlg_opts() {}
};
class Recognizer {
public:
explicit Recognizer(const RecognizerResource& resouce);
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
void Decode();
std::string GetFinalResult();
void SetFinished();
bool IsFinished();
void Reset();
private:
// std::shared_ptr<RecognizerResource> resource_;
// RecognizerResource resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<TLGDecoder> decoder_;
bool input_finished_;
};
} // namespace ppspeech
\ No newline at end of file
......@@ -6,6 +6,7 @@ add_library(frontend STATIC
linear_spectrogram.cc
audio_cache.cc
feature_cache.cc
feature_pipeline.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix)
\ No newline at end of file
target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common)
......@@ -41,7 +41,7 @@ void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
ready_feed_condition_.wait(lock);
}
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + offset_) % ring_buffer_.size();
int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx);
if (convert2PCM32_)
ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx));
......
......@@ -24,7 +24,7 @@ namespace ppspeech {
class AudioCache : public FrontendInterface {
public:
explicit AudioCache(int buffer_size = 1000 * kint16max,
bool convert2PCM32 = false);
bool convert2PCM32 = true);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);
......
......@@ -23,10 +23,13 @@ using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
FeatureCache::FeatureCache(int max_size,
FeatureCache::FeatureCache(FeatureCacheOptions opts,
unique_ptr<FrontendInterface> base_extractor) {
max_size_ = max_size;
max_size_ = opts.max_size;
frame_chunk_stride_ = opts.frame_chunk_stride;
frame_chunk_size_ = opts.frame_chunk_size;
base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim();
}
void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
......@@ -44,13 +47,14 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.empty() && base_extractor_->IsFinished() == false) {
ready_read_condition_.wait(lock);
BaseFloat elapsed = timer.Elapsed() * 1000;
// todo replace 1.0 with timeout_
if (elapsed > 1.0) {
// todo refactor: wait
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
// todo replace 1 with timeout_, 1 ms
if (elapsed > 1) {
return false;
}
usleep(1000); // sleep 1 ms
usleep(100); // sleep 0.1 ms
}
if (cache_.empty()) return false;
feats->Resize(cache_.front().Dim());
......@@ -63,25 +67,41 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
// read all data from base_feature_extractor_ into cache_
bool FeatureCache::Compute() {
// compute and feed
Vector<BaseFloat> feature_chunk;
bool result = base_extractor_->Read(&feature_chunk);
Vector<BaseFloat> feature;
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
int32 joint_len = feature.Dim() + remained_feature_.Dim();
int32 num_chunk =
((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1;
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
Vector<BaseFloat> joint_feature(joint_len);
joint_feature.Range(0, remained_feature_.Dim())
.CopyFromVec(remained_feature_);
joint_feature.Range(remained_feature_.Dim(), feature.Dim())
.CopyFromVec(feature);
// feed cache
if (feature_chunk.Dim() != 0) {
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * frame_chunk_stride_ * dim_;
Vector<BaseFloat> feature_chunk(frame_chunk_size_ * dim_);
SubVector<BaseFloat> tmp(joint_feature.Data() + start,
frame_chunk_size_ * dim_);
feature_chunk.CopyFromVec(tmp);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
// feed cache
cache_.push(feature_chunk);
ready_read_condition_.notify_one();
}
ready_read_condition_.notify_one();
int32 remained_feature_len =
joint_len - num_chunk * frame_chunk_stride_ * dim_;
remained_feature_.Resize(remained_feature_len);
remained_feature_.CopyFromVec(joint_feature.Range(
frame_chunk_stride_ * num_chunk * dim_, remained_feature_len));
return result;
}
void Reset() {
// std::lock_guard<std::mutex> lock(mutex_);
return;
}
} // namespace ppspeech
\ No newline at end of file
......@@ -19,10 +19,18 @@
namespace ppspeech {
struct FeatureCacheOptions {
int32 max_size;
int32 frame_chunk_size;
int32 frame_chunk_stride;
FeatureCacheOptions()
: max_size(kint16max), frame_chunk_size(1), frame_chunk_stride(1) {}
};
class FeatureCache : public FrontendInterface {
public:
explicit FeatureCache(
int32 max_size = kint16max,
FeatureCacheOptions opts,
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
......@@ -32,12 +40,15 @@ class FeatureCache : public FrontendInterface {
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// feat dim
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() {
// std::unique_lock<std::mutex> lock(mutex_);
base_extractor_->SetFinished();
LOG(INFO) << "set finished";
// read the last chunk data
Compute();
// ready_feed_condition_.notify_one();
}
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
......@@ -52,9 +63,13 @@ class FeatureCache : public FrontendInterface {
private:
bool Compute();
int32 dim_;
size_t max_size_;
std::unique_ptr<FrontendInterface> base_extractor_;
int32 frame_chunk_size_;
int32 frame_chunk_stride_;
kaldi::Vector<kaldi::BaseFloat> remained_feature_;
std::unique_ptr<FrontendInterface> base_extractor_;
std::mutex mutex_;
std::queue<kaldi::Vector<BaseFloat>> cache_;
std::condition_variable ready_feed_condition_;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/feature_pipeline.h"
namespace ppspeech {
using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.convert2PCM32));
unique_ptr<FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts,
std::move(data_source)));
unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(linear_spectrogram)));
base_extractor_.reset(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn)));
}
} // ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file;
bool convert2PCM32;
LinearSpectrogramOptions linear_spectrogram_opts;
FeatureCacheOptions feature_cache_opts;
FeaturePipelineOptions()
: cmvn_file(""),
convert2PCM32(false),
linear_spectrogram_opts(),
feature_cache_opts() {}
};
class FeaturePipeline : public FrontendInterface {
public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
return base_extractor_->Read(feats);
}
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); }
private:
std::unique_ptr<FrontendInterface> base_extractor_;
};
}
\ No newline at end of file
......@@ -52,16 +52,16 @@ bool LinearSpectrogram::Read(Vector<BaseFloat>* feats) {
if (flag == false || input_feats.Dim() == 0) return false;
int32 feat_len = input_feats.Dim();
int32 left_len = reminded_wav_.Dim();
int32 left_len = remained_wav_.Dim();
Vector<BaseFloat> waves(feat_len + left_len);
waves.Range(0, left_len).CopyFromVec(reminded_wav_);
waves.Range(0, left_len).CopyFromVec(remained_wav_);
waves.Range(left_len, feat_len).CopyFromVec(input_feats);
Compute(waves, feats);
int32 frame_shift = opts_.frame_opts.WindowShift();
int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts);
int32 left_samples = waves.Dim() - frame_shift * num_frames;
reminded_wav_.Resize(left_samples);
reminded_wav_.CopyFromVec(
remained_wav_.Resize(left_samples);
remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples));
return true;
}
......
......@@ -25,12 +25,12 @@ struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts;
kaldi::BaseFloat streaming_chunk; // second
LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {}
LinearSpectrogramOptions() : streaming_chunk(0.1), frame_opts() {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("streaming-chunk",
&streaming_chunk,
"streaming chunk size, default: 0.36 sec");
"streaming chunk size, default: 0.1 sec");
frame_opts.Register(opts);
}
};
......@@ -48,7 +48,7 @@ class LinearSpectrogram : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
reminded_wav_.Resize(0);
remained_wav_.Resize(0);
}
private:
......@@ -60,7 +60,7 @@ class LinearSpectrogram : public FrontendInterface {
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> reminded_wav_;
kaldi::Vector<kaldi::BaseFloat> remained_wav_;
int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};
......
......@@ -78,7 +78,6 @@ bool Decodable::AdvanceChunk() {
}
int32 nnet_dim = 0;
Vector<BaseFloat> inferences;
Matrix<BaseFloat> nnet_cache_tmp;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册