batch.py 5.9 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
"""
15 16
Utility functions to create batch for arrays which satisfy some conditions.
Batch functions for text sequences, audio and spectrograms are provided.
17 18 19
"""
import numpy as np

L
lifuchen 已提交
20

21
class TextIDBatcher(object):
22
    """A wrapper class for `batch_text_id`."""
L
lifuchen 已提交
23

24 25 26
    def __init__(self, pad_id=0, dtype=np.int64):
        self.pad_id = pad_id
        self.dtype = dtype
L
lifuchen 已提交
27

28 29 30 31
    def __call__(self, minibatch):
        out = batch_text_id(minibatch, pad_id=self.pad_id, dtype=self.dtype)
        return out

L
lifuchen 已提交
32

33
def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
34 35 36 37 38 39 40 41 42
    """Pad sequences to text_ids to the largest length and batch them.
    
    Args:
        minibatch (List[np.ndarray]): list of rank-1 arrays, shape(T,), dtype: np.int64, text_ids.
        pad_id (int, optional): the id which correspond to the special pad token. Defaults to 0.
        dtype (np.dtype, optional): the data dtype of the output. Defaults to np.int64.

    Returns:
        np.ndarray: rank-2 array of text_ids, shape(B, T), B stands for batch_size, T stands for length. The output batch.
43 44 45
    """
    peek_example = minibatch[0]
    assert len(peek_example.shape) == 1, "text example is an 1D tensor"
L
lifuchen 已提交
46 47 48

    lengths = [example.shape[0] for example in minibatch
               ]  # assume (channel, n_samples) or (n_samples, )
49
    max_len = np.max(lengths)
L
lifuchen 已提交
50

51 52 53
    batch = []
    for example in minibatch:
        pad_len = max_len - example.shape[0]
L
lifuchen 已提交
54 55 56 57
        batch.append(
            np.pad(example, [(0, pad_len)],
                   mode='constant',
                   constant_values=pad_id))
58 59

    return np.array(batch, dtype=dtype)
60

L
lifuchen 已提交
61

62
class WavBatcher(object):
63 64
    """A wrapper class for `batch_wav`."""

65 66 67
    def __init__(self, pad_value=0., dtype=np.float32):
        self.pad_value = pad_value
        self.dtype = dtype
L
lifuchen 已提交
68

69 70 71
    def __call__(self, minibatch):
        out = batch_wav(minibatch, pad_value=self.pad_value, dtype=self.dtype)
        return out
72

L
lifuchen 已提交
73

74
def batch_wav(minibatch, pad_value=0., dtype=np.float32):
75 76 77 78 79 80 81 82 83
    """pad audios to the largest length and batch them.

    Args:
        minibatch (List[np.ndarray]): list of rank-1 float arrays(mono-channel audio, shape(T,)) or list of rank-2 float arrays(multi-channel audio, shape(C, T), C stands for numer of channels, T stands for length), dtype: float.
        pad_value (float, optional): the pad value. Defaults to 0..
        dtype (np.dtype, optional): the data type of the output. Defaults to np.float32.

    Returns:
        np.ndarray: the output batch. It is a rank-2 float array of shape(B, T) if the minibatch is a list of mono-channel audios, or a rank-3 float array of shape(B, C, T) if the minibatch is a list of multi-channel audios.
84
    """
85

86 87 88 89 90
    peek_example = minibatch[0]
    if len(peek_example.shape) == 1:
        mono_channel = True
    elif len(peek_example.shape) == 2:
        mono_channel = False
L
lifuchen 已提交
91

92 93
    # assume (channel, n_samples) or (n_samples, )
    lengths = [example.shape[-1] for example in minibatch]
94
    max_len = np.max(lengths)
L
lifuchen 已提交
95

96 97 98 99
    batch = []
    for example in minibatch:
        pad_len = max_len - example.shape[-1]
        if mono_channel:
L
lifuchen 已提交
100 101 102 103
            batch.append(
                np.pad(example, [(0, pad_len)],
                       mode='constant',
                       constant_values=pad_value))
104
        else:
L
lifuchen 已提交
105 106 107
            batch.append(
                np.pad(example, [(0, 0), (0, pad_len)],
                       mode='constant',
108
                       constant_values=pad_value))
L
lifuchen 已提交
109

110 111 112 113
    return np.array(batch, dtype=dtype)


class SpecBatcher(object):
114 115
    """A wrapper class for `batch_spec`"""

116 117 118 119 120 121 122
    def __init__(self, pad_value=0., dtype=np.float32):
        self.pad_value = pad_value
        self.dtype = dtype

    def __call__(self, minibatch):
        out = batch_spec(minibatch, pad_value=self.pad_value, dtype=self.dtype)
        return out
123

L
lifuchen 已提交
124

125
def batch_spec(minibatch, pad_value=0., dtype=np.float32):
126 127 128 129 130 131 132 133 134
    """Pad spectra to the largest length and batch them.

    Args:
        minibatch (List[np.ndarray]): list of rank-2 arrays of shape(F, T) for mono-channel spectrograms, or list of rank-3 arrays of shape(C, F, T) for multi-channel spectrograms(F stands for frequency bands.), dtype: float.
        pad_value (float, optional): the pad value. Defaults to 0..
        dtype (np.dtype, optional): data type of the output. Defaults to np.float32.

    Returns:
        np.ndarray: a rank-3 array of shape(B, F, T) when the minibatch is a list of mono-channel spectrograms, or a rank-4 array of shape(B, C, F, T) when the minibatch is a list of multi-channel spectorgrams.
135 136 137 138 139 140 141
    """
    # assume (F, T) or (C, F, T)
    peek_example = minibatch[0]
    if len(peek_example.shape) == 2:
        mono_channel = True
    elif len(peek_example.shape) == 3:
        mono_channel = False
L
lifuchen 已提交
142

143 144
    # assume (channel, F, n_frame) or (F, n_frame)
    lengths = [example.shape[-1] for example in minibatch]
L
lifuchen 已提交
145 146
    max_len = np.max(lengths)

147 148 149 150
    batch = []
    for example in minibatch:
        pad_len = max_len - example.shape[-1]
        if mono_channel:
L
lifuchen 已提交
151 152 153 154
            batch.append(
                np.pad(example, [(0, 0), (0, pad_len)],
                       mode='constant',
                       constant_values=pad_value))
155
        else:
L
lifuchen 已提交
156 157 158
            batch.append(
                np.pad(example, [(0, 0), (0, 0), (0, pad_len)],
                       mode='constant',
159
                       constant_values=pad_value))
L
lifuchen 已提交
160 161

    return np.array(batch, dtype=dtype)