Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b07ee84a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b07ee84a
编写于
6月 13, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add function, class and module docs for data parts in DS2.
上级
cd3617ae
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
662 addition
and
162 deletion
+662
-162
compute_mean_std.py
compute_mean_std.py
+2
-1
data_utils/audio.py
data_utils/audio.py
+208
-24
data_utils/augmentor/augmentation.py
data_utils/augmentor/augmentation.py
+51
-9
data_utils/augmentor/base.py
data_utils/augmentor/base.py
+16
-0
data_utils/augmentor/volume_perturb.py
data_utils/augmentor/volume_perturb.py
+40
-0
data_utils/data.py
data_utils/data.py
+83
-83
data_utils/featurizer/audio_featurizer.py
data_utils/featurizer/audio_featurizer.py
+29
-9
data_utils/featurizer/speech_featurizer.py
data_utils/featurizer/speech_featurizer.py
+50
-5
data_utils/featurizer/text_featurizer.py
data_utils/featurizer/text_featurizer.py
+32
-4
data_utils/normalizer.py
data_utils/normalizer.py
+39
-1
data_utils/speech.py
data_utils/speech.py
+75
-0
data_utils/utils.py
data_utils/utils.py
+16
-1
datasets/librispeech/librispeech.py
datasets/librispeech/librispeech.py
+9
-7
decoder.py
decoder.py
+5
-4
infer.py
infer.py
+1
-4
model.py
model.py
+4
-5
train.py
train.py
+2
-5
未找到文件。
compute_mean_std.py
浏览文件 @
b07ee84a
"""Compute mean and std for feature normalizer, and save to file."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -17,7 +18,7 @@ parser.add_argument(
"(default: %(default)s)"
)
parser
.
add_argument
(
"--num_samples"
,
default
=
5
00
,
default
=
20
00
,
type
=
int
,
help
=
"Number of samples for computing mean and stddev. "
"(default: %(default)s)"
)
...
...
data_utils/audio.py
浏览文件 @
b07ee84a
"""Contains the audio segment class."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
io
import
soundfile
...
...
@@ -5,64 +10,243 @@ import soundfile
class
AudioSegment
(
object
):
"""Monaural audio segment abstraction.
:param samples: Audio samples [num_samples x num_channels].
:type samples: ndarray.float32
:param sample_rate: Audio sample rate.
:type sample_rate: int
:raises TypeError: If the sample data type is not float or int.
"""
def
__init__
(
self
,
samples
,
sample_rate
):
if
not
samples
.
dtype
==
np
.
float32
:
raise
ValueError
(
"Sample data type of [%s] is not supported."
)
self
.
_samples
=
samples
"""Create audio segment from samples.
Samples are convert float32 internally, with int scaled to [-1, 1].
"""
self
.
_samples
=
self
.
_convert_samples_to_float32
(
samples
)
self
.
_sample_rate
=
sample_rate
if
self
.
_samples
.
ndim
>=
2
:
self
.
_samples
=
np
.
mean
(
self
.
_samples
,
1
)
def
__eq__
(
self
,
other
):
"""Return whether two objects are equal."""
if
type
(
other
)
is
not
type
(
self
):
return
False
if
self
.
_sample_rate
!=
other
.
_sample_rate
:
return
False
if
self
.
_samples
.
shape
!=
other
.
_samples
.
shape
:
return
False
if
np
.
any
(
self
.
samples
!=
other
.
_samples
):
return
False
return
True
def
__ne__
(
self
,
other
):
"""Return whether two objects are unequal."""
return
not
self
.
__eq__
(
other
)
def
__str__
(
self
):
"""Return human-readable representation of segment."""
return
(
"%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
"rms=%.2fdB"
%
(
type
(
self
),
self
.
num_samples
,
self
.
sample_rate
,
self
.
duration
,
self
.
rms_db
))
@
classmethod
def
from_file
(
cls
,
filepath
):
samples
,
sample_rate
=
soundfile
.
read
(
filepath
,
dtype
=
'float32'
)
def
from_file
(
cls
,
file
):
"""Create audio segment from audio file.
:param filepath: Filepath or file object to audio file.
:type filepath: basestring|file
:return: Audio segment instance.
:rtype: AudioSegment
"""
samples
,
sample_rate
=
soundfile
.
read
(
file
,
dtype
=
'float32'
)
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
from_bytes
(
cls
,
bytes
):
"""Create audio segment from a byte string containing audio samples.
:param bytes: Byte string containing audio samples.
:type bytes: str
:return: Audio segment instance.
:rtype: AudioSegment
"""
samples
,
sample_rate
=
soundfile
.
read
(
io
.
BytesIO
(
bytes
),
dtype
=
'float32'
)
return
cls
(
samples
,
sample_rate
)
def
to_wav_file
(
self
,
filepath
,
dtype
=
'float32'
):
"""Save audio segment to disk as wav file.
:param filepath: WAV filepath or file object to save the
audio segment.
:type filepath: basestring|file
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:raises TypeError: If dtype is not supported.
"""
samples
=
self
.
_convert_samples_from_float32
(
self
.
_samples
,
dtype
)
subtype_map
=
{
'int16'
:
'PCM_16'
,
'int32'
:
'PCM_32'
,
'float32'
:
'FLOAT'
,
'float64'
:
'DOUBLE'
}
soundfile
.
write
(
filepath
,
samples
,
self
.
_sample_rate
,
format
=
'WAV'
,
subtype
=
subtype_map
[
dtype
])
def
to_bytes
(
self
,
dtype
=
'float32'
):
"""Create a byte string containing the audio content.
:param dtype: Data type for export samples. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:return: Byte string containing audio content.
:rtype: str
"""
samples
=
self
.
_convert_samples_from_float32
(
self
.
_samples
,
dtype
)
return
samples
.
tostring
()
def
apply_gain
(
self
,
gain
):
self
.
samples
*=
10.
**
(
gain
/
20.
)
"""Apply gain in decibels to samples.
Note that this is an in-place transformation.
:param gain: Gain in decibels to apply to samples.
:type gain: float
"""
self
.
_samples
*=
10.
**
(
gain
/
20.
)
def
change_speed
(
self
,
speed_rate
):
"""Change the audio speed by linear interpolation.
Note that this is an in-place transformation.
:param speed_rate: Rate of speed change:
speed_rate > 1.0, speed up the audio;
speed_rate = 1.0, unchanged;
speed_rate < 1.0, slow down the audio;
speed_rate <= 0.0, not allowed, raise ValueError.
:type speed_rate: float
:raises ValueError: If speed_rate <= 0.0.
"""
if
speed_rate
<=
0
:
raise
ValueError
(
"speed_rate should be greater than zero."
)
old_length
=
self
.
_samples
.
shape
[
0
]
new_length
=
int
(
old_length
/
speed_rate
)
old_indices
=
np
.
arange
(
old_length
)
new_indices
=
np
.
linspace
(
start
=
0
,
stop
=
old_length
,
num
=
new_length
)
self
.
_samples
=
np
.
interp
(
new_indices
,
old_indices
,
self
.
_samples
)
def
normalize
(
self
,
target_sample_rate
):
raise
NotImplementedError
()
def
resample
(
self
,
target_sample_rate
):
raise
NotImplementedError
()
def
change_speed
(
self
,
rate
):
def
pad_silence
(
self
,
duration
,
sides
=
'both'
):
raise
NotImplementedError
()
def
subsegment
(
self
,
start_sec
=
None
,
end_sec
=
None
):
raise
NotImplementedError
()
def
convolve
(
self
,
filter
,
allow_resample
=
False
):
raise
NotImplementedError
()
def
convolve_and_normalize
(
self
,
filter
,
allow_resample
=
False
):
raise
NotImplementedError
()
@
property
def
samples
(
self
):
"""Return audio samples.
:return: Audio samples.
:rtype: ndarray
"""
return
self
.
_samples
.
copy
()
@
property
def
sample_rate
(
self
):
"""Return audio sample rate.
:return: Audio sample rate.
:rtype: int
"""
return
self
.
_sample_rate
@
property
def
duration
(
self
):
return
self
.
_samples
.
shape
[
0
]
/
float
(
self
.
_sample_rate
)
def
num_samples
(
self
):
"""Return number of samples.
class
SpeechSegment
(
AudioSegment
):
def
__init__
(
self
,
samples
,
sample_rate
,
transcript
):
AudioSegment
.
__init__
(
self
,
samples
,
sample_rate
)
self
.
_transcript
=
transcript
:return: Number of samples.
:rtype: int
"""
return
self
.
_samples
.
shape
(
0
)
@
classmethod
def
from_file
(
cls
,
filepath
,
transcript
):
audio
=
AudioSegment
.
from_file
(
filepath
)
return
cls
(
audio
.
samples
,
audio
.
sample_rate
,
transcript
)
@
property
def
duration
(
self
):
"""Return audio duration.
@
classmethod
def
from_bytes
(
cls
,
bytes
,
transcript
):
audio
=
AudioSegment
.
from_bytes
(
bytes
)
return
cls
(
audio
.
samples
,
audio
.
sample_rate
,
transcript
)
:return: Audio duration in seconds.
:rtype: float
"""
return
self
.
_samples
.
shape
[
0
]
/
float
(
self
.
_sample_rate
)
@
property
def
transcript
(
self
):
return
self
.
_transcript
def
rms_db
(
self
):
"""Return root mean square energy of the audio in decibels.
:return: Root mean square energy in decibels.
:rtype: float
"""
# square root => multiply by 10 instead of 20 for dBs
mean_square
=
np
.
mean
(
self
.
_samples
**
2
)
return
10
*
np
.
log10
(
mean_square
)
def
_convert_samples_to_float32
(
self
,
samples
):
"""Convert sample type to float32.
Audio sample type is usually integer or float-point.
Integers will be scaled to [-1, 1] in float32.
"""
float32_samples
=
samples
.
astype
(
'float32'
)
if
samples
.
dtype
in
np
.
sctypes
[
'int'
]:
bits
=
np
.
iinfo
(
samples
.
dtype
).
bits
float32_samples
*=
(
1.
/
2
**
(
bits
-
1
))
elif
samples
.
dtype
in
np
.
sctypes
[
'float'
]:
pass
else
:
raise
TypeError
(
"Unsupported sample type: %s."
%
samples
.
dtype
)
return
float32_samples
def
_convert_samples_from_float32
(
self
,
samples
,
dtype
):
"""Convert sample type from float32 to dtype.
Audio sample type is usually integer or float-point. For integer
type, float32 will be rescaled from [-1, 1] to the maximum range
supported by the integer type.
This is for writing a audio file.
"""
dtype
=
np
.
dtype
(
dtype
)
output_samples
=
samples
.
copy
()
if
dtype
in
np
.
sctypes
[
'int'
]:
bits
=
np
.
iinfo
(
dtype
).
bits
output_samples
*=
(
2
**
(
bits
-
1
)
/
1.
)
min_val
=
np
.
iinfo
(
dtype
).
min
max_val
=
np
.
iinfo
(
dtype
).
max
output_samples
[
output_samples
>
max_val
]
=
max_val
output_samples
[
output_samples
<
min_val
]
=
min_val
elif
samples
.
dtype
in
np
.
sctypes
[
'float'
]:
min_val
=
np
.
finfo
(
dtype
).
min
max_val
=
np
.
finfo
(
dtype
).
max
output_samples
[
output_samples
>
max_val
]
=
max_val
output_samples
[
output_samples
<
min_val
]
=
min_val
else
:
raise
TypeError
(
"Unsupported sample type: %s."
%
samples
.
dtype
)
return
output_samples
.
astype
(
dtype
)
data_utils/augmentor/augmentation.py
浏览文件 @
b07ee84a
"""Contains the data augmentation pipeline."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
random
from
data_utils.augmentor.volum
n_perturb
import
Volumn
PerturbAugmentor
from
data_utils.augmentor.volum
e_perturb
import
Volume
PerturbAugmentor
class
AugmentationPipeline
(
object
):
"""Build a pre-processing pipeline with various augmentation models.Such a
data augmentation pipeline is oftern leveraged to augment the training
samples to make the model invariant to certain types of perturbations in the
real world, improving model's generalization ability.
The pipeline is built according the the augmentation configuration in json
string, e.g.
.. code-block::
'[{"type": "volume",
"params": {"min_gain_dBFS": -15,
"max_gain_dBFS": 15},
"prob": 0.5},
{"type": "speed",
"params": {"min_speed_rate": 0.8,
"max_speed_rate": 1.2},
"prob": 0.5}
]'
This augmentation configuration inserts two augmentation models
into the pipeline, with one is VolumePerturbAugmentor and the other
SpeedPerturbAugmentor. "prob" indicates the probability of the current
augmentor to take effect.
:param augmentation_config: Augmentation configuration in json string.
:type augmentation_config: str
:param random_seed: Random seed.
:type random_seed: int
:raises ValueError: If the augmentation json config is in incorrect format".
"""
def
__init__
(
self
,
augmentation_config
,
random_seed
=
0
):
self
.
_rng
=
random
.
Random
(
random_seed
)
self
.
_augmentors
,
self
.
_rates
=
self
.
_parse_pipeline_from
(
augmentation_config
)
def
transform_audio
(
self
,
audio_segment
):
"""Run the pre-processing pipeline for data augmentation.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
for
augmentor
,
rate
in
zip
(
self
.
_augmentors
,
self
.
_rates
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
<=
rate
:
augmentor
.
transform_audio
(
audio_segment
)
def
_parse_pipeline_from
(
self
,
config_json
):
"""Parse the config json to build a augmentation pipelien."""
try
:
configs
=
json
.
loads
(
config_json
)
augmentors
=
[
self
.
_get_augmentor
(
config
[
"type"
],
config
[
"params"
])
for
config
in
configs
]
rates
=
[
config
[
"prob"
]
for
config
in
configs
]
except
Exception
as
e
:
raise
ValueError
(
"
Augmentation config json format error
: "
raise
ValueError
(
"
Failed to parse the augmentation config json
: "
"%s"
%
str
(
e
))
augmentors
=
[
self
.
_get_augmentor
(
config
[
"type"
],
config
[
"params"
])
for
config
in
configs
]
rates
=
[
config
[
"rate"
]
for
config
in
configs
]
return
augmentors
,
rates
def
_get_augmentor
(
self
,
augmentor_type
,
params
):
if
augmentor_type
==
"volumn"
:
return
VolumnPerturbAugmentor
(
self
.
_rng
,
**
params
)
"""Return an augmentation model by the type name, and pass in params."""
if
augmentor_type
==
"volume"
:
return
VolumePerturbAugmentor
(
self
.
_rng
,
**
params
)
else
:
raise
ValueError
(
"Unknown augmentor type [%s]."
%
augmentor_type
)
data_utils/augmentor/base.py
浏览文件 @
b07ee84a
"""Contains the abstract base class for augmentation models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -6,6 +7,11 @@ from abc import ABCMeta, abstractmethod
class
AugmentorBase
(
object
):
"""Abstract base class for augmentation model (augmentor) class.
All augmentor classes should inherit from this class, and implement the
following abstract methods.
"""
__metaclass__
=
ABCMeta
@
abstractmethod
...
...
@@ -14,4 +20,14 @@ class AugmentorBase(object):
@
abstractmethod
def
transform_audio
(
self
,
audio_segment
):
"""Adds various effects to the input audio segment. Such effects
will augment the training data to make the model invariant to certain
types of perturbations in the real world, improving model's
generalization ability.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
pass
data_utils/augmentor/volum
n
_perturb.py
→
data_utils/augmentor/volum
e
_perturb.py
浏览文件 @
b07ee84a
"""Contains the volume perturb augmentation model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
random
from
data_utils.augmentor.base
import
AugmentorBase
class
VolumnPerturbAugmentor
(
AugmentorBase
):
class
VolumePerturbAugmentor
(
AugmentorBase
):
"""Augmentation model for adding random volume perturbation.
This is used for multi-loudness training of PCEN. See
https://arxiv.org/pdf/1607.05666v1.pdf
for more details.
:param rng: Random generator object.
:type rng: random.Random
:param min_gain_dBFS: Minimal gain in dBFS.
:type min_gain_dBFS: float
:param max_gain_dBFS: Maximal gain in dBFS.
:type max_gain_dBFS: float
"""
def
__init__
(
self
,
rng
,
min_gain_dBFS
,
max_gain_dBFS
):
self
.
_min_gain_dBFS
=
min_gain_dBFS
self
.
_max_gain_dBFS
=
max_gain_dBFS
self
.
_rng
=
rng
def
transform_audio
(
self
,
audio_segment
):
"""Change audio loadness.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
gain
=
self
.
_rng
.
uniform
(
min_gain_dBFS
,
max_gain_dBFS
)
audio_segment
.
apply_gain
(
gain
)
data_utils/data.py
浏览文件 @
b07ee84a
"""Contains data generator for orgnaizing various audio data preprocessing
pipeline and offering data reader interface of PaddlePaddle requirements.
"""
Providing basic audio data preprocessing pipeline, and offering
both instance-level and batch-level data reader interfaces.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -13,42 +11,41 @@ import paddle.v2 as paddle
from
data_utils
import
utils
from
data_utils.augmentor.augmentation
import
AugmentationPipeline
from
data_utils.featurizer.speech_featurizer
import
SpeechFeaturizer
from
data_utils.
audio
import
SpeechSegment
from
data_utils.
speech
import
SpeechSegment
from
data_utils.normalizer
import
FeatureNormalizer
class
DataGenerator
(
object
):
"""
DataGenerator provides basic audio data preprocessing pipeline, and offers
both instance-level and batch-level data reader interfaces.
Normalized FFT are used as audio features here.
data reader interfaces of PaddlePaddle requirements.
:param vocab_filepath: Vocabulary file
path for indexing tokenized
transcript
ion
s.
:param vocab_filepath: Vocabulary filepath for indexing tokenized
transcripts.
:type vocab_filepath: basestring
:param normalizer_manifest_path: Manifest filepath for collecting feature
normalization statistics, e.g. mean, std.
:type normalizer_manifest_path: basestring
:param normalizer_num_samples: Number of instances sampled for collecting
feature normalization statistics.
Default is 100.
:type normalizer_num_samples: int
:param max_duration: Audio clips with duration (in seconds) greater than
this will be discarded. Default is 20.0.
:param mean_std_filepath: File containing the pre-computed mean and stddev.
:type mean_std_filepath: None|basestring
:param augmentation_config: Augmentation configuration in json string.
Details see AugmentationPipeline.__doc__.
:type augmentation_config: str
:param max_duration: Audio with duration (in seconds) greater than
this will be discarded.
:type max_duration: float
:param min_duration: Audio
clips
with duration (in seconds) smaller than
this will be discarded.
Default is 0.0.
:param min_duration: Audio with duration (in seconds) smaller than
this will be discarded.
:type min_duration: float
:param stride_ms: Striding size (in milliseconds) for generating frames.
Default is 10.0.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for
frames. Default is 20.0
.
:param window_ms: Window size (in milliseconds) for
generating frames
.
:type window_ms: float
:param max_frequency: Maximun frequency for FFT features. FFT features of
frequency larger than this will be discarded.
If set None, all features will be kept.
Default is None.
:type max_frequency: float
:param max_freq: Used when specgram_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are
returned.
:types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param random_seed: Random seed.
:type random_seed: int
"""
def
__init__
(
self
,
...
...
@@ -60,6 +57,7 @@ class DataGenerator(object):
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
specgram_type
=
'linear'
,
random_seed
=
0
):
self
.
_max_duration
=
max_duration
self
.
_min_duration
=
min_duration
...
...
@@ -68,46 +66,49 @@ class DataGenerator(object):
augmentation_config
=
augmentation_config
,
random_seed
=
random_seed
)
self
.
_speech_featurizer
=
SpeechFeaturizer
(
vocab_filepath
=
vocab_filepath
,
specgram_type
=
specgram_type
,
stride_ms
=
stride_ms
,
window_ms
=
window_ms
,
max_freq
=
max_freq
,
random_seed
=
random_seed
)
max_freq
=
max_freq
)
self
.
_rng
=
random
.
Random
(
random_seed
)
self
.
_epoch
=
0
def
batch_reader_creator
(
self
,
manifest_path
,
batch_size
,
min_batch_size
=
1
,
padding_to
=-
1
,
flatten
=
False
,
sortagrad
=
False
,
batch_shuffle
=
False
):
"""
Batch data reader creator for audio data.
Creat a callable function to
produce batches of data.
Batch data reader creator for audio data.
Return a callable generator
function to
produce batches of data.
Audio features wi
ll be padded with zeros to make each instance in
the
batch to share the same audio feature
shape.
Audio features wi
thin one batch will be padded with zeros to have
the
same shape, or a user-defined
shape.
:param manifest_path: Filepath of manifest for audio
clip
files.
:param manifest_path: Filepath of manifest for audio files.
:type manifest_path: basestring
:param batch_size:
Instance number
in a batch.
:param batch_size:
Number of instances
in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:param min_batch_size: Any batch with batch size smaller than this will
be discarded. (To be deprecated in the future.)
:type min_batch_size: int
:param padding_to: If set -1, the maximun shape in the batch
will be used as the target shape for padding.
Otherwise, `padding_to` will be the target shape.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:param flatten: If set True, audio features will be flatten to 1darray.
:type flatten: bool
:param sortagrad:
Sort the audio clips by duration in the first epoc
i
f set True
.
:param sortagrad:
If set True, sort the instances by audio duration
i
n the first epoch for speed up training
.
:type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `_batch_shuffle` function.
:param batch_shuffle: If set True, instances are batch-wise shuffled.
For more details, please see
``_batch_shuffle.__doc__``.
If sortagrad is True, batch_shuffle is disabled
for the first epoch.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
...
...
@@ -132,7 +133,7 @@ class DataGenerator(object):
if
len
(
batch
)
==
batch_size
:
yield
self
.
_padding_batch
(
batch
,
padding_to
,
flatten
)
batch
=
[]
if
len
(
batch
)
>
0
:
if
len
(
batch
)
>
=
min_batch_size
:
yield
self
.
_padding_batch
(
batch
,
padding_to
,
flatten
)
self
.
_epoch
+=
1
...
...
@@ -140,20 +141,33 @@ class DataGenerator(object):
@
property
def
feeding
(
self
):
"""Returns data_reader's feeding dict."""
"""Returns data reader's feeding dict.
:return: Data feeding dict.
:rtype: dict
"""
return
{
"audio_spectrogram"
:
0
,
"transcript_text"
:
1
}
@
property
def
vocab_size
(
self
):
"""Returns vocabulary size."""
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return
self
.
_speech_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
"""Returns vocabulary list."""
"""Return the vocabulary in list.
:return: Vocabulary in list.
:rtype: list
"""
return
self
.
_speech_featurizer
.
vocab_list
def
_process_utterance
(
self
,
filename
,
transcript
):
"""Load, augment, featurize and normalize for speech data."""
speech_segment
=
SpeechSegment
.
from_file
(
filename
,
transcript
)
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
specgram
,
text_ids
=
self
.
_speech_featurizer
.
featurize
(
speech_segment
)
...
...
@@ -162,16 +176,11 @@ class DataGenerator(object):
def
_instance_reader_creator
(
self
,
manifest
):
"""
Instance reader creator
for audio data. Creat a callable function to
produce
instances of data.
Instance reader creator
. Create a callable function to produce
instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.
:param manifest: Filepath of manifest for audio clip files.
:type manifest: basestring
:return: Data reader function.
:rtype: callable
Instance: a tuple of ndarray of audio spectrogram and a list of
token indices for transcript.
"""
def
reader
():
...
...
@@ -183,24 +192,22 @@ class DataGenerator(object):
def
_padding_batch
(
self
,
batch
,
padding_to
=-
1
,
flatten
=
False
):
"""
Padding audio part of features (only in the time axis -- column axis)
with zeros, to make each instance in the batch share the same
audio feature shape.
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
If `
padding_to` is set -1, the maximun column numbers in the batch will
be used as the target size. Otherwise, `padding_to` will be the target
size. Default is -1
.
If `
`padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis)
.
If `flatten` is set True, audio data will be flatten to be a 1-dim
ndarray. Default is False.
If `flatten` is True, features will be flatten to 1darray.
"""
new_batch
=
[]
# get target shape
max_length
=
max
([
audio
.
shape
[
1
]
for
audio
,
text
in
batch
])
if
padding_to
!=
-
1
:
if
padding_to
<
max_length
:
raise
ValueError
(
"If padding_to is not -1, it should be
greater
"
"
or equal to the original instance length.
"
)
raise
ValueError
(
"If padding_to is not -1, it should be
larger
"
"
than any instance's shape in the batch
"
)
max_length
=
padding_to
# padding
for
audio
,
text
in
batch
:
...
...
@@ -212,28 +219,21 @@ class DataGenerator(object):
return
new_batch
def
_batch_shuffle
(
self
,
manifest
,
batch_size
):
"""
The instances have different lengths and they cannot be
combined into a single matrix multiplication. It usually
sorts the training examples by length and combines only
similarly-sized instances into minibatches, pads with
silence when necessary so that all instances in a batch
have the same length. This batch shuffle fuction is used
to make similarly-sized instances into minibatches and
make a batch-wise shuffle.
"""Put similarly-sized instances into minibatches for better efficiency
and make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly
remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size
.
3. Randomly
shift `k` instances in order to create different batches
for different epochs. Create minibatches
.
4. Shuffle the minibatches.
:param manifest:
manifest file
.
:param manifest:
Manifest contents. List of dict
.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:return:
b
atch shuffled mainifest.
:return:
B
atch shuffled mainifest.
:rtype: list
"""
manifest
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
...
...
data_utils/featurizer/audio_featurizer.py
浏览文件 @
b07ee84a
"""Contains the audio featurizer class."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
random
from
data_utils
import
utils
from
data_utils.audio
import
AudioSegment
class
AudioFeaturizer
(
object
):
"""Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment.
Currently, it only supports feature type of linear spectrogram.
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param max_freq: Used when specgram_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are
returned.
:types max_freq: None|float
"""
def
__init__
(
self
,
specgram_type
=
'linear'
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
random_seed
=
0
):
max_freq
=
None
):
self
.
_specgram_type
=
specgram_type
self
.
_stride_ms
=
stride_ms
self
.
_window_ms
=
window_ms
self
.
_max_freq
=
max_freq
def
featurize
(
self
,
audio_segment
):
"""Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
"""
return
self
.
_compute_specgram
(
audio_segment
.
samples
,
audio_segment
.
sample_rate
)
def
_compute_specgram
(
self
,
samples
,
sample_rate
):
"""Extract various audio features."""
if
self
.
_specgram_type
==
'linear'
:
return
self
.
_compute_linear_specgram
(
samples
,
sample_rate
,
self
.
_stride_ms
,
self
.
_window_ms
,
...
...
@@ -40,9 +64,7 @@ class AudioFeaturizer(object):
window_ms
=
20.0
,
max_freq
=
None
,
eps
=
1e-14
):
"""Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
"""Compute the linear spectrogram from FFT energy."""
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
...
...
@@ -62,9 +84,7 @@ class AudioFeaturizer(object):
return
np
.
log
(
specgram
[:
ind
,
:]
+
eps
)
def
_specgram_real
(
self
,
samples
,
window_size
,
stride_size
,
sample_rate
):
"""Compute the spectrogram by FFT for a discrete real signal.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
"""Compute the spectrogram for samples from a real signal."""
# extract strided windows
truncate_size
=
(
len
(
samples
)
-
window_size
)
%
stride_size
samples
=
samples
[:
len
(
samples
)
-
truncate_size
]
...
...
data_utils/featurizer/speech_featurizer.py
浏览文件 @
b07ee84a
"""Contains the speech featurizer class."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -7,26 +8,70 @@ from data_utils.featurizer.text_featurizer import TextFeaturizer
class
SpeechFeaturizer
(
object
):
"""Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment.
Currently, for audio parts, it only supports feature type of linear
spectrogram; for transcript parts, it only supports char-level tokenizing
and conversion into a list of token indices. Note that the token indexing
order follows the given vocabulary file.
:param vocab_filepath: Filepath to load vocabulary for token indices
conversion.
:type specgram_type: basestring
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param max_freq: Used when specgram_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are
returned.
:types max_freq: None|float
"""
def
__init__
(
self
,
vocab_filepath
,
specgram_type
=
'linear'
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
random_seed
=
0
):
self
.
_audio_featurizer
=
AudioFeaturizer
(
specgram_type
,
stride_ms
,
window_ms
,
max_freq
,
random_seed
)
max_freq
=
None
):
self
.
_audio_featurizer
=
AudioFeaturizer
(
specgram_type
,
stride_ms
,
window_ms
,
max_freq
)
self
.
_text_featurizer
=
TextFeaturizer
(
vocab_filepath
)
def
featurize
(
self
,
speech_segment
):
"""Extract features for speech segment.
1. For audio parts, extract the audio features.
2. For transcript parts, convert text string to a list of token indices
in char-level.
:param audio_segment: Speech segment to extract features from.
:type audio_segment: SpeechSegment
:return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of
char-level token indices.
:rtype: tuple
"""
audio_feature
=
self
.
_audio_featurizer
.
featurize
(
speech_segment
)
text_ids
=
self
.
_text_featurizer
.
text2ids
(
speech_segment
.
transcript
)
text_ids
=
self
.
_text_featurizer
.
featurize
(
speech_segment
.
transcript
)
return
audio_feature
,
text_ids
@
property
def
vocab_size
(
self
):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return
self
.
_text_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
"""Return the vocabulary in list.
:return: Vocabulary in list.
:rtype: list
"""
return
self
.
_text_featurizer
.
vocab_list
data_utils/featurizer/text_featurizer.py
浏览文件 @
b07ee84a
"""Contains the text featurizer class."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -6,26 +7,53 @@ import os
class
TextFeaturizer
(
object
):
"""Text featurizer, for processing or extracting features from text.
Currently, it only supports char-level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
:param vocab_filepath: Filepath to load vocabulary for token indices
conversion.
:type specgram_type: basestring
"""
def
__init__
(
self
,
vocab_filepath
):
self
.
_vocab_dict
,
self
.
_vocab_list
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
)
def
text2ids
(
self
,
text
):
def
featurize
(
self
,
text
):
"""Convert text string to a list of token indices in char-level.Note
that the token indexing order follows the given vocabulary file.
:param text: Text to process.
:type text: basestring
:return: List of char-level token indices.
:rtype: list
"""
tokens
=
self
.
_char_tokenize
(
text
)
return
[
self
.
_vocab_dict
[
token
]
for
token
in
tokens
]
def
ids2text
(
self
,
ids
):
return
''
.
join
([
self
.
_vocab_list
[
id
]
for
id
in
ids
])
@
property
def
vocab_size
(
self
):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return
len
(
self
.
_vocab_list
)
@
property
def
vocab_list
(
self
):
"""Return the vocabulary in list.
:return: Vocabulary in list.
:rtype: list
"""
return
self
.
_vocab_list
def
_char_tokenize
(
self
,
text
):
"""Character tokenizer."""
return
list
(
text
.
strip
())
def
_load_vocabulary_from_file
(
self
,
vocab_filepath
):
...
...
data_utils/normalizer.py
浏览文件 @
b07ee84a
"""Contains feature normalizers."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -9,6 +10,28 @@ from data_utils.audio import AudioSegment
class
FeatureNormalizer
(
object
):
"""Feature normalizer. Normalize features to be of zero mean and unit
stddev.
if mean_std_filepath is provided (not None), the normalizer will directly
initilize from the file. Otherwise, both manifest_path and featurize_func
should be given for on-the-fly mean and stddev computing.
:param mean_std_filepath: File containing the pre-computed mean and stddev.
:type mean_std_filepath: None|basestring
:param manifest_path: Manifest of instances for computing mean and stddev.
:type meanifest_path: None|basestring
:param featurize_func: Function to extract features. It should be callable
with ``featurize_func(audio_segment)``.
:type featurize_func: None|callable
:param num_samples: Number of random samples for computing mean and stddev.
:type num_samples: int
:param random_seed: Random seed for sampling instances.
:type random_seed: int
:raises ValueError: If both mean_std_filepath and manifest_path
(or both mean_std_filepath and featurize_func) are None.
"""
def
__init__
(
self
,
mean_std_filepath
,
manifest_path
=
None
,
...
...
@@ -25,18 +48,33 @@ class FeatureNormalizer(object):
self
.
_read_mean_std_from_file
(
mean_std_filepath
)
def
apply
(
self
,
features
,
eps
=
1e-14
):
"""Normalize features to be of zero mean and unit stddev."""
"""Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized.
:type features: ndarray
:param eps: added to stddev to provide numerical stablibity.
:type eps: float
:return: Normalized features.
:rtype: ndarray
"""
return
(
features
-
self
.
_mean
)
/
(
self
.
_std
+
eps
)
def
write_to_file
(
self
,
filepath
):
"""Write the mean and stddev to the file.
:param filepath: File to write mean and stddev.
:type filepath: basestring
"""
np
.
savez
(
filepath
,
mean
=
self
.
_mean
,
std
=
self
.
_std
)
def
_read_mean_std_from_file
(
self
,
filepath
):
"""Load mean and std from file."""
npzfile
=
np
.
load
(
filepath
)
self
.
_mean
=
npzfile
[
"mean"
]
self
.
_std
=
npzfile
[
"std"
]
def
_compute_mean_std
(
self
,
manifest_path
,
featurize_func
,
num_samples
):
"""Compute mean and std from randomly sampled instances."""
manifest
=
utils
.
read_manifest
(
manifest_path
)
sampled_manifest
=
self
.
_rng
.
sample
(
manifest
,
num_samples
)
features
=
[]
...
...
data_utils/speech.py
0 → 100755
浏览文件 @
b07ee84a
"""Contains the speech segment class."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
data_utils.audio
import
AudioSegment
class
SpeechSegment
(
AudioSegment
):
"""Speech segment abstraction, a subclass of AudioSegment,
with an additional transcript.
:param samples: Audio samples [num_samples x num_channels].
:type samples: ndarray.float32
:param sample_rate: Audio sample rate.
:type sample_rate: int
:param transcript: Transcript text for the speech.
:type transript: basestring
:raises TypeError: If the sample data type is not float or int.
"""
def
__init__
(
self
,
samples
,
sample_rate
,
transcript
):
AudioSegment
.
__init__
(
self
,
samples
,
sample_rate
)
self
.
_transcript
=
transcript
def
__eq__
(
self
,
other
):
"""Return whether two objects are equal.
"""
if
not
AudioSegment
.
__eq__
(
self
,
other
):
return
False
if
self
.
_transcript
!=
other
.
_transcript
:
return
False
return
True
def
__ne__
(
self
,
other
):
"""Return whether two objects are unequal."""
return
not
self
.
__eq__
(
other
)
@
classmethod
def
from_file
(
cls
,
filepath
,
transcript
):
"""Create speech segment from audio file and corresponding transcript.
:param filepath: Filepath or file object to audio file.
:type filepath: basestring|file
:param transcript: Transcript text for the speech.
:type transript: basestring
:return: Audio segment instance.
:rtype: AudioSegment
"""
audio
=
AudioSegment
.
from_file
(
filepath
)
return
cls
(
audio
.
samples
,
audio
.
sample_rate
,
transcript
)
@
classmethod
def
from_bytes
(
cls
,
bytes
,
transcript
):
"""Create speech segment from a byte string and corresponding
transcript.
:param bytes: Byte string containing audio samples.
:type bytes: str
:param transcript: Transcript text for the speech.
:type transript: basestring
:return: Audio segment instance.
:rtype: AudioSegment
"""
audio
=
AudioSegment
.
from_bytes
(
bytes
)
return
cls
(
audio
.
samples
,
audio
.
sample_rate
,
transcript
)
@
property
def
transcript
(
self
):
"""Return the transcript text.
:return: Transcript text for the speech.
:rtype: basestring
"""
return
self
.
_transcript
data_utils/utils.py
浏览文件 @
b07ee84a
"""Contains data helper functions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -6,7 +7,21 @@ import json
def
read_manifest
(
manifest_path
,
max_duration
=
float
(
'inf'
),
min_duration
=
0.0
):
"""Load and parse manifest file."""
"""Load and parse manifest file.
Instances with durations outside [min_duration, max_duration] will be
filtered out.
:param manifest_path: Manifest file to load and parse.
:type manifest_path: basestring
:param max_duration: Maximal duration in seconds for instance filter.
:type max_duration: float
:param min_duration: Minimal duration in seconds for instance filter.
:type min_duration: float
:return: Manifest parsing results. List of dict.
:rtype: list
:raises IOError: If failed to parse the manifest.
"""
manifest
=
[]
for
json_line
in
open
(
manifest_path
):
try
:
...
...
datasets/librispeech/librispeech.py
浏览文件 @
b07ee84a
"""
Download, unpack and create manifest json files for the Librespeech dataset.
"""Prepare Librispeech ASR datasets.
A manifest is a json file summarizing filelist in a data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file in the data set.
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle.v2
as
paddle
from
paddle.v2.dataset.common
import
md5file
import
distutils.util
import
os
import
wget
...
...
@@ -15,6 +16,7 @@ import tarfile
import
argparse
import
soundfile
import
json
from
paddle.v2.dataset.common
import
md5file
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset/speech'
)
...
...
decoder.py
浏览文件 @
b07ee84a
"""
CTC-like decoder utilitis.
"""
"""Contains various CTC decoder."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
itertools
import
groupby
import
numpy
as
np
from
itertools
import
groupby
def
ctc_best_path_decode
(
probs_seq
,
vocabulary
):
...
...
infer.py
浏览文件 @
b07ee84a
"""
Inference for a simplifed version of Baidu DeepSpeech2 model.
"""
"""Inferer for DeepSpeech2 model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
model.py
浏览文件 @
b07ee84a
"""
A simplifed version of Baidu DeepSpeech2 model.
"""
"""Contains DeepSpeech2 model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle.v2
as
paddle
#TODO: add bidirectional rnn.
def
conv_bn_layer
(
input
,
filter_size
,
num_channels_in
,
num_channels_out
,
stride
,
padding
,
act
):
...
...
train.py
浏览文件 @
b07ee84a
"""
Trainer for a simplifed version of Baidu DeepSpeech2 model.
"""
"""Trainer for DeepSpeech2 model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -164,7 +161,7 @@ def train():
print
(
"
\n
Pass: %d, Batch: %d, TrainCost: %f"
%
(
event
.
pass_id
,
event
.
batch_id
,
cost_sum
/
cost_counter
))
cost_sum
,
cost_counter
=
0.0
,
0
with
gzip
.
open
(
"params.tar.gz"
,
'w'
)
as
f
:
with
gzip
.
open
(
"params
_tmp
.tar.gz"
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
else
:
sys
.
stdout
.
write
(
'.'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录