batch.py 4.5 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16 17 18
"""
functions to make batch for arrays which satisfy some conditions.
"""
import numpy as np

L
lifuchen 已提交
19

20 21
class TextIDBatcher(object):
    """A wrapper class for a function to build a functor, which holds the configs to pass to the function."""
L
lifuchen 已提交
22

23 24 25
    def __init__(self, pad_id=0, dtype=np.int64):
        self.pad_id = pad_id
        self.dtype = dtype
L
lifuchen 已提交
26

27 28 29 30
    def __call__(self, minibatch):
        out = batch_text_id(minibatch, pad_id=self.pad_id, dtype=self.dtype)
        return out

L
lifuchen 已提交
31

32
def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
33 34 35 36 37 38
    """
    minibatch: List[Example]
    Example: ndarray, shape(T,), dtype: int64
    """
    peek_example = minibatch[0]
    assert len(peek_example.shape) == 1, "text example is an 1D tensor"
L
lifuchen 已提交
39 40 41

    lengths = [example.shape[0] for example in minibatch
               ]  # assume (channel, n_samples) or (n_samples, )
42
    max_len = np.max(lengths)
L
lifuchen 已提交
43

44 45 46
    batch = []
    for example in minibatch:
        pad_len = max_len - example.shape[0]
L
lifuchen 已提交
47 48 49 50
        batch.append(
            np.pad(example, [(0, pad_len)],
                   mode='constant',
                   constant_values=pad_id))
51 52

    return np.array(batch, dtype=dtype)
53

L
lifuchen 已提交
54

55 56 57 58
class WavBatcher(object):
    def __init__(self, pad_value=0., dtype=np.float32):
        self.pad_value = pad_value
        self.dtype = dtype
L
lifuchen 已提交
59

60 61 62
    def __call__(self, minibatch):
        out = batch_wav(minibatch, pad_value=self.pad_value, dtype=self.dtype)
        return out
63

L
lifuchen 已提交
64

65
def batch_wav(minibatch, pad_value=0., dtype=np.float32):
66 67 68 69
    """
    minibatch: List[Example]
    Example: ndarray, shape(C, T) for multi-channel wav, shape(T,) for mono-channel wav, dtype: float32 
    """
70
    # detect data format, maybe better to specify it in __init__
71 72 73 74 75
    peek_example = minibatch[0]
    if len(peek_example.shape) == 1:
        mono_channel = True
    elif len(peek_example.shape) == 2:
        mono_channel = False
L
lifuchen 已提交
76 77 78

    lengths = [example.shape[-1] for example in minibatch
               ]  # assume (channel, n_samples) or (n_samples, )
79
    max_len = np.max(lengths)
L
lifuchen 已提交
80

81 82 83 84
    batch = []
    for example in minibatch:
        pad_len = max_len - example.shape[-1]
        if mono_channel:
L
lifuchen 已提交
85 86 87 88
            batch.append(
                np.pad(example, [(0, pad_len)],
                       mode='constant',
                       constant_values=pad_value))
89
        else:
L
lifuchen 已提交
90 91 92 93 94
            batch.append(
                np.pad(example, [(0, 0), (0, pad_len)],
                       mode='constant',
                       constant_values=pad_value))  # what about PCM, no

95 96 97 98 99 100 101 102 103 104 105
    return np.array(batch, dtype=dtype)


class SpecBatcher(object):
    def __init__(self, pad_value=0., dtype=np.float32):
        self.pad_value = pad_value
        self.dtype = dtype

    def __call__(self, minibatch):
        out = batch_spec(minibatch, pad_value=self.pad_value, dtype=self.dtype)
        return out
106

L
lifuchen 已提交
107

108
def batch_spec(minibatch, pad_value=0., dtype=np.float32):
109 110 111 112 113 114 115 116 117 118
    """
    minibatch: List[Example]
    Example: ndarray, shape(C, F, T) for multi-channel spectrogram, shape(F, T) for mono-channel spectrogram, dtype: float32 
    """
    # assume (F, T) or (C, F, T)
    peek_example = minibatch[0]
    if len(peek_example.shape) == 2:
        mono_channel = True
    elif len(peek_example.shape) == 3:
        mono_channel = False
L
lifuchen 已提交
119 120 121 122 123

    lengths = [example.shape[-1] for example in minibatch
               ]  # assume (channel, F, n_frame) or (F, n_frame)
    max_len = np.max(lengths)

124 125 126 127
    batch = []
    for example in minibatch:
        pad_len = max_len - example.shape[-1]
        if mono_channel:
L
lifuchen 已提交
128 129 130 131
            batch.append(
                np.pad(example, [(0, 0), (0, pad_len)],
                       mode='constant',
                       constant_values=pad_value))
132
        else:
L
lifuchen 已提交
133 134 135 136 137 138
            batch.append(
                np.pad(example, [(0, 0), (0, 0), (0, pad_len)],
                       mode='constant',
                       constant_values=pad_value))  # what about PCM, no

    return np.array(batch, dtype=dtype)