提交 68caa8ca 编写于 作者: X Xinghai Sun 提交者: GitHub

Merge pull request #114 from xinghai-sun/ds2_feature

Improve audio featurizer and add shift augmentor for DS2.
...@@ -51,13 +51,13 @@ python compute_mean_std.py --help ...@@ -51,13 +51,13 @@ python compute_mean_std.py --help
For GPU Training: For GPU Training:
``` ```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py
``` ```
For CPU Training: For CPU Training:
``` ```
python train.py --trainer_count 8 --use_gpu False python train.py --use_gpu False
``` ```
More help for arguments: More help for arguments:
......
...@@ -66,6 +66,54 @@ class AudioSegment(object): ...@@ -66,6 +66,54 @@ class AudioSegment(object):
samples, sample_rate = soundfile.read(file, dtype='float32') samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate) return cls(samples, sample_rate)
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod @classmethod
def from_bytes(cls, bytes): def from_bytes(cls, bytes):
"""Create audio segment from a byte string containing audio samples. """Create audio segment from a byte string containing audio samples.
...@@ -105,6 +153,20 @@ class AudioSegment(object): ...@@ -105,6 +153,20 @@ class AudioSegment(object):
samples = np.concatenate([seg.samples for seg in segments]) samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate) return cls(samples, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def to_wav_file(self, filepath, dtype='float32'): def to_wav_file(self, filepath, dtype='float32'):
"""Save audio segment to disk as wav file. """Save audio segment to disk as wav file.
...@@ -130,68 +192,6 @@ class AudioSegment(object): ...@@ -130,68 +192,6 @@ class AudioSegment(object):
format='WAV', format='WAV',
subtype=subtype_map[dtype]) subtype=subtype_map[dtype])
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def superimpose(self, other): def superimpose(self, other):
"""Add samples from another segment to those of this segment """Add samples from another segment to those of this segment
(sample-wise addition, not segment concatenation). (sample-wise addition, not segment concatenation).
...@@ -225,7 +225,7 @@ class AudioSegment(object): ...@@ -225,7 +225,7 @@ class AudioSegment(object):
samples = self._convert_samples_from_float32(self._samples, dtype) samples = self._convert_samples_from_float32(self._samples, dtype)
return samples.tostring() return samples.tostring()
def apply_gain(self, gain): def gain_db(self, gain):
"""Apply gain in decibels to samples. """Apply gain in decibels to samples.
Note that this is an in-place transformation. Note that this is an in-place transformation.
...@@ -278,7 +278,7 @@ class AudioSegment(object): ...@@ -278,7 +278,7 @@ class AudioSegment(object):
"Unable to normalize segment to %f dB because the " "Unable to normalize segment to %f dB because the "
"the probable gain have exceeds max_gain_db (%f dB)" % "the probable gain have exceeds max_gain_db (%f dB)" %
(target_db, max_gain_db)) (target_db, max_gain_db))
self.apply_gain(min(max_gain_db, target_db - self.rms_db)) self.gain_db(min(max_gain_db, target_db - self.rms_db))
def normalize_online_bayesian(self, def normalize_online_bayesian(self,
target_db, target_db,
...@@ -319,7 +319,7 @@ class AudioSegment(object): ...@@ -319,7 +319,7 @@ class AudioSegment(object):
rms_estimate_db = 10 * np.log10(mean_squared_estimate) rms_estimate_db = 10 * np.log10(mean_squared_estimate)
# Compute required time-varying gain. # Compute required time-varying gain.
gain_db = target_db - rms_estimate_db gain_db = target_db - rms_estimate_db
self.apply_gain(gain_db) self.gain_db(gain_db)
def resample(self, target_sample_rate, quality='sinc_medium'): def resample(self, target_sample_rate, quality='sinc_medium'):
"""Resample the audio to a target sample rate. """Resample the audio to a target sample rate.
...@@ -366,6 +366,31 @@ class AudioSegment(object): ...@@ -366,6 +366,31 @@ class AudioSegment(object):
raise ValueError("Unknown value for the sides %s" % sides) raise ValueError("Unknown value for the sides %s" % sides)
self._samples = padded._samples self._samples = padded._samples
def shift(self, shift_ms):
"""Shift the audio in time. If `shift_ms` is positive, shift with time
advance; if negative, shift with time delay. Silence are padded to
keep the duration unchanged.
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if abs(shift_ms) / 1000.0 > self.duration:
raise ValueError("Absolute value of shift_ms should be smaller "
"than audio duration.")
shift_samples = int(shift_ms * self._sample_rate / 1000)
if shift_samples > 0:
# time advance
self._samples[:-shift_samples] = self._samples[shift_samples:]
self._samples[-shift_samples:] = 0
elif shift_samples < 0:
# time delay
self._samples[-shift_samples:] = self._samples[:shift_samples]
self._samples[:-shift_samples] = 0
def subsegment(self, start_sec=None, end_sec=None): def subsegment(self, start_sec=None, end_sec=None):
"""Cut the AudioSegment between given boundaries. """Cut the AudioSegment between given boundaries.
...@@ -505,7 +530,7 @@ class AudioSegment(object): ...@@ -505,7 +530,7 @@ class AudioSegment(object):
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db) noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
noise_new = copy.deepcopy(noise) noise_new = copy.deepcopy(noise)
noise_new.random_subsegment(self.duration, rng=rng) noise_new.random_subsegment(self.duration, rng=rng)
noise_new.apply_gain(noise_gain_db) noise_new.gain_db(noise_gain_db)
self.superimpose(noise_new) self.superimpose(noise_new)
@property @property
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import json import json
import random import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
class AugmentationPipeline(object): class AugmentationPipeline(object):
...@@ -76,5 +77,7 @@ class AugmentationPipeline(object): ...@@ -76,5 +77,7 @@ class AugmentationPipeline(object):
"""Return an augmentation model by the type name, and pass in params.""" """Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume": if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params) return VolumePerturbAugmentor(self._rng, **params)
elif augmentor_type == "shift":
return ShiftPerturbAugmentor(self._rng, **params)
else: else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type) raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
...@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase): ...@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to. :param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment :type audio_segment: AudioSegmenet|SpeechSegment
""" """
gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
audio_segment.apply_gain(gain) audio_segment.apply_gain(gain)
...@@ -45,6 +45,9 @@ class DataGenerator(object): ...@@ -45,6 +45,9 @@ class DataGenerator(object):
:types max_freq: None|float :types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'. :param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str :type specgram_type: str
:param use_dB_normalization: Whether to normalize the audio to -20 dB
before extracting the features.
:type use_dB_normalization: bool
:param num_threads: Number of CPU threads for processing data. :param num_threads: Number of CPU threads for processing data.
:type num_threads: int :type num_threads: int
:param random_seed: Random seed. :param random_seed: Random seed.
...@@ -61,6 +64,7 @@ class DataGenerator(object): ...@@ -61,6 +64,7 @@ class DataGenerator(object):
window_ms=20.0, window_ms=20.0,
max_freq=None, max_freq=None,
specgram_type='linear', specgram_type='linear',
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count(), num_threads=multiprocessing.cpu_count(),
random_seed=0): random_seed=0):
self._max_duration = max_duration self._max_duration = max_duration
...@@ -73,7 +77,8 @@ class DataGenerator(object): ...@@ -73,7 +77,8 @@ class DataGenerator(object):
specgram_type=specgram_type, specgram_type=specgram_type,
stride_ms=stride_ms, stride_ms=stride_ms,
window_ms=window_ms, window_ms=window_ms,
max_freq=max_freq) max_freq=max_freq,
use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads self._num_threads = num_threads
self._rng = random.Random(random_seed) self._rng = random.Random(random_seed)
self._epoch = 0 self._epoch = 0
......
...@@ -24,26 +24,64 @@ class AudioFeaturizer(object): ...@@ -24,26 +24,64 @@ class AudioFeaturizer(object):
corresponding to frequencies between [0, max_freq] are corresponding to frequencies between [0, max_freq] are
returned. returned.
:types max_freq: None|float :types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
""" """
def __init__(self, def __init__(self,
specgram_type='linear', specgram_type='linear',
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
max_freq=None): max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._specgram_type = specgram_type self._specgram_type = specgram_type
self._stride_ms = stride_ms self._stride_ms = stride_ms
self._window_ms = window_ms self._window_ms = window_ms
self._max_freq = max_freq self._max_freq = max_freq
self._target_sample_rate = target_sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
def featurize(self, audio_segment): def featurize(self,
audio_segment,
allow_downsampling=True,
allow_upsamplling=True):
"""Extract audio features from AudioSegment or SpeechSegment. """Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from. :param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment :type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before
featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before
featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray. :return: Spectrogram audio feature in 2darray.
:rtype: ndarray :rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
""" """
# upsampling or downsampling
if ((audio_segment.sample_rate > self._target_sample_rate and
allow_downsampling) or
(audio_segment.sample_rate < self._target_sample_rate and
allow_upsampling)):
audio_segment.resample(self._target_sample_rate)
if audio_segment.sample_rate != self._target_sample_rate:
raise ValueError("Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on.")
# decibel normalization
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# extract spectrogram
return self._compute_specgram(audio_segment.samples, return self._compute_specgram(audio_segment.samples,
audio_segment.sample_rate) audio_segment.sample_rate)
......
...@@ -29,6 +29,15 @@ class SpeechFeaturizer(object): ...@@ -29,6 +29,15 @@ class SpeechFeaturizer(object):
corresponding to frequencies between [0, max_freq] are corresponding to frequencies between [0, max_freq] are
returned. returned.
:types max_freq: None|float :types max_freq: None|float
:param target_sample_rate: Speech are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
""" """
def __init__(self, def __init__(self,
...@@ -36,9 +45,18 @@ class SpeechFeaturizer(object): ...@@ -36,9 +45,18 @@ class SpeechFeaturizer(object):
specgram_type='linear', specgram_type='linear',
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
max_freq=None): max_freq=None,
self._audio_featurizer = AudioFeaturizer(specgram_type, stride_ms, target_sample_rate=16000,
window_ms, max_freq) use_dB_normalization=True,
target_dB=-20):
self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath) self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment): def featurize(self, speech_segment):
......
...@@ -56,7 +56,7 @@ parser.add_argument( ...@@ -56,7 +56,7 @@ parser.add_argument(
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--model_filepath", "--model_filepath",
default='./params.tar.gz', default='checkpoints/params.latest.tar.gz',
type=str, type=str,
help="Model filepath. (default: %(default)s)") help="Model filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
......
...@@ -27,4 +27,7 @@ if [ $? != 0 ]; then ...@@ -27,4 +27,7 @@ if [ $? != 0 ]; then
exit 1 exit 1
fi fi
# prepare ./checkpoints
mkdir checkpoints
echo "Install all dependencies successfully." echo "Install all dependencies successfully."
...@@ -17,10 +17,10 @@ import utils ...@@ -17,10 +17,10 @@ import utils
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--batch_size", default=32, type=int, help="Minibatch size.") "--batch_size", default=256, type=int, help="Minibatch size.")
parser.add_argument( parser.add_argument(
"--num_passes", "--num_passes",
default=20, default=200,
type=int, type=int,
help="Training pass number. (default: %(default)s)") help="Training pass number. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -55,7 +55,7 @@ parser.add_argument( ...@@ -55,7 +55,7 @@ parser.add_argument(
help="Use sortagrad or not. (default: %(default)s)") help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--max_duration", "--max_duration",
default=100.0, default=27.0,
type=float, type=float,
help="Audios with duration larger than this will be discarded. " help="Audios with duration larger than this will be discarded. "
"(default: %(default)s)") "(default: %(default)s)")
...@@ -67,13 +67,13 @@ parser.add_argument( ...@@ -67,13 +67,13 @@ parser.add_argument(
"(default: %(default)s)") "(default: %(default)s)")
parser.add_argument( parser.add_argument(
"--shuffle_method", "--shuffle_method",
default='instance_shuffle', default='batch_shuffle_clipped',
type=str, type=str,
help="Shuffle method: 'instance_shuffle', 'batch_shuffle', " help="Shuffle method: 'instance_shuffle', 'batch_shuffle', "
"'batch_shuffle_batch'. (default: %(default)s)") "'batch_shuffle_batch'. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--trainer_count", "--trainer_count",
default=4, default=8,
type=int, type=int,
help="Trainer number. (default: %(default)s)") help="Trainer number. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -110,7 +110,9 @@ parser.add_argument( ...@@ -110,7 +110,9 @@ parser.add_argument(
"the existing model of this path. (default: %(default)s)") "the existing model of this path. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--augmentation_config", "--augmentation_config",
default='{}', default='[{"type": "shift", '
'"params": {"min_shift_ms": -5, "max_shift_ms": 5},'
'"prob": 1.0}]',
type=str, type=str,
help="Augmentation configuration in json-format. " help="Augmentation configuration in json-format. "
"(default: %(default)s)") "(default: %(default)s)")
...@@ -189,7 +191,7 @@ def train(): ...@@ -189,7 +191,7 @@ def train():
print("\nPass: %d, Batch: %d, TrainCost: %f" % ( print("\nPass: %d, Batch: %d, TrainCost: %f" % (
event.pass_id, event.batch_id + 1, cost_sum / cost_counter)) event.pass_id, event.batch_id + 1, cost_sum / cost_counter))
cost_sum, cost_counter = 0.0, 0 cost_sum, cost_counter = 0.0, 0
with gzip.open("params.tar.gz", 'w') as f: with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
else: else:
sys.stdout.write('.') sys.stdout.write('.')
...@@ -202,6 +204,9 @@ def train(): ...@@ -202,6 +204,9 @@ def train():
reader=test_batch_reader, feeding=test_generator.feeding) reader=test_batch_reader, feeding=test_generator.feeding)
print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" %
(time.time() - start_time, event.pass_id, result.cost)) (time.time() - start_time, event.pass_id, result.cost))
with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id,
'w') as f:
parameters.to_tar(f)
# run train # run train
trainer.train( trainer.train(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册