未验证 提交 d2bdd254 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #355 from PaddlePaddle/update_master

Update master
# This file is used by clang-format to autoformat paddle source code
#
# The clang-format is part of llvm toolchain.
# It need to install llvm and clang to format source code style.
#
# The basic usage is,
# clang-format -i -style=file PATH/TO/SOURCE/CODE
#
# The -style=file implicit use ".clang-format" file located in one of
# parent directory.
# The -i means inplace change.
#
# The document of clang-format is
# http://clang.llvm.org/docs/ClangFormat.html
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
---
Language: Cpp
BasedOnStyle: Google
IndentWidth: 2
TabWidth: 2
ContinuationIndentWidth: 4
MaxEmptyLinesToKeep: 2
AccessModifierOffset: -2 # The private/protected/public has no indent in class
Standard: Cpp11
AllowAllParametersOfDeclarationOnNextLine: true
BinPackParameters: false
BinPackArguments: false
...
#!/usr/bin/env bash
set -e
readonly VERSION="3.9"
version=$(clang-format -version)
if ! [[ $version == *"$VERSION"* ]]; then
echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1
fi
clang-format $@
.DS_Store
*.pyc
- repo: https://github.com/pre-commit/mirrors-yapf.git
sha: v0.16.0
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.md$
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
- id: remove-crlf
files: \.md$
- id: forbid-tabs
files: \.md$
- id: remove-tabs
files: \.md$
- repo: local
hooks:
- id: clang-format
name: clang-format
description: Format files with ClangFormat
entry: bash .clang_format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
[style]
based_on_style = pep8
column_limit = 80
language: cpp
cache: ccache
sudo: required
dist: trusty
services:
- docker
os:
- linux
env:
- JOB=PRE_COMMIT
addons:
apt:
packages:
- git
- python
- python-pip
- python2.7-dev
before_install:
- sudo pip install -U virtualenv pre-commit pip
- docker pull paddlepaddle/paddle:latest
script:
- exit_code=0
- .travis/precommit.sh || exit_code=$(( exit_code | $? ))
- docker run -i --rm -v "$PWD:/py_unittest" paddlepaddle/paddle:latest /bin/bash -c
'cd /py_unittest; sh .travis/unittest.sh' || exit_code=$(( exit_code | $? ))
exit $exit_code
notifications:
email:
on_success: change
on_failure: always
#!/bin/bash
function abort(){
echo "Your commit not fit PaddlePaddle code style" 1>&2
echo "Please use pre-commit scripts to auto-format your code" 1>&2
exit 1
}
trap 'abort' 0
set -e
cd `dirname $0`
cd ..
export PATH=/usr/bin:$PATH
pre-commit install
if ! pre-commit run -a ; then
ls -lh
git diff --exit-code
exit 1
fi
trap : 0
#!/bin/bash
abort(){
echo "Run unittest failed" 1>&2
echo "Please check your code" 1>&2
exit 1
}
unittest(){
cd $1 > /dev/null
if [ -f "setup.sh" ]; then
sh setup.sh
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
fi
if [ $? != 0 ]; then
exit 1
fi
find . -name 'tests' -type d -print0 | \
xargs -0 -I{} -n1 bash -c \
'python -m unittest discover -v -s {}'
cd - > /dev/null
}
trap 'abort' 0
set -e
unittest .
trap : 0
此差异已折叠。
此差异已折叠。
# Train DeepSpeech2 on PaddleCloud
>Note:
>Please make sure [PaddleCloud Client](https://github.com/PaddlePaddle/cloud/blob/develop/doc/usage_cn.md#%E4%B8%8B%E8%BD%BD%E5%B9%B6%E9%85%8D%E7%BD%AEpaddlecloud) has be installed and current directory is `deep_speech_2/cloud/`
## Step 1: Upload Data
Provided with several input manifests, `pcloud_upload_data.sh` will pack and upload all the containing audio files to PaddleCloud filesystem, and also generate some corresponding manifest files with updated cloud paths.
Please modify the following arguments in `pcloud_upload_data.sh`:
- `IN_MANIFESTS`: Paths (in local filesystem) of manifest files containing the audio files to be uploaded. Multiple paths can be concatenated with a whitespace delimeter.
- `OUT_MANIFESTS`: Paths (in local filesystem) to write the updated output manifest files to. Multiple paths can be concatenated with a whitespace delimeter. The values of `audio_filepath` in the output manifests are updated with cloud filesystem paths.
- `CLOUD_DATA_DIR`: Directory (in PaddleCloud filesystem) to upload the data to. Don't forget to replace `USERNAME` in the default directory and make sure that you have the permission to write it.
- `NUM_SHARDS`: Number of data shards / parts (in tar files) to be generated when packing and uploading data. Smaller `num_shards` requires larger temoporal local disk space for packing data.
By running:
```
sh pcloud_upload_data.sh
```
all the audio files will be uploaded to PaddleCloud filesystem, and you will get modified manifests files in `OUT_MANIFESTS`.
You have to take this step only once, in the very first time you do the cloud training. Later on, the data is persisitent on the cloud filesystem and reusable for further job submissions.
## Step 2: Configure Training
Configure cloud training arguments in `pcloud_submit.sh`, with the following arguments:
- `TRAIN_MANIFEST`: Manifest filepath (in local filesystem) for training. Notice that the`audio_filepath` should be in cloud filesystem, like those generated by `pcloud_upload_data.sh`.
- `DEV_MANIFEST`: Manifest filepath (in local filesystem) for validation.
- `CLOUD_MODEL_DIR`: Directory (in PaddleCloud filesystem) to save the model parameters (checkpoints). Don't forget to replace `USERNAME` in the default directory and make sure that you have the permission to write it.
- `BATCH_SIZE`: Training batch size for a single node.
- `NUM_GPU`: Number of GPUs allocated for a single node.
- `NUM_NODE`: Number of nodes (machines) allocated for this job.
- `IS_LOCAL`: Set to False to enable parameter server, if using multiple nodes.
Configure other training hyper-parameters in `pcloud_train.sh` as you wish, just as what you can do in local training.
By running:
```
sh pcloud_submit.sh
```
you submit a training job to PaddleCloud. And you will see the job name when the submission is done.
## Step 3 Get Job Logs
Run this to list all the jobs you have submitted, as well as their running status:
```
paddlecloud get jobs
```
Run this, the corresponding job's logs will be printed.
```
paddlecloud logs -n 10000 $REPLACED_WITH_YOUR_ACTUAL_JOB_NAME
```
## More Help
For more information about the usage of PaddleCloud, please refer to [PaddleCloud Usage](https://github.com/PaddlePaddle/cloud/blob/develop/doc/usage_cn.md#提交任务).
"""Set up paths for DS2"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import sys
def add_path(path):
if path not in sys.path:
sys.path.insert(0, path)
this_dir = os.path.dirname(__file__)
proj_path = os.path.join(this_dir, '..')
add_path(proj_path)
#! /usr/bin/env bash
TRAIN_MANIFEST="cloud/cloud_manifests/cloud.manifest.train"
DEV_MANIFEST="cloud/cloud_manifests/cloud.manifest.dev"
CLOUD_MODEL_DIR="./checkpoints"
BATCH_SIZE=512
NUM_GPU=8
NUM_NODE=1
IS_LOCAL="True"
JOB_NAME=deepspeech-`date +%Y%m%d%H%M%S`
DS2_PATH=${PWD%/*}
cp -f pcloud_train.sh ${DS2_PATH}
paddlecloud submit \
-image bootstrapper:5000/paddlepaddle/pcloud_ds2:latest \
-jobname ${JOB_NAME} \
-cpu ${NUM_GPU} \
-gpu ${NUM_GPU} \
-memory 64Gi \
-parallelism ${NUM_NODE} \
-pscpu 1 \
-pservers 1 \
-psmemory 64Gi \
-passes 1 \
-entry "sh pcloud_train.sh ${TRAIN_MANIFEST} ${DEV_MANIFEST} ${CLOUD_MODEL_DIR} ${NUM_GPU} ${BATCH_SIZE} ${IS_LOCAL}" \
${DS2_PATH}
rm ${DS2_PATH}/pcloud_train.sh
#! /usr/bin/env bash
TRAIN_MANIFEST=$1
DEV_MANIFEST=$2
MODEL_PATH=$3
NUM_GPU=$4
BATCH_SIZE=$5
IS_LOCAL=$6
python ./cloud/split_data.py \
--in_manifest_path=${TRAIN_MANIFEST} \
--out_manifest_path='/local.manifest.train'
python ./cloud/split_data.py \
--in_manifest_path=${DEV_MANIFEST} \
--out_manifest_path='/local.manifest.dev'
mkdir ./logs
python -u train.py \
--batch_size=${BATCH_SIZE} \
--trainer_count=${NUM_GPU} \
--num_passes=200 \
--num_proc_data=${NUM_GPU} \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--num_iter_print=100 \
--learning_rate=5e-4 \
--max_duration=27.0 \
--min_duration=0.0 \
--use_sortagrad=True \
--use_gru=False \
--use_gpu=True \
--is_local=${IS_LOCAL} \
--share_rnn_weights=True \
--train_manifest='/local.manifest.train' \
--dev_manifest='/local.manifest.dev' \
--mean_std_path='data/librispeech/mean_std.npz' \
--vocab_path='data/librispeech/vocab.txt' \
--output_model_dir='./checkpoints' \
--output_model_dir=${MODEL_PATH} \
--augment_conf_path='conf/augmentation.config' \
--specgram_type='linear' \
--shuffle_method='batch_shuffle_clipped' \
2>&1 | tee ./logs/train.log
#! /usr/bin/env bash
mkdir cloud_manifests
IN_MANIFESTS="../data/librispeech/manifest.train ../data/librispeech/manifest.dev-clean ../data/librispeech/manifest.test-clean"
OUT_MANIFESTS="cloud_manifests/cloud.manifest.train cloud_manifests/cloud.manifest.dev cloud_manifests/cloud.manifest.test"
CLOUD_DATA_DIR="/pfs/dlnel/home/USERNAME/deepspeech2/data/librispeech"
NUM_SHARDS=50
python upload_data.py \
--in_manifest_paths ${IN_MANIFESTS} \
--out_manifest_paths ${OUT_MANIFESTS} \
--cloud_data_dir ${CLOUD_DATA_DIR} \
--num_shards ${NUM_SHARDS}
if [ $? -ne 0 ]
then
echo "Upload Data Failed!"
exit 1
fi
echo "All Done."
"""This tool is used for splitting data into each node of
paddlecloud. This script should be called in paddlecloud.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import argparse
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--in_manifest_path",
type=str,
required=True,
help="Input manifest path for all nodes.")
parser.add_argument(
"--out_manifest_path",
type=str,
required=True,
help="Output manifest file path for current node.")
args = parser.parse_args()
def split_data(in_manifest_path, out_manifest_path):
with open("/trainer_id", "r") as f:
trainer_id = int(f.readline()[:-1])
with open("/trainer_count", "r") as f:
trainer_count = int(f.readline()[:-1])
out_manifest = []
for index, json_line in enumerate(open(in_manifest_path, 'r')):
if (index % trainer_count) == trainer_id:
out_manifest.append("%s\n" % json_line.strip())
with open(out_manifest_path, 'w') as f:
f.writelines(out_manifest)
if __name__ == '__main__':
split_data(args.in_manifest_path, args.out_manifest_path)
"""This script is for uploading data for DeepSpeech2 training on paddlecloud.
Steps:
1. Read original manifests and extract local sound files.
2. Tar all local sound files into multiple tar files and upload them.
3. Modify original manifests with updated paths in cloud filesystem.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import tarfile
import sys
import argparse
import shutil
from subprocess import call
import _init_paths
from data_utils.utils import read_manifest
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--in_manifest_paths",
default=[
"../datasets/manifest.train", "../datasets/manifest.dev",
"../datasets/manifest.test"
],
type=str,
nargs='+',
help="Local filepaths of input manifests to load, pack and upload."
"(default: %(default)s)")
parser.add_argument(
"--out_manifest_paths",
default=[
"./cloud.manifest.train", "./cloud.manifest.dev",
"./cloud.manifest.test"
],
type=str,
nargs='+',
help="Local filepaths of modified manifests to write to. "
"(default: %(default)s)")
parser.add_argument(
"--cloud_data_dir",
required=True,
type=str,
help="Destination directory on paddlecloud to upload data to.")
parser.add_argument(
"--num_shards",
default=10,
type=int,
help="Number of parts to split data to. (default: %(default)s)")
parser.add_argument(
"--local_tmp_dir",
default="./tmp/",
type=str,
help="Local directory for storing temporary data. (default: %(default)s)")
args = parser.parse_args()
def upload_data(in_manifest_path_list, out_manifest_path_list, local_tmp_dir,
upload_tar_dir, num_shards):
"""Extract and pack sound files listed in the manifest files into multple
tar files and upload them to padldecloud. Besides, generate new manifest
files with updated paths in paddlecloud.
"""
# compute total audio number
total_line = 0
for manifest_path in in_manifest_path_list:
with open(manifest_path, 'r') as f:
total_line += len(f.readlines())
line_per_tar = (total_line // num_shards) + 1
# pack and upload shard by shard
line_count, tar_file = 0, None
for manifest_path, out_manifest_path in zip(in_manifest_path_list,
out_manifest_path_list):
manifest = read_manifest(manifest_path)
out_manifest = []
for json_data in manifest:
sound_filepath = json_data['audio_filepath']
sound_filename = os.path.basename(sound_filepath)
if line_count % line_per_tar == 0:
if tar_file != None:
tar_file.close()
pcloud_cp(tar_path, upload_tar_dir)
os.remove(tar_path)
tar_name = 'part-%s-of-%s.tar' % (
str(line_count // line_per_tar).zfill(5),
str(num_shards).zfill(5))
tar_path = os.path.join(local_tmp_dir, tar_name)
tar_file = tarfile.open(tar_path, 'w')
tar_file.add(sound_filepath, arcname=sound_filename)
line_count += 1
json_data['audio_filepath'] = "tar:%s#%s" % (
os.path.join(upload_tar_dir, tar_name), sound_filename)
out_manifest.append("%s\n" % json.dumps(json_data))
with open(out_manifest_path, 'w') as f:
f.writelines(out_manifest)
pcloud_cp(out_manifest_path, upload_tar_dir)
tar_file.close()
pcloud_cp(tar_path, upload_tar_dir)
os.remove(tar_path)
def pcloud_mkdir(dir):
"""Make directory in PaddleCloud filesystem.
"""
if call(['paddlecloud', 'mkdir', dir]) != 0:
raise IOError("PaddleCloud mkdir failed: %s." % dir)
def pcloud_cp(src, dst):
"""Copy src from local filesytem to dst in PaddleCloud filesystem,
or downlowd src from PaddleCloud filesystem to dst in local filesystem.
"""
if call(['paddlecloud', 'cp', src, dst]) != 0:
raise IOError("PaddleCloud cp failed: from [%s] to [%s]." % (src, dst))
if __name__ == '__main__':
if not os.path.exists(args.local_tmp_dir):
os.makedirs(args.local_tmp_dir)
pcloud_mkdir(args.cloud_data_dir)
upload_data(args.in_manifest_paths, args.out_manifest_paths,
args.local_tmp_dir, args.cloud_data_dir, args.num_shards)
shutil.rmtree(args.local_tmp_dir)
[
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
}
]
[
{
"type": "noise",
"params": {"min_snr_dB": 40,
"max_snr_dB": 50,
"noise_manifest_path": "datasets/manifest.noise"},
"prob": 0.6
},
{
"type": "impulse",
"params": {"impulse_manifest_path": "datasets/manifest.impulse"},
"prob": 0.5
},
{
"type": "speed",
"params": {"min_speed_rate": 0.95,
"max_speed_rate": 1.05},
"prob": 0.5
},
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
},
{
"type": "volume",
"params": {"min_gain_dBFS": -10,
"max_gain_dBFS": 10},
"prob": 0.0
},
{
"type": "bayesian_normal",
"params": {"target_db": -20,
"prior_db": -20,
"prior_samples": 100},
"prob": 0.0
}
]
"""Prepare Aishell mandarin dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import codecs
import soundfile
import json
import argparse
from data_utils.utility import download, unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/33'
DATA_URL = URL_ROOT + '/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/Aishell",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
transcript_path = os.path.join(data_dir, 'transcript',
'aishell_transcript_v0.8.txt')
transcript_dict = {}
for line in codecs.open(transcript_path, 'r', 'utf-8'):
line = line.strip()
if line == '': continue
audio_id, text = line.split(' ', 1)
# remove withespace
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for type in data_types:
del json_lines[:]
audio_dir = os.path.join(data_dir, 'wav', type)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id]
json_lines.append(
json.dumps(
{
'audio_filepath': audio_path,
'duration': duration,
'text': text
},
ensure_ascii=False))
manifest_path = manifest_path_prefix + '.' + type
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'wav')
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for ftar in filelist:
unpack(os.path.join(subfolder, ftar), subfolder, True)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
create_manifest(data_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
manifest_path=args.manifest_prefix)
if __name__ == '__main__':
main()
"""Prepare Librispeech ASR datasets.
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import os
import sys
import argparse
import soundfile
import json
import codecs
from data_utils.utility import download, unpack
URL_ROOT = "http://www.openslr.org/resources/12"
URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz"
URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz"
URL_DEV_CLEAN = URL_ROOT + "/dev-clean.tar.gz"
URL_DEV_OTHER = URL_ROOT + "/dev-other.tar.gz"
URL_TRAIN_CLEAN_100 = URL_ROOT + "/train-clean-100.tar.gz"
URL_TRAIN_CLEAN_360 = URL_ROOT + "/train-clean-360.tar.gz"
URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz"
MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9"
MD5_TEST_OTHER = "fb5a50374b501bb3bac4815ee91d3135"
MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1"
MD5_DEV_OTHER = "c8d0bcc9cca99d4f8b62fcc847357931"
MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa"
MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default='~/.cache/paddle/dataset/speech/libri',
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
parser.add_argument(
"--full_download",
default="True",
type=distutils.util.strtobool,
help="Download all datasets for Librispeech."
" If False, only download a minimal requirement (test-clean, dev-clean"
" train-clean-100). (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path):
"""Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt')
]
if len(text_filelist) > 0:
text_filepath = os.path.join(data_dir, subfolder, text_filelist[0])
for line in open(text_filepath):
segments = line.strip().split()
text = ' '.join(segments[1:]).lower()
audio_filepath = os.path.join(data_dir, subfolder,
segments[0] + '.flac')
audio_data, samplerate = soundfile.read(audio_filepath)
duration = float(len(audio_data)) / samplerate
json_lines.append(
json.dumps({
'audio_filepath': audio_filepath,
'duration': duration,
'text': text
}))
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
"""
if not os.path.exists(os.path.join(target_dir, "LibriSpeech")):
# download
filepath = download(url, md5sum, target_dir)
# unpack
unpack(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
# create manifest json file
create_manifest(target_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=URL_TEST_CLEAN,
md5sum=MD5_TEST_CLEAN,
target_dir=os.path.join(args.target_dir, "test-clean"),
manifest_path=args.manifest_prefix + ".test-clean")
prepare_dataset(
url=URL_DEV_CLEAN,
md5sum=MD5_DEV_CLEAN,
target_dir=os.path.join(args.target_dir, "dev-clean"),
manifest_path=args.manifest_prefix + ".dev-clean")
if args.full_download:
prepare_dataset(
url=URL_TRAIN_CLEAN_100,
md5sum=MD5_TRAIN_CLEAN_100,
target_dir=os.path.join(args.target_dir, "train-clean-100"),
manifest_path=args.manifest_prefix + ".train-clean-100")
prepare_dataset(
url=URL_TEST_OTHER,
md5sum=MD5_TEST_OTHER,
target_dir=os.path.join(args.target_dir, "test-other"),
manifest_path=args.manifest_prefix + ".test-other")
prepare_dataset(
url=URL_DEV_OTHER,
md5sum=MD5_DEV_OTHER,
target_dir=os.path.join(args.target_dir, "dev-other"),
manifest_path=args.manifest_prefix + ".dev-other")
prepare_dataset(
url=URL_TRAIN_CLEAN_360,
md5sum=MD5_TRAIN_CLEAN_360,
target_dir=os.path.join(args.target_dir, "train-clean-360"),
manifest_path=args.manifest_prefix + ".train-clean-360")
prepare_dataset(
url=URL_TRAIN_OTHER_500,
md5sum=MD5_TRAIN_OTHER_500,
target_dir=os.path.join(args.target_dir, "train-other-500"),
manifest_path=args.manifest_prefix + ".train-other-500")
if __name__ == '__main__':
main()
"""Prepare CHiME3 background data.
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import os
import wget
import zipfile
import argparse
import soundfile
import json
from paddle.v2.dataset.common import md5file
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL = "https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ"
MD5 = "c3ff512618d7a67d4f85566ea1bc39ec"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/chime3_background",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_filepath",
default="manifest.chime3.background",
type=str,
help="Filepath for output manifests. (default: %(default)s)")
args = parser.parse_args()
def download(url, md5sum, target_dir, filename=None):
"""Download file from url to target_dir, and check md5sum."""
if filename == None:
filename = url.split("/")[-1]
if not os.path.exists(target_dir): os.makedirs(target_dir)
filepath = os.path.join(target_dir, filename)
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print("Downloading %s ..." % url)
wget.download(url, target_dir)
print("\nMD5 Chesksum %s ..." % filepath)
if not md5file(filepath) == md5sum:
raise RuntimeError("MD5 checksum failed.")
else:
print("File exists, skip downloading. (%s)" % filepath)
return filepath
def unpack(filepath, target_dir):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)
if filepath.endswith('.zip'):
zip = zipfile.ZipFile(filepath, 'r')
zip.extractall(target_dir)
zip.close()
elif filepath.endswith('.tar') or filepath.endswith('.tar.gz'):
tar = zipfile.open(filepath)
tar.extractall(target_dir)
tar.close()
else:
raise ValueError("File format is not supported for unpacking.")
def create_manifest(data_dir, manifest_path):
"""Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
for subfolder, _, filelist in sorted(os.walk(data_dir)):
for filename in filelist:
if filename.endswith('.wav'):
filepath = os.path.join(data_dir, subfolder, filename)
audio_data, samplerate = soundfile.read(filepath)
duration = float(len(audio_data)) / samplerate
json_lines.append(
json.dumps({
'audio_filepath': filepath,
'duration': duration,
'text': ''
}))
with open(manifest_path, 'w') as out_file:
for line in json_lines:
out_file.write(line + '\n')
def prepare_chime3(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file."""
if not os.path.exists(os.path.join(target_dir, "CHiME3")):
# download
filepath = download(url, md5sum, target_dir,
"myairbridge-AG0Y3DNBE5IWRRTV.zip")
# unpack
unpack(filepath, target_dir)
unpack(
os.path.join(target_dir, 'CHiME3_background_bus.zip'), target_dir)
unpack(
os.path.join(target_dir, 'CHiME3_background_caf.zip'), target_dir)
unpack(
os.path.join(target_dir, 'CHiME3_background_ped.zip'), target_dir)
unpack(
os.path.join(target_dir, 'CHiME3_background_str.zip'), target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
# create manifest json file
create_manifest(target_dir, manifest_path)
def main():
prepare_chime3(
url=URL,
md5sum=MD5,
target_dir=args.target_dir,
manifest_path=args.manifest_filepath)
if __name__ == '__main__':
main()
#! /usr/bin/env bash
# download data, generate manifests
PYTHONPATH=../../:$PYTHONPATH python voxforge.py \
--manifest_prefix='./manifest' \
--target_dir='~/.cache/paddle/dataset/speech/VoxForge' \
--is_merge_dialect=True \
--dialects 'american' 'british' 'australian' 'european' 'irish' 'canadian' 'indian'
if [ $? -ne 0 ]; then
echo "Prepare VoxForge failed. Terminated."
exit 1
fi
echo "VoxForge Data preparation done."
exit 0
"""Prepare VoxForge dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import codecs
import soundfile
import json
import argparse
import shutil
import subprocess
from data_utils.utility import download_multi, unpack, getfile_insensitive
DATA_HOME = '~/.cache/paddle/dataset/speech'
DATA_URL = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/' \
'Audio/Main/16kHz_16bit'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/VoxForge",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--dialects",
default=[
'american', 'british', 'australian', 'european', 'irish', 'canadian',
'indian'
],
nargs='+',
type=str,
help="Dialect types. (default: %(default)s)")
parser.add_argument(
"--is_merge_dialect",
default=True,
type=bool,
help="If set True, manifests of american dialect and canadian dialect will "
"be merged to american-canadian dialect; manifests of british "
"dialect, irish dialect and australian dialect will be merged to "
"commonwealth dialect. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def download_and_unpack(target_dir, url):
wget_args = '-q -l 1 -N -nd -c -e robots=off -A tgz -r -np'
tgz_dir = os.path.join(target_dir, 'tgz')
exit_code = download_multi(url, tgz_dir, wget_args)
if exit_code != 0:
print('Download tgz audio files failed with exit code %d.' % exit_code)
else:
print('Download done, start unpacking ...')
audio_dir = os.path.join(target_dir, 'audio')
for root, dirs, files in os.walk(tgz_dir):
for file in files:
print(file)
if file.endswith('.tgz'):
unpack(os.path.join(root, file), audio_dir)
def select_dialects(target_dir, dialect_list):
"""Classify audio files by dialect."""
dialect_root_dir = os.path.join(target_dir, 'dialect')
if os.path.exists(dialect_root_dir):
shutil.rmtree(dialect_root_dir)
os.mkdir(dialect_root_dir)
audio_dir = os.path.abspath(os.path.join(target_dir, 'audio'))
for dialect in dialect_list:
# filter files by dialect
command = 'find %s -iwholename "*etc/readme*" -exec egrep -iHl \
"pronunciation dialect.*%s" {} \;' % (audio_dir, dialect)
p = subprocess.Popen(
command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, shell=True)
output, err = p.communicate()
dialect_dir = os.path.join(dialect_root_dir, dialect)
if os.path.exists(dialect_dir):
shutil.rmtree(dialect_dir)
os.mkdir(dialect_dir)
for path in output.splitlines():
src_dir = os.path.dirname(os.path.dirname(path))
link = os.path.basename(os.path.normpath(src_dir))
os.symlink(src_dir, os.path.join(dialect_dir, link))
def generate_manifest(data_dir, manifest_path):
json_lines = []
for path in os.listdir(data_dir):
audio_link = os.path.join(data_dir, path)
assert os.path.islink(
audio_link), '%s should be symbolic link.' % audio_link
actual_audio_dir = os.path.abspath(os.readlink(audio_link))
audio_type = ''
if os.path.isdir(os.path.join(actual_audio_dir, 'wav')):
audio_type = 'wav'
elif os.path.isdir(os.path.join(actual_audio_dir, 'flac')):
audio_type = 'flac'
else:
print('Unknown audio type, skipped processing %s.' %
actual_audio_dir)
continue
etc_dir = os.path.join(actual_audio_dir, 'etc')
prompts_file = os.path.join(etc_dir, 'PROMPTS')
if not os.path.isfile(prompts_file):
print('PROMPTS file missing, skip processing %s.' %
actual_audio_dir)
continue
readme_file = getfile_insensitive(os.path.join(etc_dir, 'README'))
if readme_file is None:
print('README file missing, skip processing %s.' % actual_audio_dir)
continue
for line in file(prompts_file):
u, trans = line.strip().split(None, 1)
u_parts = u.split('/')
# try to format the date time
try:
speaker, date, sfx = u_parts[-3].split('-')
obj = datetime.datetime.strptime(date, '%y.%m.%d')
formatted = obj.strftime('%Y%m%d')
u_parts[-3] = '-'.join([speaker, formatted, sfx])
except Exception as e:
pass
if len(u_parts) < 2:
u_parts = [audio_type] + u_parts
u_parts[-2] = audio_type
u_parts[-1] += '.' + audio_type
u = os.path.join(actual_audio_dir, '/'.join(u_parts[-2:]))
if not os.path.isfile(u):
print('Audio file missing, skip processing %s.' % u)
continue
if os.stat(u).st_size == 0:
print('Empty audio file, skip processing %s.' % u)
continue
trans = trans.strip().replace('-', ' ')
if not trans.isupper() or \
not trans.strip().replace(' ', '').replace("'", "").isalpha():
print("Transcript not normalized properly, skip processing %s."
% u)
continue
audio_data, samplerate = soundfile.read(u)
duration = float(len(audio_data)) / samplerate
json_lines.append(
json.dumps({
'audio_filepath': u,
'duration': duration,
'text': trans.lower()
}))
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
def merge_manifests(manifest_files, save_path):
lines = []
for manifest_file in manifest_files:
line = codecs.open(manifest_file, 'r', 'utf-8').readlines()
lines += line
with codecs.open(save_path, 'w', 'utf-8') as fout:
for line in lines:
fout.write(line)
def prepare_dataset(url, dialects, target_dir, manifest_prefix, is_merge):
download_and_unpack(target_dir, url)
select_dialects(target_dir, dialects)
american_canadian_manifests = []
commonwealth_manifests = []
for dialect in dialects:
dialect_dir = os.path.join(target_dir, 'dialect', dialect)
manifest_fpath = manifest_prefix + '.' + dialect
if dialect == 'american' or dialect == 'canadian':
american_canadian_manifests.append(manifest_fpath)
if dialect == 'australian' \
or dialect == 'british' \
or dialect == 'irish':
commonwealth_manifests.append(manifest_fpath)
generate_manifest(dialect_dir, manifest_fpath)
if is_merge:
if len(american_canadian_manifests) > 0:
manifest_fpath = manifest_prefix + '.american-canadian'
merge_manifests(american_canadian_manifests, manifest_fpath)
if len(commonwealth_manifests) > 0:
manifest_fpath = manifest_prefix + '.commonwealth'
merge_manifests(commonwealth_manifests, manifest_fpath)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(DATA_URL, args.dialects, args.target_dir,
args.manifest_prefix, args.is_merge_dialect)
if __name__ == '__main__':
main()
此差异已折叠。
"""Contains the data augmentation pipeline."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor
from data_utils.augmentor.impulse_response import ImpulseResponseAugmentor
from data_utils.augmentor.resample import ResampleAugmentor
from data_utils.augmentor.online_bayesian_normalization import \
OnlineBayesianNormalizationAugmentor
class AugmentationPipeline(object):
"""Build a pre-processing pipeline with various augmentation models.Such a
data augmentation pipeline is oftern leveraged to augment the training
samples to make the model invariant to certain types of perturbations in the
real world, improving model's generalization ability.
The pipeline is built according the the augmentation configuration in json
string, e.g.
.. code-block::
[ {
"type": "noise",
"params": {"min_snr_dB": 10,
"max_snr_dB": 20,
"noise_manifest_path": "datasets/manifest.noise"},
"prob": 0.0
},
{
"type": "speed",
"params": {"min_speed_rate": 0.9,
"max_speed_rate": 1.1},
"prob": 1.0
},
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
},
{
"type": "volume",
"params": {"min_gain_dBFS": -10,
"max_gain_dBFS": 10},
"prob": 0.0
},
{
"type": "bayesian_normal",
"params": {"target_db": -20,
"prior_db": -20,
"prior_samples": 100},
"prob": 0.0
}
]
This augmentation configuration inserts two augmentation models
into the pipeline, with one is VolumePerturbAugmentor and the other
SpeedPerturbAugmentor. "prob" indicates the probability of the current
augmentor to take effect. If "prob" is zero, the augmentor does not take
effect.
:param augmentation_config: Augmentation configuration in json string.
:type augmentation_config: str
:param random_seed: Random seed.
:type random_seed: int
:raises ValueError: If the augmentation json config is in incorrect format".
"""
def __init__(self, augmentation_config, random_seed=0):
self._rng = random.Random(random_seed)
self._augmentors, self._rates = self._parse_pipeline_from(
augmentation_config)
def transform_audio(self, audio_segment):
"""Run the pre-processing pipeline for data augmentation.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
for augmentor, rate in zip(self._augmentors, self._rates):
if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment)
def _parse_pipeline_from(self, config_json):
"""Parse the config json to build a augmentation pipelien."""
try:
configs = json.loads(config_json)
augmentors = [
self._get_augmentor(config["type"], config["params"])
for config in configs
]
rates = [config["prob"] for config in configs]
except Exception as e:
raise ValueError("Failed to parse the augmentation config json: "
"%s" % str(e))
return augmentors, rates
def _get_augmentor(self, augmentor_type, params):
"""Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params)
elif augmentor_type == "shift":
return ShiftPerturbAugmentor(self._rng, **params)
elif augmentor_type == "speed":
return SpeedPerturbAugmentor(self._rng, **params)
elif augmentor_type == "resample":
return ResampleAugmentor(self._rng, **params)
elif augmentor_type == "bayesian_normal":
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
elif augmentor_type == "noise":
return NoisePerturbAugmentor(self._rng, **params)
elif augmentor_type == "impulse":
return ImpulseResponseAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
"""Contains the abstract base class for augmentation models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABCMeta, abstractmethod
class AugmentorBase(object):
"""Abstract base class for augmentation model (augmentor) class.
All augmentor classes should inherit from this class, and implement the
following abstract methods.
"""
__metaclass__ = ABCMeta
@abstractmethod
def __init__(self):
pass
@abstractmethod
def transform_audio(self, audio_segment):
"""Adds various effects to the input audio segment. Such effects
will augment the training data to make the model invariant to certain
types of perturbations in the real world, improving model's
generalization ability.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
pass
"""Contains the impulse response augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
class ImpulseResponseAugmentor(AugmentorBase):
"""Augmentation model for adding impulse response effect.
:param rng: Random generator object.
:type rng: random.Random
:param impulse_manifest_path: Manifest path for impulse audio data.
:type impulse_manifest_path: basestring
"""
def __init__(self, rng, impulse_manifest_path):
self._rng = rng
self._impulse_manifest = read_manifest(impulse_manifest_path)
def transform_audio(self, audio_segment):
"""Add impulse response effect.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
impulse_json = self._rng.sample(self._impulse_manifest, 1)[0]
impulse_segment = AudioSegment.from_file(impulse_json['audio_filepath'])
audio_segment.convolve(impulse_segment, allow_resample=True)
"""Contains the noise perturb augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
class NoisePerturbAugmentor(AugmentorBase):
"""Augmentation model for adding background noise.
:param rng: Random generator object.
:type rng: random.Random
:param min_snr_dB: Minimal signal noise ratio, in decibels.
:type min_snr_dB: float
:param max_snr_dB: Maximal signal noise ratio, in decibels.
:type max_snr_dB: float
:param noise_manifest_path: Manifest path for noise audio data.
:type noise_manifest_path: basestring
"""
def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest_path):
self._min_snr_dB = min_snr_dB
self._max_snr_dB = max_snr_dB
self._rng = rng
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
def transform_audio(self, audio_segment):
"""Add background noise audio.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
noise_json = self._rng.sample(self._noise_manifest, 1)[0]
if noise_json['duration'] < audio_segment.duration:
raise RuntimeError("The duration of sampled noise audio is smaller "
"than the audio segment to add effects to.")
diff_duration = noise_json['duration'] - audio_segment.duration
start = self._rng.uniform(0, diff_duration)
end = start + audio_segment.duration
noise_segment = AudioSegment.slice_from_file(
noise_json['audio_filepath'], start=start, end=end)
snr_dB = self._rng.uniform(self._min_snr_dB, self._max_snr_dB)
audio_segment.add_noise(
noise_segment, snr_dB, allow_downsampling=True, rng=self._rng)
"""Contain the online bayesian normalization augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
class OnlineBayesianNormalizationAugmentor(AugmentorBase):
"""Augmentation model for adding online bayesian normalization.
:param rng: Random generator object.
:type rng: random.Random
:param target_db: Target RMS value in decibels.
:type target_db: float
:param prior_db: Prior RMS estimate in decibels.
:type prior_db: float
:param prior_samples: Prior strength in number of samples.
:type prior_samples: int
:param startup_delay: Default 0.0s. If provided, this function will
accrue statistics for the first startup_delay
seconds before applying online normalization.
:type starup_delay: float.
"""
def __init__(self,
rng,
target_db,
prior_db,
prior_samples,
startup_delay=0.0):
self._target_db = target_db
self._prior_db = prior_db
self._prior_samples = prior_samples
self._rng = rng
self._startup_delay = startup_delay
def transform_audio(self, audio_segment):
"""Normalizes the input audio using the online Bayesian approach.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegment|SpeechSegment
"""
audio_segment.normalize_online_bayesian(self._target_db, self._prior_db,
self._prior_samples,
self._startup_delay)
"""Contain the resample augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
class ResampleAugmentor(AugmentorBase):
"""Augmentation model for resampling.
See more info here:
https://ccrma.stanford.edu/~jos/resample/index.html
:param rng: Random generator object.
:type rng: random.Random
:param new_sample_rate: New sample rate in Hz.
:type new_sample_rate: int
"""
def __init__(self, rng, new_sample_rate):
self._new_sample_rate = new_sample_rate
self._rng = rng
def transform_audio(self, audio_segment):
"""Resamples the input audio to a target sample rate.
Note that this is an in-place transformation.
:param audio: Audio segment to add effects to.
:type audio: AudioSegment|SpeechSegment
"""
audio_segment.resample(self._new_sample_rate)
"""Contains the volume perturb augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
class ShiftPerturbAugmentor(AugmentorBase):
"""Augmentation model for adding random shift perturbation.
:param rng: Random generator object.
:type rng: random.Random
:param min_shift_ms: Minimal shift in milliseconds.
:type min_shift_ms: float
:param max_shift_ms: Maximal shift in milliseconds.
:type max_shift_ms: float
"""
def __init__(self, rng, min_shift_ms, max_shift_ms):
self._min_shift_ms = min_shift_ms
self._max_shift_ms = max_shift_ms
self._rng = rng
def transform_audio(self, audio_segment):
"""Shift audio.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
audio_segment.shift(shift_ms)
"""Contain the speech perturbation augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
class SpeedPerturbAugmentor(AugmentorBase):
"""Augmentation model for adding speed perturbation.
See reference paper here:
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
:param rng: Random generator object.
:type rng: random.Random
:param min_speed_rate: Lower bound of new speed rate to sample and should
not be smaller than 0.9.
:type min_speed_rate: float
:param max_speed_rate: Upper bound of new speed rate to sample and should
not be larger than 1.1.
:type max_speed_rate: float
"""
def __init__(self, rng, min_speed_rate, max_speed_rate):
if min_speed_rate < 0.9:
raise ValueError(
"Sampling speed below 0.9 can cause unnatural effects")
if max_speed_rate > 1.1:
raise ValueError(
"Sampling speed above 1.1 can cause unnatural effects")
self._min_speed_rate = min_speed_rate
self._max_speed_rate = max_speed_rate
self._rng = rng
def transform_audio(self, audio_segment):
"""Sample a new speed rate from the given range and
changes the speed of the given audio clip.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegment|SpeechSegment
"""
sampled_speed = self._rng.uniform(self._min_speed_rate,
self._max_speed_rate)
audio_segment.change_speed(sampled_speed)
"""Contains the volume perturb augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
class VolumePerturbAugmentor(AugmentorBase):
"""Augmentation model for adding random volume perturbation.
This is used for multi-loudness training of PCEN. See
https://arxiv.org/pdf/1607.05666v1.pdf
for more details.
:param rng: Random generator object.
:type rng: random.Random
:param min_gain_dBFS: Minimal gain in dBFS.
:type min_gain_dBFS: float
:param max_gain_dBFS: Maximal gain in dBFS.
:type max_gain_dBFS: float
"""
def __init__(self, rng, min_gain_dBFS, max_gain_dBFS):
self._min_gain_dBFS = min_gain_dBFS
self._max_gain_dBFS = max_gain_dBFS
self._rng = rng
def transform_audio(self, audio_segment):
"""Change audio loadness.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
audio_segment.gain_db(gain)
"""Contains data generator for orgnaizing various audio data preprocessing
pipeline and offering data reader interface of PaddlePaddle requirements.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import tarfile
import multiprocessing
import numpy as np
import paddle.v2 as paddle
from threading import local
from data_utils.utility import read_manifest
from data_utils.utility import xmap_readers_mp
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.speech import SpeechSegment
from data_utils.normalizer import FeatureNormalizer
class DataGenerator(object):
"""
DataGenerator provides basic audio data preprocessing pipeline, and offers
data reader interfaces of PaddlePaddle requirements.
:param vocab_filepath: Vocabulary filepath for indexing tokenized
transcripts.
:type vocab_filepath: basestring
:param mean_std_filepath: File containing the pre-computed mean and stddev.
:type mean_std_filepath: None|basestring
:param augmentation_config: Augmentation configuration in json string.
Details see AugmentationPipeline.__doc__.
:type augmentation_config: str
:param max_duration: Audio with duration (in seconds) greater than
this will be discarded.
:type max_duration: float
:param min_duration: Audio with duration (in seconds) smaller than
this will be discarded.
:type min_duration: float
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param max_freq: Used when specgram_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are
returned.
:types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param use_dB_normalization: Whether to normalize the audio to -20 dB
before extracting the features.
:type use_dB_normalization: bool
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
:param random_seed: Random seed.
:type random_seed: int
:param keep_transcription_text: If set to True, transcription text will
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
"""
def __init__(self,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type='linear',
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count() // 2,
random_seed=0,
keep_transcription_text=False):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer(
vocab_filepath=vocab_filepath,
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads
self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
self._epoch = 0
# for caching tar files info
self._local_data = local()
self._local_data.tar2info = {}
self._local_data.tar2object = {}
def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: basestring | file
:param transcript: Transcription text.
:type transcript: basestring
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, basestring) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram)
return specgram, transcript_part
def batch_reader_creator(self,
manifest_path,
batch_size,
min_batch_size=1,
padding_to=-1,
flatten=False,
sortagrad=False,
shuffle_method="batch_shuffle"):
"""
Batch data reader creator for audio data. Return a callable generator
function to produce batches of data.
Audio features within one batch will be padded with zeros to have the
same shape, or a user-defined shape.
:param manifest_path: Filepath of manifest for audio files.
:type manifest_path: basestring
:param batch_size: Number of instances in a batch.
:type batch_size: int
:param min_batch_size: Any batch with batch size smaller than this will
be discarded. (To be deprecated in the future.)
:type min_batch_size: int
:param padding_to: If set -1, the maximun shape in the batch
will be used as the target shape for padding.
Otherwise, `padding_to` will be the target shape.
:type padding_to: int
:param flatten: If set True, audio features will be flatten to 1darray.
:type flatten: bool
:param sortagrad: If set True, sort the instances by audio duration
in the first epoch for speed up training.
:type sortagrad: bool
:param shuffle_method: Shuffle method. Options:
'' or None: no shuffle.
'instance_shuffle': instance-wise shuffle.
'batch_shuffle': similarly-sized instances are
put into batches, and then
batch-wise shuffle the batches.
For more details, please see
``_batch_shuffle.__doc__``.
'batch_shuffle_clipped': 'batch_shuffle' with
head shift and tail
clipping. For more
details, please see
``_batch_shuffle``.
If sortagrad is True, shuffle is disabled
for the first epoch.
:type shuffle_method: None|str
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def batch_reader():
# read manifest
manifest = read_manifest(
manifest_path=manifest_path,
max_duration=self._max_duration,
min_duration=self._min_duration)
# sort (by duration) or batch-wise shuffle the manifest
if self._epoch == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
else:
if shuffle_method == "batch_shuffle":
manifest = self._batch_shuffle(
manifest, batch_size, clipped=False)
elif shuffle_method == "batch_shuffle_clipped":
manifest = self._batch_shuffle(
manifest, batch_size, clipped=True)
elif shuffle_method == "instance_shuffle":
self._rng.shuffle(manifest)
elif shuffle_method == None:
pass
else:
raise ValueError("Unknown shuffle method %s." %
shuffle_method)
# prepare batches
instance_reader, cleanup = self._instance_reader_creator(manifest)
batch = []
try:
for instance in instance_reader():
batch.append(instance)
if len(batch) == batch_size:
yield self._padding_batch(batch, padding_to, flatten)
batch = []
if len(batch) >= min_batch_size:
yield self._padding_batch(batch, padding_to, flatten)
finally:
cleanup()
self._epoch += 1
return batch_reader
@property
def feeding(self):
"""Returns data reader's feeding dict.
:return: Data feeding dict.
:rtype: dict
"""
feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
return feeding_dict
@property
def vocab_size(self):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
"""Return the vocabulary in list.
:return: Vocabulary in list.
:rtype: list
"""
return self._speech_featurizer.vocab_list
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def _instance_reader_creator(self, manifest):
"""
Instance reader creator. Create a callable function to produce
instances of data.
Instance: a tuple of ndarray of audio spectrogram and a list of
token indices for transcript.
"""
def reader():
for instance in manifest:
yield instance
reader, cleanup_callback = xmap_readers_mp(
lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
reader, self._num_threads, 4096)
return reader, cleanup_callback
def _padding_batch(self, batch, padding_to=-1, flatten=False):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
If `flatten` is True, features will be flatten to 1darray.
"""
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be larger "
"than any instance's shape in the batch")
max_length = padding_to
# padding
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
padded_instance = [padded_audio, text, audio.shape[1]]
new_batch.append(padded_instance)
return new_batch
def _batch_shuffle(self, manifest, batch_size, clipped=False):
"""Put similarly-sized instances into minibatches for better efficiency
and make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly shift `k` instances in order to create different batches
for different epochs. Create minibatches.
4. Shuffle the minibatches.
:param manifest: Manifest contents. List of dict.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:param clipped: Whether to clip the heading (small shift) and trailing
(incomplete batch) instances.
:type clipped: bool
:return: Batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self._rng.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self._rng.shuffle(batch_manifest)
batch_manifest = [item for batch in batch_manifest for item in batch]
if not clipped:
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest
"""Contains the audio featurizer class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
from python_speech_features import mfcc
from python_speech_features import delta
class AudioFeaturizer(object):
"""Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment.
Currently, it supports feature types of linear spectrogram and mfcc.
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param max_freq: When specgram_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are
returned; when specgram_type is 'mfcc', max_feq is the
highest band edge of mel filters.
:types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def __init__(self,
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._specgram_type = specgram_type
self._stride_ms = stride_ms
self._window_ms = window_ms
self._max_freq = max_freq
self._target_sample_rate = target_sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
def featurize(self,
audio_segment,
allow_downsampling=True,
allow_upsampling=True):
"""Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before
featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before
featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
"""
# upsampling or downsampling
if ((audio_segment.sample_rate > self._target_sample_rate and
allow_downsampling) or
(audio_segment.sample_rate < self._target_sample_rate and
allow_upsampling)):
audio_segment.resample(self._target_sample_rate)
if audio_segment.sample_rate != self._target_sample_rate:
raise ValueError("Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on.")
# decibel normalization
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# extract spectrogram
return self._compute_specgram(audio_segment.samples,
audio_segment.sample_rate)
def _compute_specgram(self, samples, sample_rate):
"""Extract various audio features."""
if self._specgram_type == 'linear':
return self._compute_linear_specgram(
samples, sample_rate, self._stride_ms, self._window_ms,
self._max_freq)
elif self._specgram_type == 'mfcc':
return self._compute_mfcc(samples, sample_rate, self._stride_ms,
self._window_ms, self._max_freq)
else:
raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type)
def _compute_linear_specgram(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""Compute the linear spectrogram from FFT energy."""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
specgram, freqs = self._specgram_real(
samples,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(specgram[:ind, :] + eps)
def _specgram_real(self, samples, window_size, stride_size, sample_rate):
"""Compute the spectrogram for samples from a real signal."""
# extract strided windows
truncate_size = (len(samples) - window_size) % stride_size
samples = samples[:len(samples) - truncate_size]
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
windows = np.lib.stride_tricks.as_strided(
samples, shape=nshape, strides=nstrides)
assert np.all(
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0)
fft = np.absolute(fft)
fft = fft**2
scale = np.sum(weighting**2) * sample_rate
fft[1:-1, :] *= (2.0 / scale)
fft[(0, -1), :] /= scale
# prepare fft frequency list
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs
def _compute_mfcc(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None):
"""Compute mfcc from samples."""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
# compute the 13 cepstral coefficients, and the first one is replaced
# by log(frame energy)
mfcc_feat = mfcc(
signal=samples,
samplerate=sample_rate,
winlen=0.001 * window_ms,
winstep=0.001 * stride_ms,
highfreq=max_freq)
# Deltas
d_mfcc_feat = delta(mfcc_feat, 2)
# Deltas-Deltas
dd_mfcc_feat = delta(d_mfcc_feat, 2)
# transpose
mfcc_feat = np.transpose(mfcc_feat)
d_mfcc_feat = np.transpose(d_mfcc_feat)
dd_mfcc_feat = np.transpose(dd_mfcc_feat)
# concat above three features
concat_mfcc_feat = np.concatenate(
(mfcc_feat, d_mfcc_feat, dd_mfcc_feat))
return concat_mfcc_feat
"""Contains the speech featurizer class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
from data_utils.featurizer.text_featurizer import TextFeaturizer
class SpeechFeaturizer(object):
"""Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment.
Currently, for audio parts, it supports feature types of linear
spectrogram and mfcc; for transcript parts, it only supports char-level
tokenizing and conversion into a list of token indices. Note that the
token indexing order follows the given vocabulary file.
:param vocab_filepath: Filepath to load vocabulary for token indices
conversion.
:type specgram_type: basestring
:param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'.
:type specgram_type: str
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param max_freq: When specgram_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are
returned; when specgram_type is 'mfcc', max_freq is the
highest band edge of mel filters.
:types max_freq: None|float
:param target_sample_rate: Speech are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def __init__(self,
vocab_filepath,
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment.
1. For audio parts, extract the audio features.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
:param audio_segment: Speech segment to extract features from.
:type audio_segment: SpeechSegment
:return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of
char-level token indices.
:rtype: tuple
"""
audio_feature = self._audio_featurizer.featurize(speech_segment)
if keep_transcription_text:
return audio_feature, speech_segment.transcript
text_ids = self._text_featurizer.featurize(speech_segment.transcript)
return audio_feature, text_ids
@property
def vocab_size(self):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
"""Return the vocabulary in list.
:return: Vocabulary in list.
:rtype: list
"""
return self._text_featurizer.vocab_list
"""Contains the text featurizer class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import codecs
class TextFeaturizer(object):
"""Text featurizer, for processing or extracting features from text.
Currently, it only supports char-level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
:param vocab_filepath: Filepath to load vocabulary for token indices
conversion.
:type specgram_type: basestring
"""
def __init__(self, vocab_filepath):
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
def featurize(self, text):
"""Convert text string to a list of token indices in char-level.Note
that the token indexing order follows the given vocabulary file.
:param text: Text to process.
:type text: basestring
:return: List of char-level token indices.
:rtype: list
"""
tokens = self._char_tokenize(text)
return [self._vocab_dict[token] for token in tokens]
@property
def vocab_size(self):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return len(self._vocab_list)
@property
def vocab_list(self):
"""Return the vocabulary in list.
:return: Vocabulary in list.
:rtype: list
"""
return self._vocab_list
def _char_tokenize(self, text):
"""Character tokenizer."""
return list(text.strip())
def _load_vocabulary_from_file(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []
with codecs.open(vocab_filepath, 'r', 'utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_dict = dict(
[(token, id) for (id, token) in enumerate(vocab_list)])
return vocab_dict, vocab_list
"""Contains feature normalizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
class FeatureNormalizer(object):
"""Feature normalizer. Normalize features to be of zero mean and unit
stddev.
if mean_std_filepath is provided (not None), the normalizer will directly
initilize from the file. Otherwise, both manifest_path and featurize_func
should be given for on-the-fly mean and stddev computing.
:param mean_std_filepath: File containing the pre-computed mean and stddev.
:type mean_std_filepath: None|basestring
:param manifest_path: Manifest of instances for computing mean and stddev.
:type meanifest_path: None|basestring
:param featurize_func: Function to extract features. It should be callable
with ``featurize_func(audio_segment)``.
:type featurize_func: None|callable
:param num_samples: Number of random samples for computing mean and stddev.
:type num_samples: int
:param random_seed: Random seed for sampling instances.
:type random_seed: int
:raises ValueError: If both mean_std_filepath and manifest_path
(or both mean_std_filepath and featurize_func) are None.
"""
def __init__(self,
mean_std_filepath,
manifest_path=None,
featurize_func=None,
num_samples=500,
random_seed=0):
if not mean_std_filepath:
if not (manifest_path and featurize_func):
raise ValueError("If mean_std_filepath is None, meanifest_path "
"and featurize_func should not be None.")
self._rng = random.Random(random_seed)
self._compute_mean_std(manifest_path, featurize_func, num_samples)
else:
self._read_mean_std_from_file(mean_std_filepath)
def apply(self, features, eps=1e-14):
"""Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized.
:type features: ndarray
:param eps: added to stddev to provide numerical stablibity.
:type eps: float
:return: Normalized features.
:rtype: ndarray
"""
return (features - self._mean) / (self._std + eps)
def write_to_file(self, filepath):
"""Write the mean and stddev to the file.
:param filepath: File to write mean and stddev.
:type filepath: basestring
"""
np.savez(filepath, mean=self._mean, std=self._std)
def _read_mean_std_from_file(self, filepath):
"""Load mean and std from file."""
npzfile = np.load(filepath)
self._mean = npzfile["mean"]
self._std = npzfile["std"]
def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
"""Compute mean and std from randomly sampled instances."""
manifest = read_manifest(manifest_path)
sampled_manifest = self._rng.sample(manifest, num_samples)
features = []
for instance in sampled_manifest:
features.append(
featurize_func(
AudioSegment.from_file(instance["audio_filepath"])))
features = np.hstack(features)
self._mean = np.mean(features, axis=1).reshape([-1, 1])
self._std = np.std(features, axis=1).reshape([-1, 1])
"""Contains the speech segment class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.audio import AudioSegment
class SpeechSegment(AudioSegment):
"""Speech segment abstraction, a subclass of AudioSegment,
with an additional transcript.
:param samples: Audio samples [num_samples x num_channels].
:type samples: ndarray.float32
:param sample_rate: Audio sample rate.
:type sample_rate: int
:param transcript: Transcript text for the speech.
:type transript: basestring
:raises TypeError: If the sample data type is not float or int.
"""
def __init__(self, samples, sample_rate, transcript):
AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript
def __eq__(self, other):
"""Return whether two objects are equal.
"""
if not AudioSegment.__eq__(self, other):
return False
if self._transcript != other._transcript:
return False
return True
def __ne__(self, other):
"""Return whether two objects are unequal."""
return not self.__eq__(other)
@classmethod
def from_file(cls, filepath, transcript):
"""Create speech segment from audio file and corresponding transcript.
:param filepath: Filepath or file object to audio file.
:type filepath: basestring|file
:param transcript: Transcript text for the speech.
:type transript: basestring
:return: Speech segment instance.
:rtype: SpeechSegment
"""
audio = AudioSegment.from_file(filepath)
return cls(audio.samples, audio.sample_rate, transcript)
@classmethod
def from_bytes(cls, bytes, transcript):
"""Create speech segment from a byte string and corresponding
transcript.
:param bytes: Byte string containing audio samples.
:type bytes: str
:param transcript: Transcript text for the speech.
:type transript: basestring
:return: Speech segment instance.
:rtype: Speech Segment
"""
audio = AudioSegment.from_bytes(bytes)
return cls(audio.samples, audio.sample_rate, transcript)
@classmethod
def concatenate(cls, *segments):
"""Concatenate an arbitrary number of speech segments together, both
audio and transcript will be concatenated.
:param *segments: Input speech segments to be concatenated.
:type *segments: tuple of SpeechSegment
:return: Speech segment instance.
:rtype: SpeechSegment
:raises ValueError: If the number of segments is zero, or if the
sample_rate of any two segments does not match.
:raises TypeError: If any segment is not SpeechSegment instance.
"""
if len(segments) == 0:
raise ValueError("No speech segments are given to concatenate.")
sample_rate = segments[0]._sample_rate
transcripts = ""
for seg in segments:
if sample_rate != seg._sample_rate:
raise ValueError("Can't concatenate segments with "
"different sample rates")
if type(seg) is not cls:
raise TypeError("Only speech segments of the same type "
"instance can be concatenated.")
transcripts += seg._transcript
samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate, transcripts)
@classmethod
def slice_from_file(cls, filepath, transcript, start=None, end=None):
"""Loads a small section of an speech without having to load
the entire file into the memory which can be incredibly wasteful.
:param filepath: Filepath or file object to audio file.
:type filepath: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:param transcript: Transcript text for the speech. if not provided,
the defaults is an empty string.
:type transript: basestring
:return: SpeechSegment instance of the specified slice of the input
speech file.
:rtype: SpeechSegment
"""
audio = AudioSegment.slice_from_file(filepath, start, end)
return cls(audio.samples, audio.sample_rate, transcript)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent speech segment of the given duration and
sample rate, transcript will be an empty string.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silence of the given duration.
:rtype: SpeechSegment
"""
audio = AudioSegment.make_silence(duration, sample_rate)
return cls(audio.samples, audio.sample_rate, "")
@property
def transcript(self):
"""Return the transcript text.
:return: Transcript text for the speech.
:rtype: basestring
"""
return self._transcript
"""Contains data helper functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import codecs
import os
import tarfile
import time
from Queue import Queue
from threading import Thread
from multiprocessing import Process, Manager, Value
from paddle.v2.dataset.common import md5file
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
"""Load and parse manifest file.
Instances with durations outside [min_duration, max_duration] will be
filtered out.
:param manifest_path: Manifest file to load and parse.
:type manifest_path: basestring
:param max_duration: Maximal duration in seconds for instance filter.
:type max_duration: float
:param min_duration: Minimal duration in seconds for instance filter.
:type min_duration: float
:return: Manifest parsing results. List of dict.
:rtype: list
:raises IOError: If failed to parse the manifest.
"""
manifest = []
for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
try:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and
json_data["duration"] >= min_duration):
manifest.append(json_data)
return manifest
def getfile_insensitive(path):
"""Get the actual file path when given insensitive filename."""
directory, filename = os.path.split(path)
directory, filename = (directory or '.'), filename.lower()
for f in os.listdir(directory):
newpath = os.path.join(directory, f)
if os.path.isfile(newpath) and f.lower() == filename:
return newpath
def download_multi(url, target_dir, extra_args):
"""Download multiple files from url to target_dir."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
print("Downloading %s ..." % url)
ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " +
target_dir)
return ret_code
def download(url, md5sum, target_dir):
"""Download file from url to target_dir, and check md5sum."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
filepath = os.path.join(target_dir, url.split("/")[-1])
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print("Downloading %s ..." % url)
os.system("wget -c " + url + " -P " + target_dir)
print("\nMD5 Chesksum %s ..." % filepath)
if not md5file(filepath) == md5sum:
raise RuntimeError("MD5 checksum failed.")
else:
print("File exists, skip downloading. (%s)" % filepath)
return filepath
def unpack(filepath, target_dir, rm_tar=False):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)
tar = tarfile.open(filepath)
tar.extractall(target_dir)
tar.close()
if rm_tar == True:
os.remove(filepath)
class XmapEndSignal():
pass
def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
"""A multiprocessing pipeline wrapper for the data reader.
:param mapper: Function to map sample.
:type mapper: callable
:param reader: Given data reader.
:type reader: callable
:param process_num: Number of processes in the pipeline
:type process_num: int
:param buffer_size: Maximal buffer size.
:type buffer_size: int
:return: The wrappered reader and cleanup callback
:rtype: tuple
"""
end_flag = XmapEndSignal()
read_workers = []
handle_workers = []
flush_workers = []
read_exit_flag = Value('i', 0)
handle_exit_flag = Value('i', 0)
flush_exit_flag = Value('i', 0)
# define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue):
for order_id, sample in enumerate(reader()):
if read_exit_flag.value == 1: break
in_queue.put((order_id, sample))
in_queue.put(end_flag)
# the reading worker should not exit until all handling work exited
while handle_exit_flag.value == 0 or read_exit_flag.value == 0:
time.sleep(0.001)
# define a worker to handle samples from in_queue by mapper and put results
# to out_queue with order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
while not isinstance(ins, XmapEndSignal):
if handle_exit_flag.value == 1: break
order_id, sample = ins
result = mapper(sample)
while order_id != out_order[0]:
time.sleep(0.001)
out_queue.put(result)
out_order[0] += 1
ins = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# wait for exit of flushing worker
while flush_exit_flag.value == 0 or handle_exit_flag.value == 0:
time.sleep(0.001)
read_exit_flag.value = 1
handle_exit_flag.value = 1
# define a thread worker to flush samples from Manager.Queue to Queue
# for acceleration
def flush_worker(in_queue, out_queue):
finish = 0
while finish < process_num and flush_exit_flag.value == 0:
sample = in_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
out_queue.put(sample)
out_queue.put(end_flag)
handle_exit_flag.value = 1
flush_exit_flag.value = 1
def cleanup():
# first exit flushing workers
flush_exit_flag.value = 1
for w in flush_workers:
w.join()
# next exit handling workers
handle_exit_flag.value = 1
for w in handle_workers:
w.join()
# last exit reading workers
read_exit_flag.value = 1
for w in read_workers:
w.join()
def xreader():
# prepare shared memory
manager = Manager()
in_queue = manager.Queue(buffer_size)
out_queue = manager.Queue(buffer_size)
out_order = manager.list([0])
# start a read worker in a process
target = order_read_worker
p = Process(target=target, args=(reader, in_queue))
p.daemon = True
p.start()
read_workers.append(p)
# start handle_workers with multiple processes
target = order_handle_worker
args = (in_queue, out_queue, mapper, out_order)
workers = [
Process(target=target, args=args) for _ in xrange(process_num)
]
for w in workers:
w.daemon = True
w.start()
handle_workers.append(w)
# start a thread to read data from slow Manager.Queue
flush_queue = Queue(buffer_size)
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
t.daemon = True
t.start()
flush_workers.append(t)
# get results
sample = flush_queue.get()
while not isinstance(sample, XmapEndSignal):
yield sample
sample = flush_queue.get()
return xreader, cleanup
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from itertools import groupby
import numpy as np
from math import log
import multiprocessing
def ctc_greedy_decoder(probs_seq, vocabulary):
"""CTC greedy (best path) decoder.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: baseline
"""
# dimension verification
for probs in probs_seq:
if not len(probs) == len(vocabulary) + 1:
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
# argmax to get the best index for each time step
max_index_list = list(np.array(probs_seq).argmax(axis=1))
# remove consecutive duplicate indexes
index_list = [index_group[0] for index_group in groupby(max_index_list)]
# remove blank indexes
blank_index = len(vocabulary)
index_list = [index for index in index_list if index != blank_index]
# convert index list to string
return ''.join([vocabulary[index] for index in index_list])
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
nproc=False):
"""CTC Beam search decoder.
It utilizes beam search to approximately select top best decoding
labels and returning results in the descending order.
The implementation is based on Prefix Beam Search
(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
one prefix may comes from different paths; 2) the if condition "if l^+ not
in A_prev then" after probabilities' computation is deprecated for it is
hard to understand and seems unnecessary.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
# dimension check
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("The shape of prob_seq does not match with the "
"shape of the vocabulary.")
# blank_id assign
blank_id = len(vocabulary)
# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch().
if nproc is True:
global ext_nproc_scorer
ext_scoring_func = ext_nproc_scorer
## initialize
# prefix_set_prev: the set containing selected prefixes
# probs_b_prev: prefixes' probability ending with blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
prefix_set_prev = {'\t': 1.0}
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
## extend prefix in loop
for time_step in xrange(len(probs_seq)):
# prefix_set_next: the set containing candidate prefixes
# probs_b_cur: prefixes' probability ending with blank in current step
# probs_nb_cur: prefixes' probability ending with non-blank in current step
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}
prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx)
#If pruning is enabled
if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
cutoff_len = min(cutoff_len, cutoff_top_n)
prob_idx = prob_idx[0:cutoff_len]
for l in prefix_set_prev:
if not prefix_set_next.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
# extend prefix by travering prob_idx
for index in xrange(cutoff_len):
c, prob_c = prob_idx[index][0], prob_idx[index][1]
if c == blank_id:
probs_b_cur[l] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
last_char = l[-1]
new_char = vocabulary[c]
l_plus = l + new_char
if not prefix_set_next.has_key(l_plus):
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
if new_char == last_char:
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
elif new_char == ' ':
if (ext_scoring_func is None) or (len(l) == 1):
score = 1.0
else:
prefix = l[1:]
score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[
l_plus] + probs_b_cur[l_plus]
# add l into prefix_set_next
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
# update probs
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
## store top beam_size prefixes
prefix_set_prev = sorted(
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
if beam_size < len(prefix_set_prev):
prefix_set_prev = prefix_set_prev[:beam_size]
prefix_set_prev = dict(prefix_set_prev)
beam_result = []
for seq, prob in prefix_set_prev.items():
if prob > 0.0 and len(seq) > 1:
result = seq[1:]
# score last word by external scorer
if (ext_scoring_func is not None) and (result[-1] != ' '):
prob = prob * ext_scoring_func(result)
log_prob = log(prob)
beam_result.append((log_prob, result))
else:
beam_result.append((float('-inf'), ''))
## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result
def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
"""CTC beam search decoder using multiple processes.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param num_processes: Number of parallel processes.
:type num_processes: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")
# use global variable to pass the externnal scorer to beam search decoder
global ext_nproc_scorer
ext_nproc_scorer = ext_scoring_func
nproc = True
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
None, nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close()
pool.join()
beam_search_results = [result.get() for result in results]
return beam_search_results
"""External Scorer for Beam Search Decoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import kenlm
import numpy as np
class Scorer(object):
"""External scorer to evaluate a prefix or whole sentence in
beam search decoding, including the score from n-gram language
model and word count.
:param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float
:model_path: Path to load language model.
:type model_path: basestring
"""
def __init__(self, alpha, beta, model_path):
self._alpha = alpha
self._beta = beta
if not os.path.isfile(model_path):
raise IOError("Invaid language model path: %s" % model_path)
self._language_model = kenlm.LanguageModel(model_path)
# n-gram language model scoring
def _language_model_score(self, sentence):
#log10 prob of last word
log_cond_prob = list(
self._language_model.full_scores(sentence, eos=False))[-1][0]
return np.power(10, log_cond_prob)
# word insertion term
def _word_count(self, sentence):
words = sentence.strip().split(' ')
return len(words)
# reset alpha and beta
def reset_params(self, alpha, beta):
self._alpha = alpha
self._beta = beta
# execute evaluation
def __call__(self, sentence, log=False):
"""Evaluation function, gathering all the different scores
and return the final one.
:param sentence: The input sentence for evalutation
:type sentence: basestring
:param log: Whether return the score in log representation.
:type log: bool
:return: Evaluation score, in the decimal or log.
:rtype: float
"""
lm = self._language_model_score(sentence)
word_cnt = self._word_count(sentence)
if log == False:
score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta)
else:
score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt)
return score
"""Set up paths for DS2"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import sys
def add_path(path):
if path not in sys.path:
sys.path.insert(0, path)
this_dir = os.path.dirname(__file__)
# Add project path to PYTHONPATH
proj_path = os.path.join(this_dir, '..')
add_path(proj_path)
此差异已折叠。
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#include <string>
#include <utility>
#include <vector>
#include "scorer.h"
/* CTC Beam Search Decoder
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
/* CTC Beam Search Decoder for batch data
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary,
size_t beam_size,
size_t num_processes,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
#endif // CTC_BEAM_SEARCH_DECODER_H_
#include "ctc_greedy_decoder.h"
#include "decoder_utils.h"
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1,
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec;
for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0;
size_t max_idx = 0;
const std::vector<double> &probs_step = probs_seq[i];
for (size_t j = 0; j < probs_step.size(); ++j) {
if (max_prob < probs_step[j]) {
max_idx = j;
max_prob = probs_step[j];
}
}
// id with maximum probability in current time step
max_idx_vec[i] = max_idx;
// deduplicate
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]);
}
}
std::string best_path_result;
for (size_t i = 0; i < idx_vec.size(); ++i) {
if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]];
}
}
return best_path_result;
}
#ifndef CTC_GREEDY_DECODER_H
#define CTC_GREEDY_DECODER_H
#include <string>
#include <vector>
/* CTC Greedy (Best Path) Decoder
*
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* The decoding result in string
*/
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary);
#endif // CTC_GREEDY_DECODER_H
此差异已折叠。
此差异已折叠。
%module swig_decoders
%{
#include "scorer.h"
#include "ctc_greedy_decoder.h"
#include "ctc_beam_search_decoder.h"
#include "decoder_utils.h"
%}
%include "std_vector.i"
%include "std_pair.i"
%include "std_string.i"
%import "decoder_utils.h"
namespace std {
%template(DoubleVector) std::vector<double>;
%template(IntVector) std::vector<int>;
%template(StringVector) std::vector<std::string>;
%template(VectorOfStructVector) std::vector<std::vector<double> >;
%template(FloatVector) std::vector<float>;
%template(Pair) std::pair<float, std::string>;
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
%template(PairDoubleStringVector2) std::vector<std::vector<std::pair<double, std::string> > >;
%template(DoubleVector3) std::vector<std::vector<std::vector<double> > >;
}
%template(IntDoublePairCompSecondRev) pair_comp_second_rev<int, double>;
%template(StringDoublePairCompSecondRev) pair_comp_second_rev<std::string, double>;
%template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>;
%include "scorer.h"
%include "ctc_greedy_decoder.h"
%include "ctc_beam_search_decoder.h"
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册