提交 13f70873 编写于 作者: X Xinghai Sun

Improve audio featurizer and add shift augmentor.

1. Improve audio featurizer.
2. Add shift augmentor.
3. Update default argument to be the current best seggestion.
4. Add checkpoints with pass id.
上级 d104eccf
......@@ -51,13 +51,13 @@ python compute_mean_std.py --help
For GPU Training:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py
```
For CPU Training:
```
python train.py --trainer_count 8 --use_gpu False
python train.py --use_gpu False
```
More help for arguments:
......
......@@ -66,6 +66,54 @@ class AudioSegment(object):
samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate)
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod
def from_bytes(cls, bytes):
"""Create audio segment from a byte string containing audio samples.
......@@ -105,6 +153,20 @@ class AudioSegment(object):
samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def to_wav_file(self, filepath, dtype='float32'):
"""Save audio segment to disk as wav file.
......@@ -130,68 +192,6 @@ class AudioSegment(object):
format='WAV',
subtype=subtype_map[dtype])
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def superimpose(self, other):
"""Add samples from another segment to those of this segment
(sample-wise addition, not segment concatenation).
......@@ -225,7 +225,7 @@ class AudioSegment(object):
samples = self._convert_samples_from_float32(self._samples, dtype)
return samples.tostring()
def apply_gain(self, gain):
def gain_db(self, gain):
"""Apply gain in decibels to samples.
Note that this is an in-place transformation.
......@@ -278,7 +278,7 @@ class AudioSegment(object):
"Unable to normalize segment to %f dB because the "
"the probable gain have exceeds max_gain_db (%f dB)" %
(target_db, max_gain_db))
self.apply_gain(min(max_gain_db, target_db - self.rms_db))
self.gain_db(min(max_gain_db, target_db - self.rms_db))
def normalize_online_bayesian(self,
target_db,
......@@ -319,7 +319,7 @@ class AudioSegment(object):
rms_estimate_db = 10 * np.log10(mean_squared_estimate)
# Compute required time-varying gain.
gain_db = target_db - rms_estimate_db
self.apply_gain(gain_db)
self.gain_db(gain_db)
def resample(self, target_sample_rate, quality='sinc_medium'):
"""Resample the audio to a target sample rate.
......@@ -366,6 +366,31 @@ class AudioSegment(object):
raise ValueError("Unknown value for the sides %s" % sides)
self._samples = padded._samples
def shift(self, shift_ms):
"""Shift the audio in time. If `shift_ms` is positive, shift with time
advance; if negative, shift with time delay. Silence are padded to
keep the duration unchanged.
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if shift_ms / 1000.0 > self.duration:
raise ValueError("Absolute value of shift_ms should be smaller "
"than audio duration.")
shift_samples = int(shift_ms * self._sample_rate / 1000)
if shift_samples > 0:
# time advance
self._samples[:-shift_samples] = self._samples[shift_samples:]
self._samples[-shift_samples:] = 0
elif shift_samples < 0:
# time delay
self._samples[-shift_samples:] = self._samples[:shift_samples]
self._samples[:-shift_samples] = 0
def subsegment(self, start_sec=None, end_sec=None):
"""Cut the AudioSegment between given boundaries.
......@@ -505,7 +530,7 @@ class AudioSegment(object):
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
noise_new = copy.deepcopy(noise)
noise_new.random_subsegment(self.duration, rng=rng)
noise_new.apply_gain(noise_gain_db)
noise_new.gain_db(noise_gain_db)
self.superimpose(noise_new)
@property
......
......@@ -6,6 +6,7 @@ from __future__ import print_function
import json
import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
class AugmentationPipeline(object):
......@@ -76,5 +77,7 @@ class AugmentationPipeline(object):
"""Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params)
elif augmentor_type == "shift":
return ShiftPerturbAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
......@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS)
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
audio_segment.apply_gain(gain)
......@@ -45,6 +45,9 @@ class DataGenerator(object):
:types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param use_dB_normalization: Whether to normalize the audio to -20 dB
before extracting the features.
:type use_dB_normalization: bool
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
:param random_seed: Random seed.
......@@ -61,6 +64,7 @@ class DataGenerator(object):
window_ms=20.0,
max_freq=None,
specgram_type='linear',
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count(),
random_seed=0):
self._max_duration = max_duration
......@@ -73,7 +77,8 @@ class DataGenerator(object):
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq)
max_freq=max_freq,
use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads
self._rng = random.Random(random_seed)
self._epoch = 0
......
......@@ -24,26 +24,64 @@ class AudioFeaturizer(object):
corresponding to frequencies between [0, max_freq] are
returned.
:types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def __init__(self,
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None):
max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._specgram_type = specgram_type
self._stride_ms = stride_ms
self._window_ms = window_ms
self._max_freq = max_freq
self._target_sample_rate = target_sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
def featurize(self, audio_segment):
def featurize(self,
audio_segment,
allow_downsampling=True,
allow_upsamplling=True):
"""Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before
featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before
featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
"""
# upsampling or downsampling
if ((audio_segment.sample_rate > self._target_sample_rate and
allow_downsampling) or
(audio_segment.sample_rate < self._target_sample_rate and
allow_upsampling)):
audio_segment.resample(self._target_sample_rate)
if audio_segment.sample_rate != self._target_sample_rate:
raise ValueError("Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on.")
# decibel normalization
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# extract spectrogram
return self._compute_specgram(audio_segment.samples,
audio_segment.sample_rate)
......
......@@ -29,6 +29,15 @@ class SpeechFeaturizer(object):
corresponding to frequencies between [0, max_freq] are
returned.
:types max_freq: None|float
:param target_sample_rate: Speech are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def __init__(self,
......@@ -36,9 +45,18 @@ class SpeechFeaturizer(object):
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None):
self._audio_featurizer = AudioFeaturizer(specgram_type, stride_ms,
window_ms, max_freq)
max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment):
......
......@@ -56,7 +56,7 @@ parser.add_argument(
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='./params.tar.gz',
default='checkpoints/params.latest.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
......
......@@ -27,4 +27,7 @@ if [ $? != 0 ]; then
exit 1
fi
# prepare ./checkpoints
mkdir checkpoints
echo "Install all dependencies successfully."
......@@ -17,10 +17,10 @@ import utils
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--batch_size", default=32, type=int, help="Minibatch size.")
"--batch_size", default=256, type=int, help="Minibatch size.")
parser.add_argument(
"--num_passes",
default=20,
default=200,
type=int,
help="Training pass number. (default: %(default)s)")
parser.add_argument(
......@@ -55,7 +55,7 @@ parser.add_argument(
help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
"--max_duration",
default=100.0,
default=27.0,
type=float,
help="Audios with duration larger than this will be discarded. "
"(default: %(default)s)")
......@@ -67,13 +67,13 @@ parser.add_argument(
"(default: %(default)s)")
parser.add_argument(
"--shuffle_method",
default='instance_shuffle',
default='batch_shuffle_clipped',
type=str,
help="Shuffle method: 'instance_shuffle', 'batch_shuffle', "
"'batch_shuffle_batch'. (default: %(default)s)")
parser.add_argument(
"--trainer_count",
default=4,
default=8,
type=int,
help="Trainer number. (default: %(default)s)")
parser.add_argument(
......@@ -110,7 +110,9 @@ parser.add_argument(
"the existing model of this path. (default: %(default)s)")
parser.add_argument(
"--augmentation_config",
default='{}',
default='[{"type": "shift", '
'"params": {"min_shift_ms": -5, "max_shift_ms": 5},'
'"prob": 1.0}]',
type=str,
help="Augmentation configuration in json-format. "
"(default: %(default)s)")
......@@ -189,7 +191,7 @@ def train():
print("\nPass: %d, Batch: %d, TrainCost: %f" % (
event.pass_id, event.batch_id + 1, cost_sum / cost_counter))
cost_sum, cost_counter = 0.0, 0
with gzip.open("params.tar.gz", 'w') as f:
with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f:
parameters.to_tar(f)
else:
sys.stdout.write('.')
......@@ -202,6 +204,9 @@ def train():
reader=test_batch_reader, feeding=test_generator.feeding)
print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" %
(time.time() - start_time, event.pass_id, result.cost))
with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id,
'w') as f:
parameters.to_tar(f)
# run train
trainer.train(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册