提交 68caa8ca 编写于 作者: X Xinghai Sun 提交者: GitHub

Merge pull request #114 from xinghai-sun/ds2_feature

Improve audio featurizer and add shift augmentor for DS2.
......@@ -51,13 +51,13 @@ python compute_mean_std.py --help
For GPU Training:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py
```
For CPU Training:
```
python train.py --trainer_count 8 --use_gpu False
python train.py --use_gpu False
```
More help for arguments:
......
......@@ -66,6 +66,54 @@ class AudioSegment(object):
samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate)
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod
def from_bytes(cls, bytes):
"""Create audio segment from a byte string containing audio samples.
......@@ -105,6 +153,20 @@ class AudioSegment(object):
samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def to_wav_file(self, filepath, dtype='float32'):
"""Save audio segment to disk as wav file.
......@@ -130,68 +192,6 @@ class AudioSegment(object):
format='WAV',
subtype=subtype_map[dtype])
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def superimpose(self, other):
"""Add samples from another segment to those of this segment
(sample-wise addition, not segment concatenation).
......@@ -225,7 +225,7 @@ class AudioSegment(object):
samples = self._convert_samples_from_float32(self._samples, dtype)
return samples.tostring()
def apply_gain(self, gain):
def gain_db(self, gain):
"""Apply gain in decibels to samples.
Note that this is an in-place transformation.
......@@ -278,7 +278,7 @@ class AudioSegment(object):
"Unable to normalize segment to %f dB because the "
"the probable gain have exceeds max_gain_db (%f dB)" %
(target_db, max_gain_db))
self.apply_gain(min(max_gain_db, target_db - self.rms_db))
self.gain_db(min(max_gain_db, target_db - self.rms_db))
def normalize_online_bayesian(self,
target_db,
......@@ -319,7 +319,7 @@ class AudioSegment(object):
rms_estimate_db = 10 * np.log10(mean_squared_estimate)
# Compute required time-varying gain.
gain_db = target_db - rms_estimate_db
self.apply_gain(gain_db)
self.gain_db(gain_db)
def resample(self, target_sample_rate, quality='sinc_medium'):
"""Resample the audio to a target sample rate.
......@@ -366,6 +366,31 @@ class AudioSegment(object):
raise ValueError("Unknown value for the sides %s" % sides)
self._samples = padded._samples
def shift(self, shift_ms):
"""Shift the audio in time. If `shift_ms` is positive, shift with time
advance; if negative, shift with time delay. Silence are padded to
keep the duration unchanged.
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if abs(shift_ms) / 1000.0 > self.duration:
raise ValueError("Absolute value of shift_ms should be smaller "
"than audio duration.")
shift_samples = int(shift_ms * self._sample_rate / 1000)
if shift_samples > 0:
# time advance
self._samples[:-shift_samples] = self._samples[shift_samples:]
self._samples[-shift_samples:] = 0
elif shift_samples < 0:
# time delay
self._samples[-shift_samples:] = self._samples[:shift_samples]
self._samples[:-shift_samples] = 0
def subsegment(self, start_sec=None, end_sec=None):
"""Cut the AudioSegment between given boundaries.
......@@ -505,7 +530,7 @@ class AudioSegment(object):
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
noise_new = copy.deepcopy(noise)
noise_new.random_subsegment(self.duration, rng=rng)
noise_new.apply_gain(noise_gain_db)
noise_new.gain_db(noise_gain_db)
self.superimpose(noise_new)
@property
......
......@@ -6,6 +6,7 @@ from __future__ import print_function
import json
import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
class AugmentationPipeline(object):
......@@ -76,5 +77,7 @@ class AugmentationPipeline(object):
"""Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params)
elif augmentor_type == "shift":
return ShiftPerturbAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
......@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS)
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
audio_segment.apply_gain(gain)
......@@ -45,6 +45,9 @@ class DataGenerator(object):
:types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param use_dB_normalization: Whether to normalize the audio to -20 dB
before extracting the features.
:type use_dB_normalization: bool
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
:param random_seed: Random seed.
......@@ -61,6 +64,7 @@ class DataGenerator(object):
window_ms=20.0,
max_freq=None,
specgram_type='linear',
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count(),
random_seed=0):
self._max_duration = max_duration
......@@ -73,7 +77,8 @@ class DataGenerator(object):
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq)
max_freq=max_freq,
use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads
self._rng = random.Random(random_seed)
self._epoch = 0
......
......@@ -24,26 +24,64 @@ class AudioFeaturizer(object):
corresponding to frequencies between [0, max_freq] are
returned.
:types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def __init__(self,
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None):
max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._specgram_type = specgram_type
self._stride_ms = stride_ms
self._window_ms = window_ms
self._max_freq = max_freq
self._target_sample_rate = target_sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
def featurize(self, audio_segment):
def featurize(self,
audio_segment,
allow_downsampling=True,
allow_upsamplling=True):
"""Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before
featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before
featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
"""
# upsampling or downsampling
if ((audio_segment.sample_rate > self._target_sample_rate and
allow_downsampling) or
(audio_segment.sample_rate < self._target_sample_rate and
allow_upsampling)):
audio_segment.resample(self._target_sample_rate)
if audio_segment.sample_rate != self._target_sample_rate:
raise ValueError("Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on.")
# decibel normalization
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# extract spectrogram
return self._compute_specgram(audio_segment.samples,
audio_segment.sample_rate)
......
......@@ -29,6 +29,15 @@ class SpeechFeaturizer(object):
corresponding to frequencies between [0, max_freq] are
returned.
:types max_freq: None|float
:param target_sample_rate: Speech are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def __init__(self,
......@@ -36,9 +45,18 @@ class SpeechFeaturizer(object):
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None):
self._audio_featurizer = AudioFeaturizer(specgram_type, stride_ms,
window_ms, max_freq)
max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment):
......
......@@ -56,7 +56,7 @@ parser.add_argument(
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='./params.tar.gz',
default='checkpoints/params.latest.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
......
......@@ -27,4 +27,7 @@ if [ $? != 0 ]; then
exit 1
fi
# prepare ./checkpoints
mkdir checkpoints
echo "Install all dependencies successfully."
......@@ -17,10 +17,10 @@ import utils
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--batch_size", default=32, type=int, help="Minibatch size.")
"--batch_size", default=256, type=int, help="Minibatch size.")
parser.add_argument(
"--num_passes",
default=20,
default=200,
type=int,
help="Training pass number. (default: %(default)s)")
parser.add_argument(
......@@ -55,7 +55,7 @@ parser.add_argument(
help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
"--max_duration",
default=100.0,
default=27.0,
type=float,
help="Audios with duration larger than this will be discarded. "
"(default: %(default)s)")
......@@ -67,13 +67,13 @@ parser.add_argument(
"(default: %(default)s)")
parser.add_argument(
"--shuffle_method",
default='instance_shuffle',
default='batch_shuffle_clipped',
type=str,
help="Shuffle method: 'instance_shuffle', 'batch_shuffle', "
"'batch_shuffle_batch'. (default: %(default)s)")
parser.add_argument(
"--trainer_count",
default=4,
default=8,
type=int,
help="Trainer number. (default: %(default)s)")
parser.add_argument(
......@@ -110,7 +110,9 @@ parser.add_argument(
"the existing model of this path. (default: %(default)s)")
parser.add_argument(
"--augmentation_config",
default='{}',
default='[{"type": "shift", '
'"params": {"min_shift_ms": -5, "max_shift_ms": 5},'
'"prob": 1.0}]',
type=str,
help="Augmentation configuration in json-format. "
"(default: %(default)s)")
......@@ -189,7 +191,7 @@ def train():
print("\nPass: %d, Batch: %d, TrainCost: %f" % (
event.pass_id, event.batch_id + 1, cost_sum / cost_counter))
cost_sum, cost_counter = 0.0, 0
with gzip.open("params.tar.gz", 'w') as f:
with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f:
parameters.to_tar(f)
else:
sys.stdout.write('.')
......@@ -202,6 +204,9 @@ def train():
reader=test_batch_reader, feeding=test_generator.feeding)
print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" %
(time.time() - start_time, event.pass_id, result.cost))
with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id,
'w') as f:
parameters.to_tar(f)
# run train
trainer.train(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册