未验证 提交 d7e75354 编写于 作者: H Hui Zhang 提交者: GitHub

Support paddle 2.x (#538)

* 2.x model

* model test pass

* fix data

* fix soundfile with flac support

* one thread dataloader test pass

* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist

* add venv; test under venv

* fix unittest; train and valid

* add train and config

* add config and train script

* fix ctc cuda memcopy error

* fix imports

* fix train valid log

* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code

* test process can run

* test with decoding

* test and infer with decoding

* fix infer

* fix ctc loss
lr schedule
sortagrad
logger

* aishell egs

* refactor train
add aishell egs

* fix dataset batch shuffle and add batch sampler log
print model parameter

* fix model and ctc

* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook

* ctc loss
remove run prefix
using ord value as text id

* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer

* fix tester

* add lr_deacy
refactor code

* fix tools

* fix ci
add tune
fix gru model bugs
add dataset and model test

* fix decoding

* refactor repo
fix decoding

* fix musan and rir dataset

* refactor io, loss, conv, rnn, gradclip, model, utils

* fix ci and import

* refactor model
add export jit model

* add deploy bin and test it

* rm uselss egs

* add layer tools

* refactor socket server
new model from pretrain

* remve useless

* fix instability loss and grad nan or inf for librispeech training

* fix sampler

* fix libri train.sh

* fix doc

* add license on cpp

* fix doc

* fix libri script

* fix install

* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
上级 054d795d
.DS_Store .DS_Store
*.pyc *.pyc
tools/venv
.vscode
*.log
*.pdmodel
*.pdiparams*
此差异已折叠。
此差异已折叠。
...@@ -38,4 +38,4 @@ ...@@ -38,4 +38,4 @@
entry: python .pre-commit-hooks/copyright-check.hook entry: python .pre-commit-hooks/copyright-check.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=decoders/swig).*(\.cpp|\.h)$ #exclude: (?=decoders/swig).*(\.cpp|\.h)$
\ No newline at end of file
language: cpp language: cpp
cache: ccache cache: ccache
sudo: required sudo: required
dist: xenial dist: Bionic
services: services:
- docker - docker
os: os:
...@@ -26,7 +26,7 @@ script: ...@@ -26,7 +26,7 @@ script:
- exit_code=0 - exit_code=0
- .travis/precommit.sh || exit_code=$(( exit_code | $? )) - .travis/precommit.sh || exit_code=$(( exit_code | $? ))
- docker run -i --rm -v "$PWD:/py_unittest" paddlepaddle/paddle:latest /bin/bash -c - docker run -i --rm -v "$PWD:/py_unittest" paddlepaddle/paddle:latest /bin/bash -c
'cd /py_unittest; sh .travis/unittest.sh' || exit_code=$(( exit_code | $? )) 'cd /py_unittest; source env.sh; bash .travis/unittest.sh' || exit_code=$(( exit_code | $? ))
exit $exit_code exit $exit_code
notifications: notifications:
......
...@@ -15,7 +15,7 @@ unittest(){ ...@@ -15,7 +15,7 @@ unittest(){
if [ $? != 0 ]; then if [ $? != 0 ]; then
exit 1 exit 1
fi fi
find . -name 'tests' -type d -print0 | \ find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \
xargs -0 -I{} -n1 bash -c \ xargs -0 -I{} -n1 bash -c \
'python3 -m unittest discover -v -s {}' 'python3 -m unittest discover -v -s {}'
cd - > /dev/null cd - > /dev/null
...@@ -24,6 +24,10 @@ unittest(){ ...@@ -24,6 +24,10 @@ unittest(){
trap 'abort' 0 trap 'abort' 0
set -e set -e
cd tools; make; cd -
. tools/venv/bin/activate
pip3 install pytest
unittest . unittest .
trap : 0 trap : 0
此差异已折叠。
此差异已折叠。
ThreadPool/
build/
dist/
kenlm/
openfst-1.6.3/
openfst-1.6.3.tar.gz
swig_decoders.egg-info/
decoders_wrap.cxx
swig_decoders.py
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ctc_beam_search_decoder.h" #include "ctc_beam_search_decoder.h"
#include <algorithm> #include <algorithm>
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CTC_BEAM_SEARCH_DECODER_H_ #ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_ #define CTC_BEAM_SEARCH_DECODER_H_
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ctc_greedy_decoder.h" #include "ctc_greedy_decoder.h"
#include "decoder_utils.h" #include "decoder_utils.h"
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CTC_GREEDY_DECODER_H #ifndef CTC_GREEDY_DECODER_H
#define CTC_GREEDY_DECODER_H #define CTC_GREEDY_DECODER_H
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder_utils.h" #include "decoder_utils.h"
#include <algorithm> #include <algorithm>
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_UTILS_H_ #ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_ #define DECODER_UTILS_H_
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "path_trie.h" #include "path_trie.h"
#include <algorithm> #include <algorithm>
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PATH_TRIE_H #ifndef PATH_TRIE_H
#define PATH_TRIE_H #define PATH_TRIE_H
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "scorer.h" #include "scorer.h"
#include <unistd.h> #include <unistd.h>
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SCORER_H_ #ifndef SCORER_H_
#define SCORER_H_ #define SCORER_H_
......
...@@ -81,9 +81,8 @@ FILES = glob.glob('kenlm/util/*.cc') \ ...@@ -81,9 +81,8 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
FILES = [ FILES = [
fn for fn in FILES fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( or fn.endswith('unittest.cc'))
'unittest.cc'))
] ]
LIBS = ['stdc++'] LIBS = ['stdc++']
......
...@@ -46,7 +46,7 @@ def ctc_greedy_decoder(probs_seq, vocabulary): ...@@ -46,7 +46,7 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
:rtype: str :rtype: str
""" """
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
return result.decode('utf-8') return result
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Test decoders.""" """Test decoders."""
import unittest import unittest
from decoders import decoders_deprecated as decoder from deepspeech.decoders import decoders_deprecated as decoder
class TestDecoders(unittest.TestCase): class TestDecoders(unittest.TestCase):
......
...@@ -19,6 +19,8 @@ import sys ...@@ -19,6 +19,8 @@ import sys
import argparse import argparse
import pyaudio import pyaudio
from deepspeech.utils.socket_server import socket_send
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--host_ip", "--host_ip",
...@@ -61,16 +63,7 @@ def callback(in_data, frame_count, time_info, status): ...@@ -61,16 +63,7 @@ def callback(in_data, frame_count, time_info, status):
data_list.append(in_data) data_list.append(in_data)
enable_trigger_record = False enable_trigger_record = False
elif len(data_list) > 0: elif len(data_list) > 0:
# Connect to server and send data socket_send(args.host_ip, args.host_port, ''.join(data_list))
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((args.host_ip, args.host_port))
sent = ''.join(data_list)
sock.sendall(struct.pack('>i', len(sent)) + sent)
print('Speech[length=%d] Sent.' % len(sent))
# Receive data from the server and shut down
received = sock.recv(1024)
print("Recognition Results: {}".format(received))
sock.close()
data_list = [] data_list = []
enable_trigger_record = True enable_trigger_record = True
return (in_data, pyaudio.paContinue) return (in_data, pyaudio.paContinue)
...@@ -80,7 +73,7 @@ def main(): ...@@ -80,7 +73,7 @@ def main():
# prepare audio recorder # prepare audio recorder
p = pyaudio.PyAudio() p = pyaudio.PyAudio()
stream = p.open( stream = p.open(
format=pyaudio.paInt32, format=pyaudio.paInt16,
channels=1, channels=1,
rate=16000, rate=16000,
input=True, input=True,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Record wav from Microphone"""
# http://people.csail.mit.edu/hubert/pyaudio/
import pyaudio
import wave
CHUNK = 1024
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
RECORD_SECONDS = 5
WAVE_OUTPUT_FILENAME = "output.wav"
p = pyaudio.PyAudio()
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK)
print("* recording")
frames = []
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
frames.append(data)
print("* done recording")
stream.stop_stream()
stream.close()
p.terminate()
wf = wave.open(WAVE_OUTPUT_FILENAME, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(p.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Server-end for the ASR demo."""
import os
import time
import argparse
import functools
import paddle
import numpy as np
from deepspeech.utils.socket_server import warm_up_test
from deepspeech.utils.socket_server import AsrTCPServer
from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.training.cli import default_argument_parser
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.utility import add_arguments, print_arguments
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.io.dataset import ManifestDataset
from paddle.inference import Config
from paddle.inference import create_predictor
def init_predictor(args):
if args.model_dir is not None:
config = Config(args.model_dir)
else:
config = Config(args.model_file, args.params_file)
config.enable_memory_optim()
if args.use_gpu:
config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
else:
# If not specific mkldnn, you can set the blas thread.
# The thread num should not be greater than the number of cores in the CPU.
config.set_cpu_math_library_num_threads(4)
config.enable_mkldnn()
predictor = create_predictor(config)
return predictor
def run(predictor, img):
# copy img data to input tensor
input_names = predictor.get_input_names()
for i, name in enumerate(input_names):
input_tensor = predictor.get_input_handle(name)
#input_tensor.reshape(img[i].shape)
#input_tensor.copy_from_cpu(img[i].copy())
# do the inference
predictor.run()
results = []
# get out data from output tensor
output_names = predictor.get_output_names()
for i, name in enumerate(output_names):
output_tensor = predictor.get_output_handle(name)
output_data = output_tensor.copy_to_cpu()
results.append(output_data)
return results
def inference(config, args):
predictor = init_predictor(args)
def start_server(config, args):
"""Start the ASR server"""
dataset = ManifestDataset(
config.data.test_manifest,
config.data.vocab_filepath,
config.data.mean_std_filepath,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
model = DeepSpeech2Model.from_pretrained(dataset, config,
args.checkpoint_path)
model.eval()
# prepare ASR inference handler
def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
audio_len = feature[0].shape[1]
audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode(
paddle.to_tensor(audio),
paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list,
decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha,
beam_beta=config.decoding.beta,
beam_size=config.decoding.beam_size,
cutoff_prob=config.decoding.cutoff_prob,
cutoff_top_n=config.decoding.cutoff_top_n,
num_processes=config.decoding.num_proc_bsearch)
return result_transcript[0]
# warming up with utterrances sampled from Librispeech
print('-----------------------------------------------------------')
print('Warming up ...')
warm_up_test(
audio_process_handler=file_to_transcript,
manifest_path=args.warmup_manifest,
num_test_cases=3)
print('-----------------------------------------------------------')
# start the server
server = AsrTCPServer(
server_address=(args.host_ip, args.host_port),
RequestHandlerClass=AsrRequestHandler,
speech_save_dir=args.speech_save_dir,
audio_process_handler=file_to_transcript)
print("ASR Server Started.")
server.serve_forever()
def main(config, args):
start_server(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
add_arg('warmup_manifest', str, None, "Filepath of manifest to warm up.")
add_arg(
"--model_file",
type=str,
default="",
help="Model filename, Specify this when your model is a combined model."
)
add_arg(
"--params_file",
type=str,
default="",
help=
"Parameter filename, Specify this when your model is a combined model."
)
add_arg(
"--model_dir",
type=str,
default=None,
help=
"Model dir, If you load a non-combined model, specify the directory of the model."
)
add_arg("--use_gpu",
type=bool,
default=False,
help="Whether use gpu.")
args = parser.parse_args()
print_arguments(args)
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
args.warmup_manifest = config.data.test_manifest
print_arguments(args)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Socket client to send wav to ASR server."""
import struct
import socket
import argparse
import wave
from deepspeech.utils.socket_server import socket_send
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--host_ip",
default="localhost",
type=str,
help="Server IP address. (default: %(default)s)")
parser.add_argument(
"--host_port",
default=8086,
type=int,
help="Server Port. (default: %(default)s)")
args = parser.parse_args()
WAVE_OUTPUT_FILENAME = "output.wav"
def main():
wf = wave.open(WAVE_OUTPUT_FILENAME, 'rb')
nframe = wf.getnframes()
data = wf.readframes(nframe)
print(f"Wave: {WAVE_OUTPUT_FILENAME}")
print(f"Wave samples: {nframe}")
print(f"Wave channels: {wf.getnchannels()}")
print(f"Wave sample rate: {wf.getframerate()}")
print(f"Wave sample width: {wf.getsampwidth()}")
assert isinstance(data, bytes)
socket_send(args.host_ip, args.host_port, data)
if __name__ == "__main__":
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Server-end for the ASR demo."""
import os
import time
import argparse
import functools
import paddle
import numpy as np
from deepspeech.utils.socket_server import warm_up_test
from deepspeech.utils.socket_server import AsrTCPServer
from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.training.cli import default_argument_parser
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.utility import add_arguments, print_arguments
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.io.dataset import ManifestDataset
def start_server(config, args):
"""Start the ASR server"""
dataset = ManifestDataset(
config.data.test_manifest,
config.data.vocab_filepath,
config.data.mean_std_filepath,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
model = DeepSpeech2Model.from_pretrained(dataset, config,
args.checkpoint_path)
model.eval()
# prepare ASR inference handler
def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
audio_len = feature[0].shape[1]
audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode(
paddle.to_tensor(audio),
paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list,
decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha,
beam_beta=config.decoding.beta,
beam_size=config.decoding.beam_size,
cutoff_prob=config.decoding.cutoff_prob,
cutoff_top_n=config.decoding.cutoff_top_n,
num_processes=config.decoding.num_proc_bsearch)
return result_transcript[0]
# warming up with utterrances sampled from Librispeech
print('-----------------------------------------------------------')
print('Warming up ...')
warm_up_test(
audio_process_handler=file_to_transcript,
manifest_path=args.warmup_manifest,
num_test_cases=3)
print('-----------------------------------------------------------')
# start the server
server = AsrTCPServer(
server_address=(args.host_ip, args.host_port),
RequestHandlerClass=AsrRequestHandler,
speech_save_dir=args.speech_save_dir,
audio_process_handler=file_to_transcript)
print("ASR Server Started.")
server.serve_forever()
def main(config, args):
start_server(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
add_arg('warmup_manifest', str, None, "Filepath of manifest to warm up.")
args = parser.parse_args()
print_arguments(args)
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
args.warmup_manifest = config.data.test_manifest
print_arguments(args)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export for DeepSpeech2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_export()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args)
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inferer for DeepSpeech2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
# TODO(hui zhang): dynamic load
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args)
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args)
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for DeepSpeech2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist
from deepspeech.utils.utility import print_arguments
from deepspeech.training.cli import default_argument_parser
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer
def main_sp(config, args):
exp = Trainer(config, args)
exp.setup()
exp.run()
def main(config, args):
if args.device == "gpu" and args.nprocs > 1:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args)
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
...@@ -20,105 +20,60 @@ import argparse ...@@ -20,105 +20,60 @@ import argparse
import functools import functools
import gzip import gzip
import logging import logging
import paddle.fluid as fluid
import _init_paths from paddle.io import DataLoader
from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model from deepspeech.utils import error_rate
from utils.error_rate import char_errors, word_errors from deepspeech.utils.utility import add_arguments, print_arguments
from utils.utility import add_arguments, print_arguments
from deepspeech.models.deepspeech2 import DeepSpeech2Model
parser = argparse.ArgumentParser(description=__doc__) from deepspeech.io.collator import SpeechCollator
add_arg = functools.partial(add_arguments, argparser=parser) from deepspeech.io.dataset import ManifestDataset
# yapf: disable
add_arg('num_batches', int, -1, "# of batches tuning on. " from deepspeech.training.cli import default_argument_parser
"Default -1, on whole dev set.") from deepspeech.exps.deepspeech2.config import get_cfg_defaults
add_arg('batch_size', int, 256, "# of samples per batch.")
add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).")
add_arg('beam_size', int, 500, "Beam search width.") def tune(config, args):
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.")
add_arg('num_betas', int, 8, "# of beta candidates for tuning.")
add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.")
add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.")
add_arg('beta_from', float, 0.1, "Where beta starts tuning from.")
add_arg('beta_to', float, 0.45, "Where beta ends tuning with.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
"bi-directional RNNs. Not for GRU.")
add_arg('tune_manifest', str,
'data/librispeech/manifest.dev-clean',
"Filepath of manifest to tune.")
add_arg('mean_std_path', str,
'data/librispeech/mean_std.npz',
"Filepath of normalizer's mean & std.")
add_arg('vocab_path', str,
'data/librispeech/vocab.txt',
"Filepath of vocabulary.")
add_arg('lang_model_path', str,
'models/lm/common_crawl_00.prune01111.trie.klm',
"Filepath for language model.")
add_arg('model_path', str,
'./checkpoints/libri/params.latest.tar.gz',
"If None, the training starts from scratch, "
"otherwise, it resumes from the pre-trained model.")
add_arg('error_rate_type', str,
'wer',
"Error rate type for evaluation.",
choices=['wer', 'cer'])
add_arg('specgram_type', str,
'linear',
"Audio feature type. Options: linear, mfcc.",
choices=['linear', 'mfcc'])
# yapf: disable
args = parser.parse_args()
def tune():
"""Tune parameters alpha and beta incrementally.""" """Tune parameters alpha and beta incrementally."""
if not args.num_alphas >= 0: if not args.num_alphas >= 0:
raise ValueError("num_alphas must be non-negative!") raise ValueError("num_alphas must be non-negative!")
if not args.num_betas >= 0: if not args.num_betas >= 0:
raise ValueError("num_betas must be non-negative!") raise ValueError("num_betas must be non-negative!")
if args.use_gpu: dev_dataset = ManifestDataset(
place = fluid.CUDAPlace(0) config.data.dev_manifest,
else: config.data.vocab_filepath,
place = fluid.CPUPlace() config.data.mean_std_filepath,
augmentation_config="{}",
data_generator = DataGenerator( max_duration=config.data.max_duration,
vocab_filepath=args.vocab_path, min_duration=config.data.min_duration,
mean_std_filepath=args.mean_std_path, stride_ms=config.data.stride_ms,
augmentation_config='{}', window_ms=config.data.window_ms,
specgram_type=args.specgram_type, n_fft=config.data.n_fft,
keep_transcription_text=True, max_freq=config.data.max_freq,
place = place, target_sample_rate=config.data.target_sample_rate,
is_training = False) specgram_type=config.data.specgram_type,
use_dB_normalization=config.data.use_dB_normalization,
batch_reader = data_generator.batch_reader_creator( target_dB=config.data.target_dB,
manifest_path=args.tune_manifest, random_seed=config.data.random_seed,
batch_size=args.batch_size, keep_transcription_text=True)
sortagrad=False,
shuffle_method=None) valid_loader = DataLoader(
dev_dataset,
ds2_model = DeepSpeech2Model( batch_size=config.data.batch_size,
vocab_size=data_generator.vocab_size, shuffle=False,
num_conv_layers=args.num_conv_layers, drop_last=False,
num_rnn_layers=args.num_rnn_layers, collate_fn=SpeechCollator(is_training=False))
rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru, model = DeepSpeech2Model.from_pretrained(dev_dataset, config,
place=place, args.checkpoint_path)
init_from_pretrained_model=args.model_path, model.eval()
share_rnn_weights=args.share_rnn_weights)
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = [chars for chars in data_generator.vocab_list] vocab_list = valid_loader.dataset.vocab_list
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors errors_func = error_rate.char_errors if config.decoding.error_rate_type == 'cer' else error_rate.word_errors
# create grid for search # create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
...@@ -127,36 +82,50 @@ def tune(): ...@@ -127,36 +82,50 @@ def tune():
err_sum = [0.0 for i in range(len(params_grid))] err_sum = [0.0 for i in range(len(params_grid))]
err_ave = [0.0 for i in range(len(params_grid))] err_ave = [0.0 for i in range(len(params_grid))]
num_ins, len_refs, cur_batch = 0, 0, 0 num_ins, len_refs, cur_batch = 0, 0, 0
# initialize external scorer # initialize external scorer
ds2_model.init_ext_scorer(args.alpha_from, args.beta_from, model.decoder.init_decode(args.alpha_from, args.beta_from,
args.lang_model_path, vocab_list) config.decoding.lang_model_path, vocab_list,
config.decoding.decoding_method)
## incremental tuning parameters over multiple batches ## incremental tuning parameters over multiple batches
ds2_model.logger.info("start tuning ...") print("start tuning ...")
for infer_data in batch_reader(): for infer_data in valid_loader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches): if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break break
probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data,
feeding_dict=data_generator.feeding)
target_transcripts = infer_data[1]
num_ins += len(target_transcripts) def ordid2token(texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
return trans
audio, text, audio_len, text_len = infer_data
target_transcripts = ordid2token(text, text_len)
num_ins += audio.shape[0]
# model infer
eouts, eouts_len = model.encoder(audio, audio_len)
probs = model.decoder.probs(eouts)
# grid search # grid search
for index, (alpha, beta) in enumerate(params_grid): for index, (alpha, beta) in enumerate(params_grid):
result_transcripts = ds2_model.decode_batch_beam_search( print(f"tuneing: alpha={alpha} beta={beta}")
probs_split=probs_split, result_transcripts = model.decoder.decode_probs(
beam_alpha=alpha, probs.numpy(), eouts_len, vocab_list,
beam_beta=beta, config.decoding.decoding_method,
beam_size=args.beam_size, config.decoding.lang_model_path, alpha, beta,
cutoff_prob=args.cutoff_prob, config.decoding.beam_size, config.decoding.cutoff_prob,
cutoff_top_n=args.cutoff_top_n, config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch)
vocab_list=vocab_list,
num_processes=args.num_proc_bsearch)
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
err_sum[index] += errors err_sum[index] += errors
# accumulate the length of references of every batch
# accumulate the length of references of every batchπ
# in the first iteration # in the first iteration
if args.alpha_from == alpha and args.beta_from == beta: if args.alpha_from == alpha and args.beta_from == beta:
len_refs += len_ref len_refs += len_ref
...@@ -165,37 +134,77 @@ def tune(): ...@@ -165,37 +134,77 @@ def tune():
if index % 2 == 0: if index % 2 == 0:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
print(f"tuneing: one grid done!")
# output on-line tuning result at the end of current batch # output on-line tuning result at the end of current batch
err_ave_min = min(err_ave) err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min) min_index = err_ave.index(err_ave_min)
print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), " print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
" min [%s] = %f" %(cur_batch, num_ins, " min [%s] = %f" %
"%.3f" % params_grid[min_index][0], (cur_batch, num_ins, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1], "%.3f" % params_grid[min_index][1],
args.error_rate_type, err_ave_min)) config.decoding.error_rate_type, err_ave_min))
cur_batch += 1 cur_batch += 1
# output WER/CER at every (alpha, beta) # output WER/CER at every (alpha, beta)
print("\nFinal %s:\n" % args.error_rate_type) print("\nFinal %s:\n" % config.decoding.error_rate_type)
for index in range(len(params_grid)): for index in range(len(params_grid)):
print("(alpha, beta) = (%s, %s), [%s] = %f" print("(alpha, beta) = (%s, %s), [%s] = %f" %
% ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1], ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1],
args.error_rate_type, err_ave[index])) config.decoding.error_rate_type, err_ave[index]))
err_ave_min = min(err_ave) err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min) min_index = err_ave.index(err_ave_min)
print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" %
% (cur_batch, "%.3f" % params_grid[min_index][0], (cur_batch, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1])) "%.3f" % params_grid[min_index][1]))
print("finish tuning")
ds2_model.logger.info("finish tuning") def main(config, args):
tune(config, args)
def main(): if __name__ == "__main__":
parser = default_argument_parser()
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('num_batches', int, -1, "# of batches tuning on. "
"Default -1, on whole dev set.")
add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.")
add_arg('num_betas', int, 8, "# of beta candidates for tuning.")
add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.")
add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.")
add_arg('beta_from', float, 0.1, "Where beta starts tuning from.")
add_arg('beta_to', float, 0.45, "Where beta ends tuning with.")
add_arg('batch_size', int, 256, "# of samples per batch.")
add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
args = parser.parse_args()
print_arguments(args) print_arguments(args)
tune()
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.data.batch_size = args.batch_size
config.decoding.beam_size = args.beam_size
config.decoding.num_proc_bsearch = args.num_proc_bsearch
config.decoding.cutoff_prob = args.cutoff_prob
config.decoding.cutoff_top_n = args.cutoff_top_n
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
if __name__ == '__main__': main(config, args)
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from yacs.config import CfgNode as CN
from deepspeech.models.deepspeech2 import DeepSpeech2Model
_C = CN()
_C.data = CN(
dict(
train_manifest="",
dev_manifest="",
test_manifest="",
vocab_filepath="",
mean_std_filepath="",
augmentation_config="",
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
specgram_type='linear', # 'linear', 'mfcc'
target_sample_rate=16000, # sample rate
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
))
_C.model = CN(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
DeepSpeech2Model.params(_C.model)
_C.training = CN(
dict(
lr=5e-4, # learning rate
lr_decay=1.0, # learning rate decay
weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs
))
_C.decoding = CN(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width.
batch_size=128, # decoding batch size
))
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()
此差异已折叠。
...@@ -15,17 +15,17 @@ ...@@ -15,17 +15,17 @@
import json import json
import random import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor from deepspeech.frontend.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor from deepspeech.frontend.augmentor.shift_perturb import ShiftPerturbAugmentor
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor from deepspeech.frontend.augmentor.speed_perturb import SpeedPerturbAugmentor
from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor from deepspeech.frontend.augmentor.noise_perturb import NoisePerturbAugmentor
from data_utils.augmentor.impulse_response import ImpulseResponseAugmentor from deepspeech.frontend.augmentor.impulse_response import ImpulseResponseAugmentor
from data_utils.augmentor.resample import ResampleAugmentor from deepspeech.frontend.augmentor.resample import ResampleAugmentor
from data_utils.augmentor.online_bayesian_normalization import \ from deepspeech.frontend.augmentor.online_bayesian_normalization import \
OnlineBayesianNormalizationAugmentor OnlineBayesianNormalizationAugmentor
class AugmentationPipeline(object): class AugmentationPipeline():
"""Build a pre-processing pipeline with various augmentation models.Such a """Build a pre-processing pipeline with various augmentation models.Such a
data augmentation pipeline is oftern leveraged to augment the training data augmentation pipeline is oftern leveraged to augment the training
samples to make the model invariant to certain types of perturbations in the samples to make the model invariant to certain types of perturbations in the
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
class AugmentorBase(object): class AugmentorBase():
"""Abstract base class for augmentation model (augmentor) class. """Abstract base class for augmentation model (augmentor) class.
All augmentor classes should inherit from this class, and implement the All augmentor classes should inherit from this class, and implement the
following abstract methods. following abstract methods.
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Contains the impulse response augmentation model.""" """Contains the impulse response augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
from data_utils.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
class ImpulseResponseAugmentor(AugmentorBase): class ImpulseResponseAugmentor(AugmentorBase):
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Contains the noise perturb augmentation model.""" """Contains the noise perturb augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
from data_utils.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
class NoisePerturbAugmentor(AugmentorBase): class NoisePerturbAugmentor(AugmentorBase):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contain the online bayesian normalization augmentation model.""" """Contain the online bayesian normalization augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class OnlineBayesianNormalizationAugmentor(AugmentorBase): class OnlineBayesianNormalizationAugmentor(AugmentorBase):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contain the resample augmentation model.""" """Contain the resample augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class ResampleAugmentor(AugmentorBase): class ResampleAugmentor(AugmentorBase):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the volume perturb augmentation model.""" """Contains the volume perturb augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class ShiftPerturbAugmentor(AugmentorBase): class ShiftPerturbAugmentor(AugmentorBase):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contain the speech perturbation augmentation model.""" """Contain the speech perturbation augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class SpeedPerturbAugmentor(AugmentorBase): class SpeedPerturbAugmentor(AugmentorBase):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the volume perturb augmentation model.""" """Contains the volume perturb augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class VolumePerturbAugmentor(AugmentorBase): class VolumePerturbAugmentor(AugmentorBase):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
"""Contains the audio featurizer class.""" """Contains the audio featurizer class."""
import numpy as np import numpy as np
from data_utils.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
from python_speech_features import mfcc from python_speech_features import mfcc
from python_speech_features import delta from python_speech_features import delta
...@@ -52,6 +52,7 @@ class AudioFeaturizer(object): ...@@ -52,6 +52,7 @@ class AudioFeaturizer(object):
specgram_type='linear', specgram_type='linear',
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
n_fft=None,
max_freq=None, max_freq=None,
target_sample_rate=16000, target_sample_rate=16000,
use_dB_normalization=True, use_dB_normalization=True,
...@@ -63,6 +64,7 @@ class AudioFeaturizer(object): ...@@ -63,6 +64,7 @@ class AudioFeaturizer(object):
self._target_sample_rate = target_sample_rate self._target_sample_rate = target_sample_rate
self._use_dB_normalization = use_dB_normalization self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB self._target_dB = target_dB
self._fft_point = n_fft
def featurize(self, def featurize(self,
audio_segment, audio_segment,
...@@ -98,6 +100,22 @@ class AudioFeaturizer(object): ...@@ -98,6 +100,22 @@ class AudioFeaturizer(object):
return self._compute_specgram(audio_segment.samples, return self._compute_specgram(audio_segment.samples,
audio_segment.sample_rate) audio_segment.sample_rate)
@property
def feature_size(self):
"""audio feature size"""
feat_dim = 0
if self._specgram_type == 'linear':
fft_point = self._window_ms if self._fft_point is None else self._fft_point
feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 +
1)
elif self._specgram_type == 'mfcc':
# mfcc,delta, delta-delta
feat_dim = int(13 * 3)
else:
raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type)
return feat_dim
def _compute_specgram(self, samples, sample_rate): def _compute_specgram(self, samples, sample_rate):
"""Extract various audio features.""" """Extract various audio features."""
if self._specgram_type == 'linear': if self._specgram_type == 'linear':
...@@ -150,7 +168,8 @@ class AudioFeaturizer(object): ...@@ -150,7 +168,8 @@ class AudioFeaturizer(object):
windows[:, 1] == samples[stride_size:(stride_size + window_size)]) windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# window weighting, squared Fast Fourier Transform (fft), scaling # window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None] weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0) # https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html
fft = np.fft.rfft(windows * weighting, n=None, axis=0)
fft = np.absolute(fft) fft = np.absolute(fft)
fft = fft**2 fft = fft**2
scale = np.sum(weighting**2) * sample_rate scale = np.sum(weighting**2) * sample_rate
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
"""Contains the speech featurizer class.""" """Contains the speech featurizer class."""
from data_utils.featurizer.audio_featurizer import AudioFeaturizer from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
from data_utils.featurizer.text_featurizer import TextFeaturizer from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
class SpeechFeaturizer(object): class SpeechFeaturizer(object):
...@@ -56,6 +56,7 @@ class SpeechFeaturizer(object): ...@@ -56,6 +56,7 @@ class SpeechFeaturizer(object):
specgram_type='linear', specgram_type='linear',
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
n_fft=None,
max_freq=None, max_freq=None,
target_sample_rate=16000, target_sample_rate=16000,
use_dB_normalization=True, use_dB_normalization=True,
...@@ -64,6 +65,7 @@ class SpeechFeaturizer(object): ...@@ -64,6 +65,7 @@ class SpeechFeaturizer(object):
specgram_type=specgram_type, specgram_type=specgram_type,
stride_ms=stride_ms, stride_ms=stride_ms,
window_ms=window_ms, window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq, max_freq=max_freq,
target_sample_rate=target_sample_rate, target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
...@@ -106,3 +108,12 @@ class SpeechFeaturizer(object): ...@@ -106,3 +108,12 @@ class SpeechFeaturizer(object):
:rtype: list :rtype: list
""" """
return self._text_featurizer.vocab_list return self._text_featurizer.vocab_list
@property
def feature_size(self):
"""Return the audio feature size.
:return: audio feature size.
:rtype: int
"""
return self._audio_featurizer.feature_size
\ No newline at end of file
...@@ -30,6 +30,7 @@ class TextFeaturizer(object): ...@@ -30,6 +30,7 @@ class TextFeaturizer(object):
""" """
def __init__(self, vocab_filepath): def __init__(self, vocab_filepath):
self.unk = '<unk>'
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath) vocab_filepath)
...@@ -43,7 +44,11 @@ class TextFeaturizer(object): ...@@ -43,7 +44,11 @@ class TextFeaturizer(object):
:rtype: list :rtype: list
""" """
tokens = self._char_tokenize(text) tokens = self._char_tokenize(text)
return [self._vocab_dict[token] for token in tokens] ids = []
for token in tokens:
token = token if token in self._vocab_dict else self.unk
ids.append(self._vocab_dict[token])
return ids
@property @property
def vocab_size(self): def vocab_size(self):
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import numpy as np import numpy as np
import random import random
from data_utils.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
class FeatureNormalizer(object): class FeatureNormalizer(object):
......
...@@ -14,28 +14,33 @@ ...@@ -14,28 +14,33 @@
"""Contains the speech segment class.""" """Contains the speech segment class."""
import numpy as np import numpy as np
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
class SpeechSegment(AudioSegment): class SpeechSegment(AudioSegment):
"""Speech segment abstraction, a subclass of AudioSegment, """Speech Segment with Text
with an additional transcript.
Args:
:param samples: Audio samples [num_samples x num_channels]. AudioSegment (AudioSegment): Audio Segment
:type samples: ndarray.float32
:param sample_rate: Audio sample rate.
:type sample_rate: int
:param transcript: Transcript text for the speech.
:type transript: str
:raises TypeError: If the sample data type is not float or int.
""" """
def __init__(self, samples, sample_rate, transcript): def __init__(self, samples, sample_rate, transcript):
"""Speech segment abstraction, a subclass of AudioSegment,
with an additional transcript.
Args:
samples (ndarray.float32): Audio samples [num_samples x num_channels].
sample_rate (int): Audio sample rate.
transcript (str): Transcript text for the speech.
"""
AudioSegment.__init__(self, samples, sample_rate) AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript self._transcript = transcript
def __eq__(self, other): def __eq__(self, other):
"""Return whether two objects are equal. """Return whether two objects are equal.
Returns:
bool: True, when equal to other
""" """
if not AudioSegment.__eq__(self, other): if not AudioSegment.__eq__(self, other):
return False return False
......
...@@ -20,6 +20,7 @@ import tarfile ...@@ -20,6 +20,7 @@ import tarfile
import time import time
from threading import Thread from threading import Thread
from multiprocessing import Process, Manager, Value from multiprocessing import Process, Manager, Value
from paddle.dataset.common import md5file from paddle.dataset.common import md5file
...@@ -49,51 +50,3 @@ def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): ...@@ -49,51 +50,3 @@ def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
json_data["duration"] >= min_duration): json_data["duration"] >= min_duration):
manifest.append(json_data) manifest.append(json_data)
return manifest return manifest
def getfile_insensitive(path):
"""Get the actual file path when given insensitive filename."""
directory, filename = os.path.split(path)
directory, filename = (directory or '.'), filename.lower()
for f in os.listdir(directory):
newpath = os.path.join(directory, f)
if os.path.isfile(newpath) and f.lower() == filename:
return newpath
def download_multi(url, target_dir, extra_args):
"""Download multiple files from url to target_dir."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
print("Downloading %s ..." % url)
ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " +
target_dir)
return ret_code
def download(url, md5sum, target_dir):
"""Download file from url to target_dir, and check md5sum."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
filepath = os.path.join(target_dir, url.split("/")[-1])
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print("Downloading %s ..." % url)
os.system("wget -c " + url + " -P " + target_dir)
print("\nMD5 Chesksum %s ..." % filepath)
if not md5file(filepath) == md5sum:
raise RuntimeError("MD5 checksum failed.")
else:
print("File exists, skip downloading. (%s)" % filepath)
return filepath
def unpack(filepath, target_dir, rm_tar=False):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)
tar = tarfile.open(filepath)
tar.extractall(target_dir)
tar.close()
if rm_tar == True:
os.remove(filepath)
class XmapEndSignal():
pass
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepspeech.training.trainer import *
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册