提交 d40f2092 编写于 作者: H Hui Zhang

epsnet io

上级 fa70a024
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import numpy as np
from deepspeech.utils.log import Log
__all__ = ["make_batchset"]
logger = Log(__name__).getlog()
def batchfy_by_seq(
sorted_data,
batch_size,
max_length_in,
max_length_out,
min_batch_size=1,
shortest_first=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0, ):
"""Make batch set from json dictionary
:param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
:param int batch_size: batch size
:param int max_length_in: maximum length of input to decide adaptive batch size
:param int max_length_out: maximum length of output to decide adaptive batch size
:param int min_batch_size: mininum batch size (for multi-gpu)
:param bool shortest_first: Sort from batch with shortest samples
to longest if true, otherwise reverse
:param str ikey: key to access input
(for ASR ikey="input", for TTS, MT ikey="output".)
:param int iaxis: dimension to access input
(for ASR, TTS iaxis=0, for MT iaxis="1".)
:param str okey: key to access output
(for ASR, MT okey="output". for TTS okey="input".)
:param int oaxis: dimension to access output
(for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)
:return: List[List[Tuple[str, dict]]] list of batches
"""
if batch_size <= 0:
raise ValueError(f"Invalid batch_size={batch_size}")
# check #utts is more than min_batch_size
if len(sorted_data) < min_batch_size:
raise ValueError(
f"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size})."
)
# make list of minibatches
minibatches = []
start = 0
while True:
_, info = sorted_data[start]
ilen = int(info[ikey][iaxis]["shape"][0])
olen = (int(info[okey][oaxis]["shape"][0]) if oaxis >= 0 else
max(map(lambda x: int(x["shape"][0]), info[okey])))
factor = max(int(ilen / max_length_in), int(olen / max_length_out))
# change batchsize depending on the input and output length
# if ilen = 1000 and max_length_in = 800
# then b = batchsize / 2
# and max(min_batches, .) avoids batchsize = 0
bs = max(min_batch_size, int(batch_size / (1 + factor)))
end = min(len(sorted_data), start + bs)
minibatch = sorted_data[start:end]
if shortest_first:
minibatch.reverse()
# check each batch is more than minimum batchsize
if len(minibatch) < min_batch_size:
mod = min_batch_size - len(minibatch) % min_batch_size
additional_minibatch = [
sorted_data[i] for i in np.random.randint(0, start, mod)
]
if shortest_first:
additional_minibatch.reverse()
minibatch.extend(additional_minibatch)
minibatches.append(minibatch)
if end == len(sorted_data):
break
start = end
# batch: List[List[Tuple[str, dict]]]
return minibatches
def batchfy_by_bin(
sorted_data,
batch_bins,
num_batches=0,
min_batch_size=1,
shortest_first=False,
ikey="input",
okey="output", ):
"""Make variably sized batch set, which maximizes
the number of bins up to `batch_bins`.
:param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
:param int batch_bins: Maximum frames of a batch
:param int num_batches: # number of batches to use (for debug)
:param int min_batch_size: minimum batch size (for multi-gpu)
:param int test: Return only every `test` batches
:param bool shortest_first: Sort from batch with shortest samples
to longest if true, otherwise reverse
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
"""
if batch_bins <= 0:
raise ValueError(f"invalid batch_bins={batch_bins}")
length = len(sorted_data)
idim = int(sorted_data[0][1][ikey][0]["shape"][1])
odim = int(sorted_data[0][1][okey][0]["shape"][1])
logger.info("# utts: " + str(len(sorted_data)))
minibatches = []
start = 0
n = 0
while True:
# Dynamic batch size depending on size of samples
b = 0
next_size = 0
max_olen = 0
while next_size < batch_bins and (start + b) < length:
ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) * idim
olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) * odim
if olen > max_olen:
max_olen = olen
next_size = (max_olen + ilen) * (b + 1)
if next_size <= batch_bins:
b += 1
elif next_size == 0:
raise ValueError(
f"Can't fit one sample in batch_bins ({batch_bins}): "
f"Please increase the value")
end = min(length, start + max(min_batch_size, b))
batch = sorted_data[start:end]
if shortest_first:
batch.reverse()
minibatches.append(batch)
# Check for min_batch_size and fixes the batches if needed
i = -1
while len(minibatches[i]) < min_batch_size:
missing = min_batch_size - len(minibatches[i])
if -i == len(minibatches):
minibatches[i + 1].extend(minibatches[i])
minibatches = minibatches[1:]
break
else:
minibatches[i].extend(minibatches[i - 1][:missing])
minibatches[i - 1] = minibatches[i - 1][missing:]
i -= 1
if end == length:
break
start = end
n += 1
if num_batches > 0:
minibatches = minibatches[:num_batches]
lengths = [len(x) for x in minibatches]
logger.info(
str(len(minibatches)) + " batches containing from " + str(min(lengths))
+ " to " + str(max(lengths)) + " samples " + "(avg " + str(
int(np.mean(lengths))) + " samples).")
return minibatches
def batchfy_by_frame(
sorted_data,
max_frames_in,
max_frames_out,
max_frames_inout,
num_batches=0,
min_batch_size=1,
shortest_first=False,
ikey="input",
okey="output", ):
"""Make variable batch set, which maximizes the number of frames to max_batch_frame.
:param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json
:param int max_frames_in: Maximum input frames of a batch
:param int max_frames_out: Maximum output frames of a batch
:param int max_frames_inout: Maximum input+output frames of a batch
:param int num_batches: # number of batches to use (for debug)
:param int min_batch_size: minimum batch size (for multi-gpu)
:param int test: Return only every `test` batches
:param bool shortest_first: Sort from batch with shortest samples
to longest if true, otherwise reverse
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
"""
if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:
raise ValueError(
"At least, one of `--batch-frames-in`, `--batch-frames-out` or "
"`--batch-frames-inout` should be > 0")
length = len(sorted_data)
minibatches = []
start = 0
end = 0
while end != length:
# Dynamic batch size depending on size of samples
b = 0
max_olen = 0
max_ilen = 0
while (start + b) < length:
ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0])
if ilen > max_frames_in and max_frames_in != 0:
raise ValueError(
f"Can't fit one sample in --batch-frames-in ({max_frames_in}): "
f"Please increase the value")
olen = int(sorted_data[start + b][1][okey][0]["shape"][0])
if olen > max_frames_out and max_frames_out != 0:
raise ValueError(
f"Can't fit one sample in --batch-frames-out ({max_frames_out}): "
f"Please increase the value")
if ilen + olen > max_frames_inout and max_frames_inout != 0:
raise ValueError(
f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): "
f"Please increase the value")
max_olen = max(max_olen, olen)
max_ilen = max(max_ilen, ilen)
in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0
out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0
inout_ok = (max_ilen + max_olen) * (
b + 1) <= max_frames_inout or max_frames_inout == 0
if in_ok and out_ok and inout_ok:
# add more seq in the minibatch
b += 1
else:
# no more seq in the minibatch
break
end = min(length, start + b)
batch = sorted_data[start:end]
if shortest_first:
batch.reverse()
minibatches.append(batch)
# Check for min_batch_size and fixes the batches if needed
i = -1
while len(minibatches[i]) < min_batch_size:
missing = min_batch_size - len(minibatches[i])
if -i == len(minibatches):
minibatches[i + 1].extend(minibatches[i])
minibatches = minibatches[1:]
break
else:
minibatches[i].extend(minibatches[i - 1][:missing])
minibatches[i - 1] = minibatches[i - 1][missing:]
i -= 1
start = end
if num_batches > 0:
minibatches = minibatches[:num_batches]
lengths = [len(x) for x in minibatches]
logger.info(
str(len(minibatches)) + " batches containing from " + str(min(lengths))
+ " to " + str(max(lengths)) + " samples" + "(avg " + str(
int(np.mean(lengths))) + " samples).")
return minibatches
def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,
shortest_first):
import random
logger.info("use shuffled batch.")
sorted_data = random.sample(data.items(), len(data.items()))
logger.info("# utts: " + str(len(sorted_data)))
# make list of minibatches
minibatches = []
start = 0
while True:
end = min(len(sorted_data), start + batch_size)
# check each batch is more than minimum batchsize
minibatch = sorted_data[start:end]
if shortest_first:
minibatch.reverse()
if len(minibatch) < min_batch_size:
mod = min_batch_size - len(minibatch) % min_batch_size
additional_minibatch = [
sorted_data[i] for i in np.random.randint(0, start, mod)
]
if shortest_first:
additional_minibatch.reverse()
minibatch.extend(additional_minibatch)
minibatches.append(minibatch)
if end == len(sorted_data):
break
start = end
# for debugging
if num_batches > 0:
minibatches = minibatches[:num_batches]
logger.info("# minibatches: " + str(len(minibatches)))
return minibatches
BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"]
BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"]
def make_batchset(
data,
batch_size=0,
max_length_in=float("inf"),
max_length_out=float("inf"),
num_batches=0,
min_batch_size=1,
shortest_first=False,
batch_sort_key="input",
count="auto",
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
iaxis=0,
oaxis=0, ):
"""Make batch set from json dictionary
if utts have "category" value,
>>> data = [{'category': 'A', 'input': ..., 'utt':'utt1'},
... {'category': 'B', 'input': ..., 'utt':'utt2'},
... {'category': 'B', 'input': ..., 'utt':'utt3'},
... {'category': 'A', 'input': ..., 'utt':'utt4'}]
>>> make_batchset(data, batchsize=2, ...)
[[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
Note that if any utts doesn't have "category",
perform as same as batchfy_by_{count}
:param List[Dict[str, Any]] data: dictionary loaded from data.json
:param int batch_size: maximum number of sequences in a minibatch.
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
:param int batch_frames_in: maximum number of input frames in a minibatch.
:param int batch_frames_out: maximum number of output frames in a minibatch.
:param int batch_frames_out: maximum number of input+output frames in a minibatch.
:param str count: strategy to count maximum size of batch.
For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES
:param int max_length_in: maximum length of input to decide adaptive batch size
:param int max_length_out: maximum length of output to decide adaptive batch size
:param int num_batches: # number of batches to use (for debug)
:param int min_batch_size: minimum batch size (for multi-gpu)
:param bool shortest_first: Sort from batch with shortest samples
to longest if true, otherwise reverse
:param str batch_sort_key: how to sort data before creating minibatches
["input", "output", "shuffle"]
:param bool swap_io: if True, use "input" as output and "output"
as input in `data` dict
:param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output"
as input in `data` dict
:param int iaxis: dimension to access input
(for ASR, TTS iaxis=0, for MT iaxis="1".)
:param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,
reserved for future research, -1 means all axis.)
:return: List[List[Tuple[str, dict]]] list of batches
"""
# check args
if count not in BATCH_COUNT_CHOICES:
raise ValueError(
f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}")
if batch_sort_key not in BATCH_SORT_KEY_CHOICES:
raise ValueError(f"arg 'batch_sort_key' ({batch_sort_key}) should be "
f"one of {BATCH_SORT_KEY_CHOICES}")
ikey = "input"
okey = "output"
batch_sort_axis = 0 # index of list
if count == "auto":
if batch_size != 0:
count = "seq"
elif batch_bins != 0:
count = "bin"
elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:
count = "frame"
else:
raise ValueError(
f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}"
)
logger.info(f"count is auto detected as {count}")
if count != "seq" and batch_sort_key == "shuffle":
raise ValueError(
"batch_sort_key=shuffle is only available if batch_count=seq")
category2data = {} # Dict[str, dict]
for v in data:
k = v['utt']
category2data.setdefault(v.get("category"), {})[k] = v
batches_list = [] # List[List[List[Tuple[str, dict]]]]
for d in category2data.values():
if batch_sort_key == "shuffle":
batches = batchfy_shuffle(d, batch_size, min_batch_size,
num_batches, shortest_first)
batches_list.append(batches)
continue
# sort it by input lengths (long to short)
sorted_data = sorted(
d.items(),
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
reverse=not shortest_first, )
logger.info("# utts: " + str(len(sorted_data)))
if count == "seq":
batches = batchfy_by_seq(
sorted_data,
batch_size=batch_size,
max_length_in=max_length_in,
max_length_out=max_length_out,
min_batch_size=min_batch_size,
shortest_first=shortest_first,
ikey=ikey,
iaxis=iaxis,
okey=okey,
oaxis=oaxis, )
if count == "bin":
batches = batchfy_by_bin(
sorted_data,
batch_bins=batch_bins,
min_batch_size=min_batch_size,
shortest_first=shortest_first,
ikey=ikey,
okey=okey, )
if count == "frame":
batches = batchfy_by_frame(
sorted_data,
max_frames_in=batch_frames_in,
max_frames_out=batch_frames_out,
max_frames_inout=batch_frames_inout,
min_batch_size=min_batch_size,
shortest_first=shortest_first,
ikey=ikey,
okey=okey, )
batches_list.append(batches)
if len(batches_list) == 1:
batches = batches_list[0]
else:
# Concat list. This way is faster than "sum(batch_list, [])"
batches = list(itertools.chain(*batches_list))
# for debugging
if num_batches > 0:
batches = batches[:num_batches]
logger.info("# minibatches: " + str(len(batches)))
# batch: List[List[Tuple[str, dict]]]
return batches
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from deepspeech.io.utility import pad_list
from deepspeech.utils.log import Log
__all__ = ["CustomConverter"]
logger = Log(__name__).getlog()
class CustomConverter():
"""Custom batch converter.
Args:
subsampling_factor (int): The subsampling factor.
dtype (np.dtype): Data type to convert.
"""
def __init__(self, subsampling_factor=1, dtype=np.float32):
"""Construct a CustomConverter object."""
self.subsampling_factor = subsampling_factor
self.ignore_id = -1
self.dtype = dtype
def __call__(self, batch):
"""Transform a batch and send it to a device.
Args:
batch (list): The batch to transform.
Returns:
tuple(np.ndarray, nn.ndarray, nn.ndarray)
"""
# batch should be located in list
assert len(batch) == 1
(xs, ys), utts = batch[0]
assert xs[0] is not None, "please check Reader and Augmentation impl."
# perform subsampling
if self.subsampling_factor > 1:
xs = [x[::self.subsampling_factor, :] for x in xs]
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs])
# perform padding and convert to tensor
# currently only support real number
if xs[0].dtype.kind == "c":
xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)
xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)
# Note(kamo):
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
# Don't create ComplexTensor and give it E2E here
# because torch.nn.DataParellel can't handle it.
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
else:
xs_pad = pad_list(xs, 0).astype(self.dtype)
# NOTE: this is for multi-output (e.g., speech translation)
ys_pad = pad_list(
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
self.ignore_id)
olens = np.array(
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])
return utts, xs_pad, ilens, ys_pad, olens
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import List
from typing import Text
import numpy as np
from paddle.io import DataLoader
from deepspeech.frontend.utility import read_manifest
from deepspeech.io.batchfy import make_batchset
from deepspeech.io.converter import CustomConverter
from deepspeech.io.dataset import TransformDataset
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.utils.log import Log
__all__ = ["BatchDataLoader"]
logger = Log(__name__).getlog()
def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
mode: Text="asr",
iaxis=0,
oaxis=0):
if mode == 'asr':
feat_dim = data_json[0]['input'][oaxis]['shape'][1]
vocab_size = data_json[0]['output'][oaxis]['shape'][1]
else:
raise ValueError(f"{mode} mode not support!")
return feat_dim, vocab_size
def batch_collate(x):
"""de-minibatch, since user compose batch.
Args:
x (List[Tuple]): [(utts, xs, ilens, ys, olens)]
Returns:
Tuple: (utts, xs, ilens, ys, olens)
"""
return x[0]
class BatchDataLoader():
def __init__(self,
json_file: str,
train_mode: bool,
sortagrad: bool=False,
batch_size: int=0,
maxlen_in: float=float('inf'),
maxlen_out: float=float('inf'),
minibatches: int=0,
mini_batch_size: int=1,
batch_count: str='auto',
batch_bins: int=0,
batch_frames_in: int=0,
batch_frames_out: int=0,
batch_frames_inout: int=0,
preprocess_conf=None,
n_iter_processes: int=1,
subsampling_factor: int=1,
num_encs: int=1):
self.json_file = json_file
self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
self.batch_size = batch_size
self.maxlen_in = maxlen_in
self.maxlen_out = maxlen_out
self.batch_count = batch_count
self.batch_bins = batch_bins
self.batch_frames_in = batch_frames_in
self.batch_frames_out = batch_frames_out
self.batch_frames_inout = batch_frames_inout
self.subsampling_factor = subsampling_factor
self.num_encs = num_encs
self.preprocess_conf = preprocess_conf
self.n_iter_processes = n_iter_processes
# read json data
self.data_json = read_manifest(json_file)
self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
self.data_json, mode='asr')
# make minibatch list (variable length)
self.minibaches = make_batchset(
self.data_json,
batch_size,
maxlen_in,
maxlen_out,
minibatches, # for debug
min_batch_size=mini_batch_size,
shortest_first=self.use_sortagrad,
count=batch_count,
batch_bins=batch_bins,
batch_frames_in=batch_frames_in,
batch_frames_out=batch_frames_out,
batch_frames_inout=batch_frames_inout,
iaxis=0,
oaxis=0, )
# data reader
self.reader = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=preprocess_conf,
preprocess_args={"train":
train_mode}, # Switch the mode of preprocessing
)
# Setup a converter
if num_encs == 1:
self.converter = CustomConverter(
subsampling_factor=subsampling_factor, dtype=np.float32)
else:
assert NotImplementedError("not impl CustomConverterMulEnc.")
# hack to make batchsize argument as 1
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
self.dataset = TransformDataset(self.minibaches, self.converter,
self.reader)
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
collate_fn=batch_collate,
num_workers=self.n_iter_processes, )
def __repr__(self):
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
echo += f"train_mode: {self.train_mode}, "
echo += f"sortagrad: {self.use_sortagrad}, "
echo += f"batch_size: {self.batch_size}, "
echo += f"maxlen_in: {self.maxlen_in}, "
echo += f"maxlen_out: {self.maxlen_out}, "
echo += f"batch_count: {self.batch_count}, "
echo += f"batch_bins: {self.batch_bins}, "
echo += f"batch_frames_in: {self.batch_frames_in}, "
echo += f"batch_frames_out: {self.batch_frames_out}, "
echo += f"batch_frames_inout: {self.batch_frames_inout}, "
echo += f"subsampling_factor: {self.subsampling_factor}, "
echo += f"num_encs: {self.num_encs}, "
echo += f"num_workers: {self.n_iter_processes}, "
echo += f"file: {self.json_file}"
return echo
def __len__(self):
return len(self.dataloader)
def __iter__(self):
return self.dataloader.__iter__()
def __call__(self):
return self.__iter__()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import kaldiio
import numpy as np
import soundfile
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.utils.log import Log
__all__ = ["LoadInputsAndTargets"]
logger = Log(__name__).getlog()
class LoadInputsAndTargets():
"""Create a mini-batch from a list of dicts
>>> batch = [('utt1',
... dict(input=[dict(feat='some.ark:123',
... filetype='mat',
... name='input1',
... shape=[100, 80])],
... output=[dict(tokenid='1 2 3 4',
... name='target1',
... shape=[4, 31])]]))
>>> l = LoadInputsAndTargets()
>>> feat, target = l(batch)
:param: str mode: Specify the task mode, "asr" or "tts"
:param: str preprocess_conf: The path of a json file for pre-processing
:param: bool load_input: If False, not to load the input data
:param: bool load_output: If False, not to load the output data
:param: bool sort_in_input_length: Sort the mini-batch in descending order
of the input length
:param: bool use_speaker_embedding: Used for tts mode only
:param: bool use_second_target: Used for tts mode only
:param: dict preprocess_args: Set some optional arguments for preprocessing
:param: Optional[dict] preprocess_args: Used for tts mode only
"""
def __init__(
self,
mode="asr",
preprocess_conf=None,
load_input=True,
load_output=True,
sort_in_input_length=True,
preprocess_args=None,
keep_all_data_on_mem=False, ):
self._loaders = {}
if mode not in ["asr"]:
raise ValueError("Only asr are allowed: mode={}".format(mode))
if preprocess_conf is not None:
with open(preprocess_conf, 'r') as fin:
self.preprocessing = AugmentationPipeline(fin.read())
logger.warning(
"[Experimental feature] Some preprocessing will be done "
"for the mini-batch creation using {}".format(
self.preprocessing))
else:
# If conf doesn't exist, this function don't touch anything.
self.preprocessing = None
self.mode = mode
self.load_output = load_output
self.load_input = load_input
self.sort_in_input_length = sort_in_input_length
if preprocess_args is None:
self.preprocess_args = {}
else:
assert isinstance(preprocess_args, dict), type(preprocess_args)
self.preprocess_args = dict(preprocess_args)
self.keep_all_data_on_mem = keep_all_data_on_mem
def __call__(self, batch, return_uttid=False):
"""Function to load inputs and targets from list of dicts
:param List[Tuple[str, dict]] batch: list of dict which is subset of
loaded data.json
:param bool return_uttid: return utterance ID information for visualization
:return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]
:return: list of input feature sequences
[(T_1, D), (T_2, D), ..., (T_B, D)]
:rtype: list of float ndarray
:return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]
:rtype: list of int ndarray
"""
x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
uttid_list = [] # List[str]
for uttid, info in batch:
uttid_list.append(uttid)
if self.load_input:
# Note(kamo): This for-loop is for multiple inputs
for idx, inp in enumerate(info["input"]):
# {"input":
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# "name": "input1", ...}], ...}
x = self._get_from_loader(
filepath=inp["feat"],
filetype=inp.get("filetype", "mat"))
x_feats_dict.setdefault(inp["name"], []).append(x)
if self.load_output:
for idx, inp in enumerate(info["output"]):
if "tokenid" in inp:
# ======= Legacy format for output =======
# {"output": [{"tokenid": "1 2 3 4"}])
x = np.fromiter(
map(int, inp["tokenid"].split()), dtype=np.int64)
else:
# ======= New format =======
# {"input":
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# "name": "target1", ...}], ...}
x = self._get_from_loader(
filepath=inp["feat"],
filetype=inp.get("filetype", "mat"))
y_feats_dict.setdefault(inp["name"], []).append(x)
if self.mode == "asr":
return_batch, uttid_list = self._create_batch_asr(
x_feats_dict, y_feats_dict, uttid_list)
else:
raise NotImplementedError(self.mode)
if self.preprocessing is not None:
# Apply pre-processing all input features
for x_name in return_batch.keys():
if x_name.startswith("input"):
return_batch[x_name] = self.preprocessing(
return_batch[x_name], uttid_list,
**self.preprocess_args)
if return_uttid:
return tuple(return_batch.values()), uttid_list
# Doesn't return the names now.
return tuple(return_batch.values())
def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):
"""Create a OrderedDict for the mini-batch
:param OrderedDict x_feats_dict:
e.g. {"input1": [ndarray, ndarray, ...],
"input2": [ndarray, ndarray, ...]}
:param OrderedDict y_feats_dict:
e.g. {"target1": [ndarray, ndarray, ...],
"target2": [ndarray, ndarray, ...]}
:param: List[str] uttid_list:
Give uttid_list to sort in the same order as the mini-batch
:return: batch, uttid_list
:rtype: Tuple[OrderedDict, List[str]]
"""
# handle single-input and multi-input (paralell) asr mode
xs = list(x_feats_dict.values())
if self.load_output:
ys = list(y_feats_dict.values())
assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))
# get index of non-zero length samples
nonzero_idx = list(
filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))
for n in range(1, len(y_feats_dict)):
nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)
else:
# Note(kamo): Be careful not to make nonzero_idx to a generator
nonzero_idx = list(range(len(xs[0])))
if self.sort_in_input_length:
# sort in input lengths based on the first input
nonzero_sorted_idx = sorted(
nonzero_idx, key=lambda i: -len(xs[0][i]))
else:
nonzero_sorted_idx = nonzero_idx
if len(nonzero_sorted_idx) != len(xs[0]):
logger.warning(
"Target sequences include empty tokenid (batch {} -> {}).".
format(len(xs[0]), len(nonzero_sorted_idx)))
# remove zero-length samples
xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]
uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]
x_names = list(x_feats_dict.keys())
if self.load_output:
ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]
y_names = list(y_feats_dict.keys())
# Keeping x_name and y_name, e.g. input1, for future extension
return_batch = OrderedDict([
* [(x_name, x) for x_name, x in zip(x_names, xs)],
* [(y_name, y) for y_name, y in zip(y_names, ys)],
])
else:
return_batch = OrderedDict(
[(x_name, x) for x_name, x in zip(x_names, xs)])
return return_batch, uttid_list
def _get_from_loader(self, filepath, filetype):
"""Return ndarray
In order to make the fds to be opened only at the first referring,
the loader are stored in self._loaders
>>> ndarray = loader.get_from_loader(
... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')
:param: str filepath:
:param: str filetype:
:return:
:rtype: np.ndarray
"""
if filetype == "hdf5":
# e.g.
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
filepath, key = filepath.split(":", 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = h5py.File(filepath, "r")
self._loaders[filepath] = loader
return loader[key][()]
elif filetype == "sound.hdf5":
# e.g.
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "sound.hdf5",
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
filepath, key = filepath.split(":", 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = SoundHDF5File(filepath, "r", dtype="int16")
self._loaders[filepath] = loader
array, rate = loader[key]
return array
elif filetype == "sound":
# e.g.
# {"input": [{"feat": "some/path.wav",
# "filetype": "sound"},
# Assume PCM16
if not self.keep_all_data_on_mem:
array, _ = soundfile.read(filepath, dtype="int16")
return array
if filepath not in self._loaders:
array, _ = soundfile.read(filepath, dtype="int16")
self._loaders[filepath] = array
return self._loaders[filepath]
elif filetype == "npz":
# e.g.
# {"input": [{"feat": "some/path.npz:F01_050C0101_PED_REAL",
# "filetype": "npz",
filepath, key = filepath.split(":", 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = np.load(filepath)
self._loaders[filepath] = loader
return loader[key]
elif filetype == "npy":
# e.g.
# {"input": [{"feat": "some/path.npy",
# "filetype": "npy"},
if not self.keep_all_data_on_mem:
return np.load(filepath)
if filepath not in self._loaders:
self._loaders[filepath] = np.load(filepath)
return self._loaders[filepath]
elif filetype in ["mat", "vec"]:
# e.g.
# {"input": [{"feat": "some/path.ark:123",
# "filetype": "mat"}]},
# In this case, "123" indicates the starting points of the matrix
# load_mat can load both matrix and vector
if not self.keep_all_data_on_mem:
return kaldiio.load_mat(filepath)
if filepath not in self._loaders:
self._loaders[filepath] = kaldiio.load_mat(filepath)
return self._loaders[filepath]
elif filetype == "scp":
# e.g.
# {"input": [{"feat": "some/path.scp:F01_050C0101_PED_REAL",
# "filetype": "scp",
filepath, key = filepath.split(":", 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = kaldiio.load_scp(filepath)
self._loaders[filepath] = loader
return loader[key]
else:
raise NotImplementedError(
"Not supported: loader_type={}".format(filetype))
def file_type(self, filepath):
suffix = filepath.split(":")[0].split('.')[-1].lower()
if suffix == 'ark':
return 'mat'
elif suffix == 'scp':
return 'scp'
elif suffix == 'npy':
return 'npy'
elif suffix == 'npz':
return 'npz'
elif suffix in ['wav', 'flac']:
# PCM16
return 'sound'
else:
raise ValueError(f"Not support filetype: {suffix}")
class SoundHDF5File():
"""Collecting sound files to a HDF5 file
>>> f = SoundHDF5File('a.flac.h5', mode='a')
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
>>> f['id'] = (array, 16000)
>>> array, rate = f['id']
:param: str filepath:
:param: str mode:
:param: str format: The type used when saving wav. flac, nist, htk, etc.
:param: str dtype:
"""
def __init__(self,
filepath,
mode="r+",
format=None,
dtype="int16",
**kwargs):
self.filepath = filepath
self.mode = mode
self.dtype = dtype
self.file = h5py.File(filepath, mode, **kwargs)
if format is None:
# filepath = a.flac.h5 -> format = flac
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
format = second_ext[1:]
if format.upper() not in soundfile.available_formats():
# If not found, flac is selected
format = "flac"
# This format affects only saving
self.format = format
def __repr__(self):
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
self.filepath, self.mode, self.format, self.dtype)
def create_dataset(self, name, shape=None, data=None, **kwds):
f = io.BytesIO()
array, rate = data
soundfile.write(f, array, rate, format=self.format)
self.file.create_dataset(
name, shape=shape, data=np.void(f.getvalue()), **kwds)
def __setitem__(self, name, data):
self.create_dataset(name, data=data)
def __getitem__(self, key):
data = self.file[key][()]
f = io.BytesIO(data.tobytes())
array, rate = soundfile.read(f, dtype=self.dtype)
return array, rate
def keys(self):
return self.file.keys()
def values(self):
for k in self.file:
yield self[k]
def items(self):
for k in self.file:
yield k, self[k]
def __iter__(self):
return iter(self.file)
def __contains__(self, item):
return item in self.file
def __len__(self, item):
return len(self.file)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def close(self):
self.file.close()
......@@ -17,11 +17,16 @@ import numpy as np
from deepspeech.utils.log import Log
__all__ = ["pad_sequence"]
__all__ = ["pad_list", "pad_sequence"]
logger = Log(__name__).getlog()
def pad_list(sequences: List[np.ndarray],
padding_value: float=0.0) -> np.ndarray:
return pad_sequence(sequences, True, padding_value)
def pad_sequence(sequences: List[np.ndarray],
batch_first: bool=True,
padding_value: float=0.0) -> np.ndarray:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册