提交 86e42f3d 编写于 作者: H Hui Zhang

more data utils

上级 7b649af8
......@@ -13,7 +13,6 @@
# limitations under the License.
import itertools
import logger
import numpy as np
from deepspeech.utils.log import Log
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.io import DataLoader
from deepspeech.frontend.utility import read_manifest
from deepspeech.io.batchfy import make_batchset
from deepspeech.io.dataset import TransformDataset
from deepspeech.io.utility import LoadInputsAndTargets
from deepspeech.io.utility import pad_list
from deepspeech.utils.log import Log
__all__ = ["CustomConverter", "BatchDataLoader"]
logger = Log(__name__).getlog()
class CustomConverter():
"""Custom batch converter.
Args:
subsampling_factor (int): The subsampling factor.
dtype (paddle.dtype): Data type to convert.
"""
def __init__(self, subsampling_factor=1, dtype=paddle.float32):
"""Construct a CustomConverter object."""
self.subsampling_factor = subsampling_factor
self.ignore_id = -1
self.dtype = dtype
def __call__(self, batch):
"""Transform a batch and send it to a device.
Args:
batch (list): The batch to transform.
Returns:
tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)
"""
# batch should be located in list
assert len(batch) == 1
xs, ys = batch[0]
# perform subsampling
if self.subsampling_factor > 1:
xs = [x[::self.subsampling_factor, :] for x in xs]
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs])
# perform padding and convert to tensor
# currently only support real number
if xs[0].dtype.kind == "c":
xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)
xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)
# Note(kamo):
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
# Don't create ComplexTensor and give it E2E here
# because torch.nn.DataParellel can't handle it.
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
else:
xs_pad = pad_list(xs, 0).astype(self.dtype)
ilens = paddle.to_tensor(ilens)
# NOTE: this is for multi-output (e.g., speech translation)
ys_pad = pad_list(
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
self.ignore_id)
olens = np.array([y.shape[0] for y in ys])
return xs_pad, ilens, ys_pad, olens
class BatchDataLoader():
def __init__(self,
json_file: str,
train_mode: bool,
sortagrad: bool=False,
batch_size: int=0,
maxlen_in: float=float('inf'),
maxlen_out: float=float('inf'),
minibatches: int=0,
mini_batch_size: int=1,
batch_count: str='auto',
batch_bins: int=0,
batch_frames_in: int=0,
batch_frames_out: int=0,
batch_frames_inout: int=0,
preprocess_conf=None,
n_iter_processes: int=1,
subsampling_factor: int=1,
num_encs: int=1):
self.json_file = json_file
self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
self.batch_size = batch_size
self.maxlen_in = maxlen_in
self.maxlen_out = maxlen_out
self.batch_count = batch_count
self.batch_bins = batch_bins
self.batch_frames_in = batch_frames_in
self.batch_frames_out = batch_frames_out
self.batch_frames_inout = batch_frames_inout
self.subsampling_factor = subsampling_factor
self.num_encs = num_encs
self.preprocess_conf = preprocess_conf
self.n_iter_processes = n_iter_processes
# read json data
data_json = read_manifest(json_file)
logger.info(f"load {json_file} file.")
# make minibatch list (variable length)
self.data = make_batchset(
data_json,
batch_size,
maxlen_in,
maxlen_out,
minibatches, # for debug
min_batch_size=mini_batch_size,
shortest_first=self.use_sortagrad,
count=batch_count,
batch_bins=batch_bins,
batch_frames_in=batch_frames_in,
batch_frames_out=batch_frames_out,
batch_frames_inout=batch_frames_inout,
iaxis=0,
oaxis=0, )
logger.info(f"batchfy data {json_file}: {len(self.data)}.")
self.load = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=preprocess_conf,
preprocess_args={"train":
train_mode}, # Switch the mode of preprocessing
)
# Setup a converter
if num_encs == 1:
self.converter = CustomConverter(
subsampling_factor=subsampling_factor, dtype=dtype)
else:
assert NotImplementedError("not impl CustomConverterMulEnc.")
# hack to make batchsize argument as 1
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
self.train_loader = DataLoader(
dataset=TransformDataset(
self.data, lambda data: self.converter([self.load(data)])),
batch_size=1,
shuffle=not use_sortagrad if train_mode else False,
collate_fn=lambda x: x[0],
num_workers=n_iter_processes, )
logger.info(f"dataloader for {json_file}.")
def __repr__(self):
return f"DataLoader {self.json_file}-{self.train_mode}-{self.use_sortagrad}"
......@@ -19,7 +19,7 @@ from yacs.config import CfgNode
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log
__all__ = ["ManifestDataset", "TripletManifestDataset"]
__all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"]
logger = Log(__name__).getlog()
......@@ -116,3 +116,27 @@ class TripletManifestDataset(ManifestDataset):
instance = self._manifest[idx]
return instance["utt"], instance["feat"], instance["text"], instance[
"text1"]
class TransformDataset(Dataset):
"""Transform Dataset.
Args:
data: list object from make_batchset
transfrom: transform function
"""
def __init__(self, data, transform):
"""Init function."""
super().__init__()
self.data = data
self.transform = transform
def __len__(self):
"""Len function."""
return len(self.data)
def __getitem__(self, idx):
"""[] operator."""
return self.transform(self.data[idx])
......@@ -11,17 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import List
import numpy as np
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.utils.log import Log
__all__ = ["pad_sequence"]
__all__ = ["pad_list", "pad_sequence", "LoadInputsAndTargets"]
logger = Log(__name__).getlog()
def pad_list(sequences: List[np.ndarray],
padding_value: float=0.0) -> np.ndarray:
return pad_sequence(sequences, True, padding_value)
def pad_sequence(sequences: List[np.ndarray],
batch_first: bool=True,
padding_value: float=0.0) -> np.ndarray:
......@@ -80,3 +87,299 @@ def pad_sequence(sequences: List[np.ndarray],
out_tensor[:length, i, ...] = tensor
return out_tensor
class LoadInputsAndTargets():
"""Create a mini-batch from a list of dicts
>>> batch = [('utt1',
... dict(input=[dict(feat='some.ark:123',
... filetype='mat',
... name='input1',
... shape=[100, 80])],
... output=[dict(tokenid='1 2 3 4',
... name='target1',
... shape=[4, 31])]]))
>>> l = LoadInputsAndTargets()
>>> feat, target = l(batch)
:param: str mode: Specify the task mode, "asr" or "tts"
:param: str preprocess_conf: The path of a json file for pre-processing
:param: bool load_input: If False, not to load the input data
:param: bool load_output: If False, not to load the output data
:param: bool sort_in_input_length: Sort the mini-batch in descending order
of the input length
:param: bool use_speaker_embedding: Used for tts mode only
:param: bool use_second_target: Used for tts mode only
:param: dict preprocess_args: Set some optional arguments for preprocessing
:param: Optional[dict] preprocess_args: Used for tts mode only
"""
def __init__(
self,
mode="asr",
preprocess_conf=None,
load_input=True,
load_output=True,
sort_in_input_length=True,
preprocess_args=None,
keep_all_data_on_mem=False, ):
self._loaders = {}
if mode not in ["asr"]:
raise ValueError("Only asr are allowed: mode={}".format(mode))
if preprocess_conf is not None:
self.preprocessing = AugmentationPipeline(preprocess_conf)
logging.warning(
"[Experimental feature] Some preprocessing will be done "
"for the mini-batch creation using {}".format(
self.preprocessing))
else:
# If conf doesn't exist, this function don't touch anything.
self.preprocessing = None
self.mode = mode
self.load_output = load_output
self.load_input = load_input
self.sort_in_input_length = sort_in_input_length
if preprocess_args is None:
self.preprocess_args = {}
else:
assert isinstance(preprocess_args, dict), type(preprocess_args)
self.preprocess_args = dict(preprocess_args)
self.keep_all_data_on_mem = keep_all_data_on_mem
def __call__(self, batch, return_uttid=False):
"""Function to load inputs and targets from list of dicts
:param List[Tuple[str, dict]] batch: list of dict which is subset of
loaded data.json
:param bool return_uttid: return utterance ID information for visualization
:return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]
:return: list of input feature sequences
[(T_1, D), (T_2, D), ..., (T_B, D)]
:rtype: list of float ndarray
:return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]
:rtype: list of int ndarray
"""
x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
uttid_list = [] # List[str]
for uttid, info in batch:
uttid_list.append(uttid)
if self.load_input:
# Note(kamo): This for-loop is for multiple inputs
for idx, inp in enumerate(info["input"]):
# {"input":
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# "name": "input1", ...}], ...}
x = self._get_from_loader(
filepath=inp["feat"],
filetype=inp.get("filetype", "mat"))
x_feats_dict.setdefault(inp["name"], []).append(x)
if self.load_output:
for idx, inp in enumerate(info["output"]):
if "tokenid" in inp:
# ======= Legacy format for output =======
# {"output": [{"tokenid": "1 2 3 4"}])
x = np.fromiter(
map(int, inp["tokenid"].split()), dtype=np.int64)
else:
# ======= New format =======
# {"input":
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# "name": "target1", ...}], ...}
x = self._get_from_loader(
filepath=inp["feat"],
filetype=inp.get("filetype", "mat"))
y_feats_dict.setdefault(inp["name"], []).append(x)
if self.mode == "asr":
return_batch, uttid_list = self._create_batch_asr(
x_feats_dict, y_feats_dict, uttid_list)
else:
raise NotImplementedError(self.mode)
if self.preprocessing is not None:
# Apply pre-processing all input features
for x_name in return_batch.keys():
if x_name.startswith("input"):
return_batch[x_name] = self.preprocessing(
return_batch[x_name], uttid_list,
**self.preprocess_args)
if return_uttid:
return tuple(return_batch.values()), uttid_list
# Doesn't return the names now.
return tuple(return_batch.values())
def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):
"""Create a OrderedDict for the mini-batch
:param OrderedDict x_feats_dict:
e.g. {"input1": [ndarray, ndarray, ...],
"input2": [ndarray, ndarray, ...]}
:param OrderedDict y_feats_dict:
e.g. {"target1": [ndarray, ndarray, ...],
"target2": [ndarray, ndarray, ...]}
:param: List[str] uttid_list:
Give uttid_list to sort in the same order as the mini-batch
:return: batch, uttid_list
:rtype: Tuple[OrderedDict, List[str]]
"""
# handle single-input and multi-input (paralell) asr mode
xs = list(x_feats_dict.values())
if self.load_output:
ys = list(y_feats_dict.values())
assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))
# get index of non-zero length samples
nonzero_idx = list(
filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))
for n in range(1, len(y_feats_dict)):
nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)
else:
# Note(kamo): Be careful not to make nonzero_idx to a generator
nonzero_idx = list(range(len(xs[0])))
if self.sort_in_input_length:
# sort in input lengths based on the first input
nonzero_sorted_idx = sorted(
nonzero_idx, key=lambda i: -len(xs[0][i]))
else:
nonzero_sorted_idx = nonzero_idx
if len(nonzero_sorted_idx) != len(xs[0]):
logging.warning(
"Target sequences include empty tokenid (batch {} -> {}).".
format(len(xs[0]), len(nonzero_sorted_idx)))
# remove zero-length samples
xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]
uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]
x_names = list(x_feats_dict.keys())
if self.load_output:
ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]
y_names = list(y_feats_dict.keys())
# Keeping x_name and y_name, e.g. input1, for future extension
return_batch = OrderedDict([
* [(x_name, x) for x_name, x in zip(x_names, xs)],
* [(y_name, y) for y_name, y in zip(y_names, ys)],
])
else:
return_batch = OrderedDict(
[(x_name, x) for x_name, x in zip(x_names, xs)])
return return_batch, uttid_list
def _get_from_loader(self, filepath, filetype):
"""Return ndarray
In order to make the fds to be opened only at the first referring,
the loader are stored in self._loaders
>>> ndarray = loader.get_from_loader(
... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')
:param: str filepath:
:param: str filetype:
:return:
:rtype: np.ndarray
"""
if filetype == "hdf5":
# e.g.
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
filepath, key = filepath.split(":", 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = h5py.File(filepath, "r")
self._loaders[filepath] = loader
return loader[key][()]
elif filetype == "sound.hdf5":
# e.g.
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "sound.hdf5",
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
filepath, key = filepath.split(":", 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = SoundHDF5File(filepath, "r", dtype="int16")
self._loaders[filepath] = loader
array, rate = loader[key]
return array
elif filetype == "sound":
# e.g.
# {"input": [{"feat": "some/path.wav",
# "filetype": "sound"},
# Assume PCM16
if not self.keep_all_data_on_mem:
array, _ = soundfile.read(filepath, dtype="int16")
return array
if filepath not in self._loaders:
array, _ = soundfile.read(filepath, dtype="int16")
self._loaders[filepath] = array
return self._loaders[filepath]
elif filetype == "npz":
# e.g.
# {"input": [{"feat": "some/path.npz:F01_050C0101_PED_REAL",
# "filetype": "npz",
filepath, key = filepath.split(":", 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = np.load(filepath)
self._loaders[filepath] = loader
return loader[key]
elif filetype == "npy":
# e.g.
# {"input": [{"feat": "some/path.npy",
# "filetype": "npy"},
if not self.keep_all_data_on_mem:
return np.load(filepath)
if filepath not in self._loaders:
self._loaders[filepath] = np.load(filepath)
return self._loaders[filepath]
elif filetype in ["mat", "vec"]:
# e.g.
# {"input": [{"feat": "some/path.ark:123",
# "filetype": "mat"}]},
# In this case, "123" indicates the starting points of the matrix
# load_mat can load both matrix and vector
if not self.keep_all_data_on_mem:
return kaldiio.load_mat(filepath)
if filepath not in self._loaders:
self._loaders[filepath] = kaldiio.load_mat(filepath)
return self._loaders[filepath]
elif filetype == "scp":
# e.g.
# {"input": [{"feat": "some/path.scp:F01_050C0101_PED_REAL",
# "filetype": "scp",
filepath, key = filepath.split(":", 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = kaldiio.load_scp(filepath)
self._loaders[filepath] = loader
return loader[key]
else:
raise NotImplementedError(
"Not supported: loader_type={}".format(filetype))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册