提交 a84bdf64 编写于 作者: chrisxu2014's avatar chrisxu2014

add augmentation

上级 31a36018
......@@ -6,6 +6,8 @@ from __future__ import print_function
import numpy as np
import io
import soundfile
import scikits.samplerate
from scipy import signal
class AudioSegment(object):
......@@ -62,6 +64,69 @@ class AudioSegment(object):
samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate)
@classmethod
def slice_from_file(cls, fname, start=None, end=None):
"""
Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param fname: input audio file name
:type fname: bsaestring
:param start: start time in seconds (supported granularity is ms)
If start is negative, it wraps around from the end. If not
provided, this function reads from the very beginning.
:type start: float
:param end: start time in seconds (supported granularity is ms)
If end is negative, it wraps around from the end. If not
provided, the default behvaior is to read to the end of the
file.
:type end: float
:return:the specified slice of input audio in the audio.AudioSegment
format.
"""
sndfile = soundfile.SoundFile(fname)
sample_rate = sndfile.samplerate
if sndfile.channels != 1:
raise TypeError("{} has more than 1 channel.".format(fname))
duration = float(len(sndfile)) / sample_rate
if start is None:
start = 0.0
if end is None:
end = duration
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise IndexError("The slice start position ({} s) is out of "
"bounds. Filename: {}".format(start, fname))
if end < 0.0:
raise IndexError("The slice end position ({} s) is out of bounds "
"Filename: {}".format(end, fname))
if start > end:
raise IndexError("The slice start position ({} s) is later than "
"the slice end position ({} s)."
.format(start, end))
if end > duration:
raise ValueError("The slice end time ({} s) is out of "
"bounds (> {} s) Filename: {}"
.format(end, duration, fname))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod
def from_bytes(cls, bytes):
"""Create audio segment from a byte string containing audio samples.
......@@ -75,6 +140,44 @@ class AudioSegment(object):
io.BytesIO(bytes), dtype='float32')
return cls(samples, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and
sample rate.
:param duration: length of silence in seconds
:type duration: scalar
:param sample_rate: sample rate
:type sample_rate: scalar
:returns: silence of the given duration
:rtype: AudioSegment
"""
samples = np.zeros(int(float(duration) * sample_rate))
return cls(samples, sample_rate)
@classmethod
def concatenate(cls, *segments):
"""Concatenate an arbitrary number of audio segments together.
:param *segments: input audio segments
:type *segments: [AudioSegment]
"""
# Perform basic sanity-checks.
N = len(segments)
if N == 0:
raise ValueError("No audio segments are given to concatenate.")
sample_rate = segments[0]._sample_rate
for segment in segments:
if sample_rate != segment._sample_rate:
raise ValueError("Can't concatenate segments with "
"different sample rates")
if type(segment) is not cls:
raise TypeError("Only audio segments of the same type "
"instance can be concatenated.")
samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate)
def to_wav_file(self, filepath, dtype='float32'):
"""Save audio segment to disk as wav file.
......@@ -143,23 +246,288 @@ class AudioSegment(object):
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
self._samples = np.interp(new_indices, old_indices, self._samples)
def normalize(self, target_sample_rate):
raise NotImplementedError()
def normalize(self, target_db=-20, max_gain_db=300.0):
"""Normalize audio to desired RMS value in decibels.
Note that this is an in-place transformation.
:param target_db: Target RMS value in decibels.This value
should be less than 0.0 as 0.0 is full-scale audio.
:type target_db: float, optional
:param max_gain_db: Max amount of gain in dB that can be applied
for normalization. This is to prevent nans when attempting
to normalize a signal consisting of all zeros.
:type max_gain_db: float, optional
def resample(self, target_sample_rate):
raise NotImplementedError()
:raises NormalizationWarning: if the required gain to normalize the
segment to the target_db value exceeds max_gain_db.
"""
gain = target_db - self.rms_db
if gain > max_gain_db:
raise ValueError(
"Unable to normalize segment to {} dB because it has an RMS "
"value of {} dB and the difference exceeds max_gain_db ({} dB)"
.format(target_db, self.rms_db, max_gain_db))
gain = min(max_gain_db, target_db - self.rms_db)
self.apply_gain(gain)
def normalize_online_bayesian(self,
target_db,
prior_db,
prior_samples,
startup_delay=0.0):
"""
Normalize audio using a production-compatible online/causal algorithm.
This uses an exponential likelihood and gamma prior to make
online estimates of the RMS even when there are very few samples.
Note that this is an in-place transformation.
:param target_db: Target RMS value in decibels
:type target_bd: scalar
:param prior_db: Prior RMS estimate in decibels
:type prior_db: scalar
:param prior_samples: Prior strength in number of samples
:type prior_samples: scalar
:param startup_delay: Default: 0.0 s. If provided, this
function will accrue statistics for the first startup_delay
seconds before applying online normalization.
:type startup_delay: scalar
"""
# Estimate total RMS online
startup_sample_idx = min(self.num_samples - 1,
int(self.sample_rate * startup_delay))
prior_mean_squared = 10.**(prior_db / 10.)
prior_sum_of_squares = prior_mean_squared * prior_samples
cumsum_of_squares = np.cumsum(self.samples**2)
sample_count = np.arange(len(self)) + 1
if startup_sample_idx > 0:
cumsum_of_squares[:startup_sample_idx] = \
cumsum_of_squares[startup_sample_idx]
sample_count[:startup_sample_idx] = \
sample_count[startup_sample_idx]
mean_squared_estimate = ((cumsum_of_squares + prior_sum_of_squares) /
(sample_count + prior_samples))
rms_estimate_db = 10 * np.log10(mean_squared_estimate)
# Compute required time-varying gain
gain_db = target_db - rms_estimate_db
# Apply gain to new segment
self.apply_gain(gain_db)
def normalize_ewma(self,
target_db,
decay_rate,
startup_delay,
rms_eps=1e-6,
max_gain_db=300.0):
startup_sample_idx = min(self.num_samples - 1,
int(self.sample_rate * startup_delay))
mean_sq = self.samples**2
if startup_sample_idx > 0:
mean_sq[:startup_sample_idx] = \
np.sum(mean_sq[:startup_sample_idx]) / startup_sample_idx
idx_start = max(0, startup_sample_idx - 1)
initial_condition = mean_sq[idx_start] * decay_rate
mean_sq[idx_start:] = lfilter(
[1.0 - decay_rate], [1.0, -decay_rate],
mean_sq[idx_start:],
axis=0,
zi=[initial_condition])[0]
rms_estimate_db = 10.0 * np.log10(mean_sq + rms_eps)
gain_db = target_db - rms_estimate_db
if np.any(gain_db > max_gain_db):
warnings.warn(
"Unable to normalize segment to {} dB because it has an RMS "
"value of {} dB and the difference exceeds max_gain_db ({} dB)"
.format(target_db, self.rms_db, max_gain_db),
NormalizationWarning)
gain_db = np.minimum(gain_db, max_gain_db)
self.apply_gain(gain_db)
def resample(self, target_sample_rate, quality='sinc_medium'):
"""Resample audio and return new AudioSegment.
This resamples the audio to a new sample rate and returns a brand
new AudioSegment. The existing AudioSegment is unchanged.
Note that this is an in-place transformation.
:param new_sample_rate: target sample rate
:type new_sample_rate: scalar
:param quality: One of {'sinc_fastest', 'sinc_medium', 'sinc_best'}.
Sets resampling speed/quality tradeoff.
See http://www.mega-nerd.com/SRC/api_misc.html#Converters
:type quality: basestring
"""
resample_ratio = target_sample_rate / self._sample_rate
new_samples = scikits.samplerate.resample(
self._samples, r=resample_ratio, type=quality)
self._samples = new_samples
self._sample_rate = new_sample_rate
def pad_silence(self, duration, sides='both'):
raise NotImplementedError()
"""Pads this audio sample with a period of silence.
Note that this is an in-place transformation.
:param duration: length of silence in seconds to pad
:type duration: float
:param sides:
'beginning' - adds silence in the beginning
'end' - adds silence in the end
'both' - adds silence in both the beginning and the end.
:type sides: basestring
"""
if duration == 0.0:
return self
cls = type(self)
silence = cls.make_silence(duration, self._sample_rate)
if sides == "beginning":
padded = cls.concatenate(silence, self)
elif sides == "end":
padded = cls.concatenate(self, silence)
elif sides == "both":
padded = cls.concatenate(silence, self, silence)
else:
raise ValueError("Unknown value for the kwarg 'sides'")
self._samples = padded._samples
self._sample_rate = padded._sample_rate
def subsegment(self, start_sec=None, end_sec=None):
raise NotImplementedError()
"""Return new AudioSegment containing audio between given boundaries.
:param start_sec: Beginning of subsegment in seconds,
(beginning of segment if None).
:type start_sec: scalar
:param end_sec: End of subsegment in seconds,
(end of segment if None).
:type end_sec: scalar
:return: New AudioSegment containing specified
subsegment.
:trype: AudioSegment
"""
# Default boundaries
if start_sec is None:
start_sec = 0.0
if end_sec is None:
end_sec = self.duration
# negative boundaries are relative to end of segment
if start_sec < 0.0:
start_sec = self.duration + start_sec
if end_sec < 0.0:
end_sec = self.duration + end_sec
def convolve(self, filter, allow_resample=False):
raise NotImplementedError()
start_sample = int(round(start_sec * self._sample_rate))
end_sample = int(round(end_sec * self._sample_rate))
samples = self._samples[start_sample:end_sample]
def convolve_and_normalize(self, filter, allow_resample=False):
raise NotImplementedError()
return type(self)(samples, sample_rate=self._sample_rate)
def random_subsegment(self, subsegment_length, rng=None):
"""
Return a random subsegment of a specified length in seconds.
:param subsegment_length: Subsegment length in seconds.
:type subsegment_length: scalar
:param rng: Random number generator state
:type rng: random.Random [optional]
:return:clip (SpeechDLSegment): New SpeechDLSegmen containing random
subsegment of original segment.
"""
if rng is None:
rng = random.Random()
if subsegment_length > self.duration:
raise ValueError("Length of subsegment must not be greater "
"than original segment.")
start_time = rng.uniform(0.0, self.duration - subsegment_length)
return self.subsegment(start_time, start_time + subsegment_length)
def convolve(self, ir, allow_resampling=False):
"""Convolve this audio segment with the given filter.
:param ir: impulse response
:type ir: AudioSegment
:param allow_resampling: indicates whether resampling is allowed
when the ir has a different sample rate from this signal.
:type allow_resampling: boolean
"""
if allow_resampling and self.sample_rate != ir.sample_rate:
ir = ir.resample(self.sample_rate)
if self.sample_rate != ir.sample_rate:
raise ValueError("Impulse response sample rate ({}Hz) is "
"equal to base signal sample rate ({}Hz)."
.format(ir.sample_rate, self.sample_rate))
samples = signal.fftconvolve(self.samples, ir.samples, "full")
self._samples = samples
def convolve_and_normalize(self, ir, allow_resample=False):
"""Convolve and normalize the resulting audio segment so that it
has the same average power as the input signal.
:param ir: impulse response
:type ir: AudioSegment
:param allow_resampling: indicates whether resampling is allowed
when the ir has a different sample rate from this signal.
:type allow_resampling: boolean
"""
self.convolve(ir, allow_resampling=allow_resampling)
self.normalize(target_db=self.rms_db)
def add_noise(self,
noise,
snr_dB,
allow_downsampling=False,
max_gain_db=300.0,
rng=None):
"""Adds the given noise segment at a specific signal-to-noise ratio.
If the noise segment is longer than this segment, a random subsegment
of matching length is sampled from it and used instead.
:param noise: Noise signal to add.
:type noise: SpeechDLSegment
:param snr_dB: Signal-to-Noise Ratio, in decibels.
:type snr_dB: scalar
:param allow_downsampling: whether to allow the noise signal
to be downsampled to match the base signal sample rate.
:type allow_downsampling: boolean
:param max_gain_db: Maximum amount of gain to apply to noise
signal before adding it in. This is to prevent attempting
to apply infinite gain to a zero signal.
:type max_gain_db: scalar
:param rng: Random number generator state.
:type rng: random.Random
Returns:
SpeechDLSegment: signal with noise added.
"""
if rng is None:
rng = random.Random()
if allow_downsampling and noise.sample_rate > self.sample_rate:
noise = noise.resample(self.sample_rate)
if noise.sample_rate != self.sample_rate:
raise ValueError("Noise sample rate ({}Hz) is not equal to "
"base signal sample rate ({}Hz)."
.format(noise.sample_rate, self.sample_rate))
if noise.duration < self.duration:
raise ValueError("Noise signal ({} sec) must be at "
"least as long as base signal ({} sec)."
.format(noise.duration, self.duration))
noise_gain_db = self.rms_db - noise.rms_db - snr_dB
noise_gain_db = min(max_gain_db, noise_gain_db)
noise_subsegment = noise.random_subsegment(self.duration, rng=rng)
output = self + self.tranform_noise(noise_subsegment, noise_gain_db)
self._samples = output._samples
self._sample_rate = output._sample_rate
@property
def samples(self):
......@@ -186,7 +554,7 @@ class AudioSegment(object):
:return: Number of samples.
:rtype: int
"""
return self._samples.shape(0)
return self._samples.shape[0]
@property
def duration(self):
......@@ -250,3 +618,9 @@ class AudioSegment(object):
else:
raise TypeError("Unsupported sample type: %s." % samples.dtype)
return output_samples.astype(dtype)
def tranform_noise(self, noise_subsegment, noise_gain_db):
""" tranform noise file
"""
return type(self)(noise_subsegment._samples * (10.**(
noise_gain_db / 20.)), noise_subsegment._sample_rate)
from __future__ import print_function
from collections import defaultdict
import bisect
import logging
import numpy as np
import os
import random
import sys
UNK_TAG = "<UNK>"
def stream_audio_index(fname, UNK=UNK_TAG):
"""Reads an audio index file and emits one record in the index at a time.
:param fname: audio index path
:type fname: basestring
:param UNK: UNK token to denote that certain audios are not tagged.
:type UNK: basesring
Yields:
idx, duration, size, relpath, tags (int, float, int, str, list(str)):
audio file id, length of the audio in seconds, size in byte,
relative path w.r.t. to the root noise directory, list of tags
"""
with open(fname) as audio_index_file:
for i, line in enumerate(audio_index_file):
tok = line.strip().split("\t")
assert len(tok) >= 4, \
"Invalid line at line {} in file {}".format(
i + 1, audio_index_file)
idx = int(tok[0])
duration = float(tok[1])
# Sometimes, the duration can round down to 0.0
assert duration >= 0.0, \
"Invalid duration at line {} in file {}".format(
i + 1, audio_index_file)
size = int(tok[2])
assert size > 0, \
"Invalid size at line {} in file {}".format(
i + 1, audio_index_file)
relpath = tok[3]
if len(tok) == 4:
tags = [UNK_TAG]
else:
tags = tok[4:]
yield idx, duration, size, relpath, tags
def truncate_float(val, ndigits=6):
""" Truncates a floating-point value to have the desired number of
digits after the decimal point.
:param val: input value.
:type val: float
:parma ndigits: desired number of digits.
:type ndigits: int
:return: truncated value
:rtype: float
"""
p = 10.0**ndigits
return float(int(val * p)) / p
def print_audio_index(idx, duration, size, relpath, tags, file=sys.stdout):
"""Prints an audio record to the index file.
:param idx: Audio file id.
:type idx: int
:param duration: length of the audio in seconds
:type duration: float
:param size: size of the file in bytes
:type size: int
:param relpath: relative path w.r.t. to the root noise directory.
:type relpath: basestring
:parma tags: list of tags
:parma tags: list(str)
:parma file: file to which we want to write an audio record.
:type file: sys.stdout
"""
file.write("{}\t{:.6f}\t{}\t{}"
.format(idx, truncate_float(duration, ndigits=6), size, relpath))
for tag in tags:
file.write("\t{}".format(tag))
file.write("\n")
class AudioIndex(object):
""" In-memory index of audio files that do not have annotations.
This supports duration-based sampling and sampling from a target
distribution.
Each line in the index file consists of the following fields:
(id (int), duration (float), size (int), relative path (str),
list of tags ([str]))
"""
def __init__(self):
self.audio_dir = None
self.index_fname = None
self.tags = None
self.bin_size = 2.0
self.clear()
def clear(self):
""" Clears the index
Returns:
None
"""
self.idx_to_record = {}
# The list of indices correspond to audio files whose duration is
# greater than or equal to the key.
self.duration_to_id_set = {}
self.duration_to_id_set_per_tag = defaultdict(lambda: {})
self.duration_to_list = defaultdict(lambda: [])
self.duration_to_list_per_tag = defaultdict(
lambda: defaultdict(lambda: []))
self.tag_to_id_set = defaultdict(lambda: set())
self.shared_duration_bins = []
self.id_set_complete = set()
self.id_set = set()
self.duration_bins = []
def has_audio(self, distr=None):
"""
:param distr: The target distribution of audio tags that we want to
match. If this is not supplied, the function simply checks that
there are some audio files.
:parma distr: dict
:return: True if there are audio files.
:rtype: boolean
"""
if distr is None:
return len(self.id_set) > 0
else:
for tag in distr:
if tag not in self.duration_to_list_per_tag:
return False
return True
def _load_all_records_from_disk(self, audio_dir, idx_fname, bin_size):
"""Loads all audio records from the disk into memory and groups them
into chunks based on their duration and the bin_size granalarity.
Once all the records are read, indices are built from these records
by another function so that the audio samples can be drawn efficiently.
Updates:
self.audio_dir (path): audio root directory
self.idx_fname (path): audio database index filename
self.bin_size (float): granularity of bins
self.idx_to_record (dict): maps from the audio id to
(duration, file_size, relative_path, tags)
self.tag_to_id_set (dict): maps from the tag to
the set of id's of audios that have this tag.
self.id_set_complete (set): set of all audio id's in the index file
self.min_duration (float): minimum audio duration observed in the
index file
self.duration_bins (list): the lower bounds on the duration of
audio files falling in each bin
self.duration_to_id_set (dict): contains (k, v) where v is the set
of id's of audios whose lengths are longer than or equal to k.
(e.g. k is the duration lower bound of this bin).
self.duration_to_id_set_per_tag (dict): Something like above but
has a finer granularity mapping from the tag to
duration_to_id_set.
self.shared_duration_bins (list): list of sets where each set
contains duration lower bounds whose audio id sets are the
same. The rationale for having this is that there are a few
but extremely long audio files which lead to a lot of bins.
When the id sets do not change across various minimum duration
boundaries, we
cluster these together and make them point to the same id set
reference.
:return: whether the records were read from the disk. The assumption is
that the audio index file on disk and the actual audio files
are constructed once and never change during training. We only
re-read when either the directory or the index file path change.
"""
if self.audio_dir == audio_dir and self.idx_fname == idx_fname and \
self.bin_size == bin_size:
# The audio directory and/or the list of audio files
# haven't changed. No need to load the list again.
return False
# Remember where the audio index is most recently read from.
self.audio_dir = audio_dir
self.idx_fname = idx_fname
self.bin_size = bin_size
# Read in the idx and compute the number of bins necessary
self.clear()
rank = []
min_duration = float('inf')
max_duration = float('-inf')
for idx, duration, file_size, relpath, tags in \
stream_audio_index(idx_fname):
self.idx_to_record[idx] = (duration, file_size, relpath, tags)
max_duration = max(max_duration, duration)
min_duration = min(min_duration, duration)
rank.append((duration, idx))
for tag in tags:
self.tag_to_id_set[tag].add(idx)
if len(rank) == 0:
# file is empty
raise IOError("Index file {} is empty".format(idx_fname))
for tag in self.tag_to_id_set:
self.id_set_complete |= self.tag_to_id_set[tag]
dur = min_duration
self.min_duration = min_duration
while dur < max_duration + bin_size:
self.duration_bins.append(dur)
dur += bin_size
# Sort in decreasing order of duration and populate
# the cumulative indices lists.
rank.sort(reverse=True)
# These are indices for `rank` and used to keep track of whether
# there are new records to add in the current bin.
last = 0
cur = 0
# The set of audios falling in the previous bin; in the case,
# where we don't find new audios for the current bin, we store
# the reference to the last set so as to conserve memory.
# This is not such a big problem if the audio duration is
# bounded by a small number like 30 seconds and the
# bin size is big enough. But, for raw freesound audios,
# some audios can be as long as a few hours!
last_audio_set = set()
# The same but for each tag so that we can pick audios based on
# tags and also some user-specified tag distribution.
last_audio_set_per_tag = defaultdict(lambda: set())
# Set of lists of bins sharing the same audio sets.
shared = set()
for i in range(len(self.duration_bins) - 1, -1, -1):
lower_bound = self.duration_bins[i]
new_audio_idxs = set()
new_audio_idxs_per_tag = defaultdict(lambda: set())
while cur < len(rank) and rank[cur][0] >= lower_bound:
idx = rank[cur][1]
tags = self.idx_to_record[idx][3]
new_audio_idxs.add(idx)
for tag in tags:
new_audio_idxs_per_tag[tag].add(idx)
cur += 1
# This makes certain that the same list is shared across
# different bins if no new indices are added.
if cur == last:
shared.add(lower_bound)
else:
last_audio_set = last_audio_set | new_audio_idxs
for tag in new_audio_idxs_per_tag:
last_audio_set_per_tag[tag] = \
last_audio_set_per_tag[tag] | \
new_audio_idxs_per_tag[tag]
if len(shared) > 0:
self.shared_duration_bins.append(shared)
shared = set([lower_bound])
### last_audio_set = set() should set blank
last = cur
self.duration_to_id_set[lower_bound] = last_audio_set
for tag in last_audio_set_per_tag:
self.duration_to_id_set_per_tag[lower_bound][tag] = \
last_audio_set_per_tag[tag]
# The last `shared` record isn't added to the `shared_duration_bins`.
self.shared_duration_bins.append(shared)
# We make sure that the while loop above has exhausted through the
# `rank` list by checking if the `cur`rent index in `rank` equals
# the length of the array, which is the halting condition.
assert cur == len(rank)
return True
def _build_index_from_records(self, tag_list):
""" Uses the in-memory records read from the index file to build
an in-memory index restricted to the given tag list.
:param tag_list: List of tags we are interested in sampling from.
:type tag_list: list(str)
Updates:
self.id_set (set): the set of all audio id's that can be sampled.
self.duration_to_list (dict): maps from the duration lower bound
to the id's of audios longer than this duration.
self.duration_to_list_per_tag (dict): maps from the tag to
the same structure as self.duration_to_list. This is to support
sampling from a target noise distribution.
:return: whether the index was built from scratch
"""
if self.tags == tag_list:
return False
self.tags = tag_list
if len(tag_list) == 0:
self.id_set = self.id_set_complete
else:
self.id_set = set()
for tag in tag_list:
self.id_set |= self.tag_to_id_set[tag]
# Next, we need to take a subset of the audio files
for shared in self.shared_duration_bins:
# All bins in `shared' have the same index lists
# so we can intersect once and set all of them to this list.
lb = list(shared)[0]
intersected = list(self.id_set & self.duration_to_id_set[lb])
duration_to_id_set = self.duration_to_id_set_per_tag[lb]
intersected_per_tag = {
tag: self.tag_to_id_set[tag] & duration_to_id_set[tag]
for tag in duration_to_id_set
}
for bin_key in shared:
self.duration_to_list[bin_key] = intersected
for tag in intersected_per_tag:
self.duration_to_list_per_tag[tag][bin_key] = \
intersected_per_tag[tag]
assert len(self.duration_to_list) == len(self.duration_to_id_set)
return True
def refresh_records_from_index_file(self,
audio_dir,
idx_fname,
tag_list,
bin_size=2.0):
""" Loads the index file and populates the records
for building the internal index.
If the audio directory or index file name has changed, the whole index
is reloaded from scratch. If only the tag_list is changed, then the
desired index is built from the complete, in-memory record.
:param audio_dir: audio directory
:type audio_dir: basestring
:param idx_fname: audio index file name
:type idex_fname: basestring
:param tag_list: list of tags we are interested in loading;
if empty, we load all.
:type tag_list: list
:param bin_size: optional argument for controlling the granularity
of duration bins
:type bin_size: float
"""
if tag_list is None:
tag_list = []
reloaded_records = self._load_all_records_from_disk(audio_dir,
idx_fname, bin_size)
if reloaded_records or self.tags != tag_list:
self._build_index_from_records(tag_list)
logger.info('loaded {} audio files from {}'
.format(len(self.id_set), idx_fname))
def sample_audio(self, duration, rng=None, distr=None):
""" Uniformly draws an audio record of at least the desired duration
:param duration: minimum desired audio duration
:type duration: float
:param rng: random number generator
:type rng: random.Random
:param distr: target distribution of audio tags. If not provided,
:type distr: dict
all audio files are sampled uniformly at random.
:returns: success, (duration, file_size, path)
"""
if duration < 0.0:
duration = self.min_duration
i = bisect.bisect_left(self.duration_bins, duration)
if i == len(self.duration_bins):
return False, None
bin_key = self.duration_bins[i]
if distr is None:
indices = self.duration_to_list[bin_key]
else:
# If a desired audio distribution is given, we sample from it.
if rng is None:
rng = random.Random()
nprng = np.random.RandomState(rng.getrandbits(32))
prob_masses = distr.values()
prob_masses /= np.sum(prob_masses)
tag = nprng.choice(distr.keys(), p=prob_masses)
indices = self.duration_to_list_per_tag[tag][bin_key]
if len(indices) == 0:
return False, None
else:
if rng is None:
rng = random.Random()
# duration, file size and relative path from root
s = self.idx_to_record[rng.sample(indices, 1)[0]]
s = (s[0], s[1], os.path.join(self.audio_dir, s[2]))
return True, s
......@@ -6,6 +6,11 @@ from __future__ import print_function
import json
import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.resamler import ResamplerAugmentor
from data_utils.augmentor.speed_perturb import SpeedPerturbatioAugmentor
from data_utils.augmentor.online_bayesian_normalization import OnlineBayesianNormalizationAugmentor
from data_utils.augmentor.Impulse_response import ImpulseResponseAugmentor
from data_utils.augmentor.noise_speech import NoiseSpeechAugmentor
class AugmentationPipeline(object):
......@@ -76,5 +81,15 @@ class AugmentationPipeline(object):
"""Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params)
if augmentor_type == "resamle":
return ResamplerAugmentor(self._rng, **params)
if augmentor_type == "speed":
return SpeedPerturbatioAugmentor(self._rng, **params)
if augmentor_type == "online_bayesian_normalization":
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
if augmentor_type == "Impulse_response":
return ImpulseResponseAugmentor(self._rng, **params)
if augmentor_type == "noise_speech":
return NoiseSpeechAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
""" Impulse response"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import base
from . import audio_database
from data_utils.speech import SpeechSegment
class ImpulseResponseAugmentor(base.AugmentorBase):
""" Instantiates an impulse response model
:param ir_dir: directory containing impulse responses
:type ir_dir: basestring
:param tags: optional parameter for specifying what
particular impulse responses to apply.
:type tags: list
:parm tag_distr: optional noise distribution
:type tag_distr: dict
"""
def __init__(self, rng, ir_dir, index_file, tags=None, tag_distr=None):
# Define all required parameter maps here.
self.ir_dir = ir_dir
self.index_file = index_file
self.tags = tags
self.tag_distr = tag_distr
self.audio_index = audio_database.AudioIndex()
self.rng = rng
def _init_data(self):
""" Preloads stuff from disk in an attempt (e.g. list of files, etc)
to make later loading faster. If the data configuration remains the
same, this function does nothing.
"""
self.audio_index.refresh_records_from_index_file(
self.ir_dir, self.index_file, self.tags)
def transform_audio(self, audio_segment):
""" Convolves the input audio with an impulse response.
:param audio_segment: input audio
:type audio_segment: AudioSegemnt
"""
# This handles the cases where the data source or directories change.
self._init_data()
read_size = 0
tag_distr = self.tag_distr
if not self.audio_index.has_audio(tag_distr):
if tag_distr is None:
if not self.tags:
raise RuntimeError("The ir index does not have audio "
"files to sample from.")
else:
raise RuntimeError("The ir index does not have audio "
"files of the given tags to sample "
"from.")
else:
raise RuntimeError("The ir index does not have audio "
"files to match the target ir "
"distribution.")
else:
# Querying with a negative duration triggers the index to search
# from all impulse responses.
success, record = self.audio_index.sample_audio(
-1.0, rng=self.rng, distr=tag_distr)
if success is True:
_, read_size, ir_fname = record
ir_wav = SpeechSegment.from_file(ir_fname)
audio_segment.convolve(ir_wav, allow_resampling=True)
""" noise speech
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import os
from collections import defaultdict
from . import base
from . import audio_database
from data_utils.speech import SpeechSegment
TURK = "turk"
USE_AUDIO_DATABASE_SOURCES = frozenset(["freesound", "chime"])
HALF_NOISE_LENGTH_MIN_THRESHOLD = 3.0
FIND_NOISE_MAX_ATTEMPTS = 20
logger = logging.getLogger(__name__)
def get_first_smaller(items, value):
index = bisect.bisect_left(items, value) - 1
assert items[index] < value, \
'get_first_smaller failed! %d %d' % (items[index], value)
return items[index]
def get_first_larger(items, value):
'Find leftmost value greater than value'
index = bisect.bisect_right(items, value)
assert index < len(items), \
"no noise bin exists for this audio length (%f)" % value
assert items[index] > value, \
'get_first_larger failed! %d %d' % (items[index], value)
return items[index]
def _get_turk_noise_files(noise_dir, index_file):
""" Creates a map from duration => a list of noise filenames
:param noise_dir: Directory of noise files which contains
"noise-samples-list"
:type noise_dir: basestring
:param index_file: Noise list
:type index_file: basestring
returns:noise_files (defaultdict): A map of bins to noise files.
Each key is the duration, and the value is a list of noise
files binned to this duration. Each bin is 2 secs.
Note: noise-samples-list should contain one line per noise (wav) file
along with its duration in milliseconds
"""
noise_files = defaultdict(list)
if not os.path.exists(index_file):
logger.error('No noise files were found at {}'.format(index_file))
return noise_files
num_noise_files = 0
rounded_durations = list(range(0, 65, 2))
with open(index_file, 'r') as fl:
for line in fl:
fname = os.path.join(noise_dir, line.strip().split()[0])
duration = float(line.strip().split()[1]) / 1000
# bin the noise files into length bins rounded by 2 sec
bin_id = get_first_smaller(rounded_durations, duration)
noise_files[bin_id].append(fname)
num_noise_files += 1
logger.info('Loaded {} turk noise files'.format(num_noise_files))
return noise_files
class NoiseSpeechAugmentor(base.AugmentorBase):
""" Noise addition block
:param snr_min: minimum signal-to-noise ratio
:type snr_min: float
:param snr_max: maximum signal-to-noise ratio
:type snr_max: float
:param noise_dir: root of where noise files are stored
:type noise_fir: basestring
:param index_file: index of noises of interest in noise_dir
:type index_file: basestring
:param source: select one from
- turk
- freesound
- chime
Note that this field is no longer required for the freesound
and chime
:type source: string
:param tags: optional parameter for specifying what
particular noises we want to add. See above for the available tags.
:type tags: list
:param tag_distr: optional noise distribution
:type tag_distr: dict
"""
def __init__(self,
rng,
snr_min,
snr_max,
noise_dir,
source,
allow_downsampling=None,
index_file=None,
tags=None,
tag_distr=None):
# Define all required parameter maps here.
self.rng = rng
self.snr_min = snr_min
self.snr_max = snr_max
self.noise_dir = noise_dir
self.source = source
self.allow_downsampling = allow_downsampling
self.index_file = index_file
self.tags = tags
self.tag_distr = tag_distr
# When new noise sources are added, make sure to define the
# associated bookkeeping variables here.
self.turk_noise_files = []
self.turk_noise_dir = None
self.audio_index = audio_database.AudioIndex()
def _init_data(self):
""" Preloads stuff from disk in an attempt (e.g. list of files, etc)
to make later loading faster. If the data configuration remains the
same, this function does nothing.
"""
noise_dir = self.noise_dir
index_file = self.index_file
source = self.source
if not index_file:
if source == TURK:
index_file = os.path.join(noise_dir, 'noise-samples-list')
logger.debug("index_file not provided; " + "defaulting to " +
index_file)
else:
if source != "":
assert source in USE_AUDIO_DATABASE_SOURCES, \
"{} not supported by audio_database".format(source)
index_file = os.path.join(noise_dir,
"audio_index_commercial.txt")
logger.debug("index_file not provided; " + "defaulting to " +
index_file)
if source == TURK:
if self.turk_noise_dir != noise_dir:
self.turk_noise_dir = noise_dir
self.turk_noise_files = _get_turk_noise_files(noise_dir,
index_file)
# elif source == TODO_SUPPORT_NON_AUDIO_DATABASE_BASED_SOURCES:
else:
if source != "":
assert source in USE_AUDIO_DATABASE_SOURCES, \
"{} not supported by audio_database".format(source)
self.audio_index.refresh_records_from_index_file(
self.noise_dir, index_file, self.tags)
def transform_audio(self, audio_segment):
"""Adds walla noise
:param audio_segment: Input audio
:type audio_segment: SpeechSegment
"""
# This handles the cases where the data source or directories change.
self._init_data
source = self.source
allow_downsampling = self.allow_downsampling
if source == TURK:
self._add_turk_noise(audio_segment, self.rng, allow_downsampling)
# elif source == TODO_SUPPORT_NON_AUDIO_DATABASE_BASED_SOURCES:
else:
self._add_noise(audio_segment, self.rng, allow_downsampling)
def _sample_snr(self):
""" Returns a float sampled in [`self.snr_min`, `self.snr_max`]
if both `self.snr_min` and `self.snr_max` are non-zero.
"""
snr_min = self.snr_min
snr_max = self.snr_max
sampled_snr = self.rng.uniform(snr_min, snr_max)
return sampled_snr
def _add_turk_noise(self, audio_segment, allow_downsampling):
""" Adds a turk noise to the input audio.
:param audio_segment: input audio
:type audio_segment: audiosegment
:param allow_downsampling: indicates whether downsampling
is allowed
:type allow_downsampling: boolean
"""
read_size = 0
if len(self.turk_noise_files) > 0:
snr = self._sample_snr(self.rng)
# Draw the noise file randomly from noise files that are
# slightly longer than the utterance
noise_bins = sorted(self.turk_noise_files.keys())
# note some bins can be empty, so we can't just round up
# to the nearest 2-sec interval
rounded_duration = get_first_larger(noise_bins,
audio_segment.duration)
noise_fname = \
self.rng.sample(self.turk_noise_files[rounded_duration], 1)[0]
noise = SpeechSegment.from_wav_file(noise_fname)
logger.debug('noise_fname {}'.format(noise_fname))
logger.debug('snr {}'.format(snr))
read_size = len(noise) * 2
# May throw exceptions, but this is caught by
# AudioFeaturizer.get_audio_files.
audio_segment.add_noise(
noise, snr, rng=self.rng, allow_downsampling=allow_downsampling)
def _add_noise(self, audio_segment, allow_downsampling):
""" Adds a noise indexed in audio_database.AudioIndex.
:param audio_segment: input audio
:type audio_segment: SpeechSegment
:param allow_downsampling: indicates whether downsampling
is allowed
:type allow_downsampling: boolean
Returns:
(SpeechSegment, int)
- sound with turk noise added
- number of bytes read from disk
"""
read_size = 0
tag_distr = self.tag_distr
if not self.audio_index.has_audio(tag_distr):
if tag_distr is None:
if not self.tags:
raise RuntimeError("The noise index does not have audio "
"files to sample from.")
else:
raise RuntimeError("The noise index does not have audio "
"files of the given tags to sample "
"from.")
else:
raise RuntimeError("The noise index does not have audio "
"files to match the target noise "
"distribution.")
else:
# Compute audio segment related statistics
audio_duration = audio_segment.duration
# Sample relevant augmentation parameters.
snr = self._sample_snr(self.rng)
# Perhaps, we may not have a sufficiently long noise, so we need
# to search iteratively.
min_duration = audio_duration + 0.25
for _ in range(FIND_NOISE_MAX_ATTEMPTS):
logger.debug("attempting to find noise of length "
"at least {}".format(min_duration))
success, record = \
self.audio_index.sample_audio(min_duration,
rng=self.rng,
distr=tag_distr)
if success is True:
noise_duration, read_size, noise_fname = record
# Assert after logging so we know
# what caused augmentation to fail.
logger.debug("noise_fname {}".format(noise_fname))
logger.debug("snr {}".format(snr))
assert noise_duration >= min_duration
break
# Decrease the desired minimum duration linearly.
# If the value becomes smaller than some threshold,
# we half the value instead.
if min_duration > HALF_NOISE_LENGTH_MIN_THRESHOLD:
min_duration -= 2.0
else:
min_duration *= 0.5
if success is False:
logger.info("Failed to find a noise file")
return
diff_duration = audio_duration + 0.25 - noise_duration
if diff_duration >= 0.0:
# Here, the noise is shorter than the audio file, so
# we pad with zeros to make sure the noise sound is applied
# with a uniformly random shift.
noise = SpeechSegment.from_file(noise_fname)
noise = noise.pad_silence(diff_duration, sides="both")
else:
# The noise clip is at least ~25 ms longer than the audio
# segment here.
diff_duration = int(noise_duration * audio_segment.sample_rate) - \
int(audio_duration * audio_segment.sample_rate) - \
int(0.02 * audio_segment.sample_rate)
start = float(self.rng.randint(0, diff_duration)) / \
audio.sample_rate
finish = min(start + audio_duration + 0.2, noise_duration)
noise = SpeechSegment.slice_from_file(noise_fname, start,
finish)
if len(noise) < len(audio_segment):
# This is to ensure that the noise clip is at least as
# long as the audio segment.
num_samples_to_pad = len(audio_segment) - len(noise)
# Padding this amount of silence on both ends ensures that
# the placement of the noise clip is uniformly random.
silence = SpeechSegment(
np.zeros(num_samples_to_pad), audio_segment.sample_rate)
noise = SpeechSegment.concatenate(silence, noise, silence)
audio_segment.add_noise(
noise, snr, rng=self.rng, allow_downsampling=allow_downsampling)
""" Online bayesian normalization
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import base
class OnlineBayesianNormalizationAugmentor(base.AugmentorBase):
"""
Instantiates an online bayesian normalization module.
:param target_db: Target RMS value in decibels
:type target_db: func[int->scalar]
:param prior_db: Prior RMS estimate in decibels
:type prior_db: func[int->scalar]
:param prior_samples: Prior strength in number of samples
:type prior_samples: func[int->scalar]
:param startup_delay: Start-up delay in seconds during
which normalization statistics is accrued.
:type starup_delay: func[int->scalar]
"""
def __init__(self,
rng,
target_db,
prior_db,
prior_samples,
startup_delay=base.parse_parameter_from(0.0)):
self.target_db = target_db
self.prior_db = prior_db
self.prior_samples = prior_samples
self.startup_delay = startup_delay
self.rng = rng
def transform_audio(self, audio_segment):
"""
Normalizes the input audio using the online Bayesian approach.
:param audio_segment: input audio
:type audio_segment: SpeechSegment
:param iteration: current iteration
:type iteration: int
:param text: audio transcription
:type text: basestring
:param rng: RNG to use for augmentation
:type rng: random.Random
"""
read_size = 0
target_db = self.target_db(iteration)
prior_db = self.prior_db(iteration)
prior_samples = self.prior_samples(iteration)
startup_delay = self.startup_delay(iteration)
audio.normalize_online_bayesian(
target_db, prior_db, prior_samples, startup_delay=startup_delay)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import base
class ResamplerAugmentor(base.AugmentorBase):
""" Instantiates a resampler module.
:param new_sample_rate: New sample rate in Hz
:type new_sample_rate: func[int->scalar]
:param rng: Random generator object.
:type rng: random.Random
"""
def __init__(self, rng, new_sample_rate):
self.new_sample_rate = new_sample_rate
self._rng = rng
def transform_audio(self, audio_segment):
""" Resamples the input audio to the target sample rate.
Note that this is an in-place transformation.
:param audio: input audio
:type audio: SpeechDLSegment
"""
new_sample_rate = self.new_sample_rate
audio.resample(new_sample_rate)
\ No newline at end of file
"""Speed perturbation module for making ASR robust to different voice
types (high pitched, low pitched, etc)
Samples uniformly between speed_min and speed_max
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import base
class SpeedPerturbatioAugmentor(base.AugmentorBase):
"""
Instantiates a speed perturbation module.
See reference paper here:
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
:param speed_min: Lower bound on new rate to sample
:type speed_min: func[int->scalar]
:param speed_max: Upper bound on new rate to sample
:type speed_max: func[int->scalar]
"""
def __init__(self, rng, speed_min, speed_max):
if (speed_min < 0.9):
raise ValueError(
"Sampling speed below 0.9 can cause unnatural effects")
if (speed_min > 1.1):
raise ValueError(
"Sampling speed above 1.1 can cause unnatural effects")
self.speed_min = speed_min
self.speed_max = speed_max
self.rng = rng
def transform_audio(self, audio_segment):
"""
Samples a new speed rate from the given range and
changes the speed of the given audio clip.
Note that this is an in-place transformation.
:param audio_segment: input audio
:type audio_segment: SpeechDLSegment
"""
read_size = 0
speed_min = self.speed_min(iteration)
speed_max = self.speed_max(iteration)
sampled_speed = rng.uniform(speed_min, speed_max)
audio = audio.change_speed(sampled_speed)
......@@ -3,10 +3,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
from . import base
class VolumePerturbAugmentor(AugmentorBase):
class VolumePerturbAugmentor(base.AugmentorBase):
"""Augmentation model for adding random volume perturbation.
This is used for multi-loudness training of PCEN. See
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册