batch.py 3.6 KB
Newer Older
1 2 3 4 5
"""
functions to make batch for arrays which satisfy some conditions.
"""
import numpy as np

6 7 8 9 10 11 12 13 14 15 16
class TextIDBatcher(object):
    """A wrapper class for a function to build a functor, which holds the configs to pass to the function."""
    def __init__(self, pad_id=0, dtype=np.int64):
        self.pad_id = pad_id
        self.dtype = dtype
    
    def __call__(self, minibatch):
        out = batch_text_id(minibatch, pad_id=self.pad_id, dtype=self.dtype)
        return out

def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
17 18 19 20 21 22 23 24 25 26 27 28 29
    """
    minibatch: List[Example]
    Example: ndarray, shape(T,), dtype: int64
    """
    peek_example = minibatch[0]
    assert len(peek_example.shape) == 1, "text example is an 1D tensor"
    
    lengths = [example.shape[0] for example in minibatch] # assume (channel, n_samples) or (n_samples, )
    max_len = np.max(lengths)
    
    batch = []
    for example in minibatch:
        pad_len = max_len - example.shape[0]
30 31 32
        batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=pad_id))

    return np.array(batch, dtype=dtype)
33

34 35 36 37 38 39 40 41
class WavBatcher(object):
    def __init__(self, pad_value=0., dtype=np.float32):
        self.pad_value = pad_value
        self.dtype = dtype
        
    def __call__(self, minibatch):
        out = batch_wav(minibatch, pad_value=self.pad_value, dtype=self.dtype)
        return out
42

43
def batch_wav(minibatch, pad_value=0., dtype=np.float32):
44 45 46 47
    """
    minibatch: List[Example]
    Example: ndarray, shape(C, T) for multi-channel wav, shape(T,) for mono-channel wav, dtype: float32 
    """
48
    # detect data format, maybe better to specify it in __init__
49 50 51 52 53 54 55 56 57 58 59 60 61
    peek_example = minibatch[0]
    if len(peek_example.shape) == 1:
        mono_channel = True
    elif len(peek_example.shape) == 2:
        mono_channel = False
    
    lengths = [example.shape[-1] for example in minibatch] # assume (channel, n_samples) or (n_samples, )
    max_len = np.max(lengths)
    
    batch = []
    for example in minibatch:
        pad_len = max_len - example.shape[-1]
        if mono_channel:
62
            batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=pad_value))
63
        else:
64
            batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=pad_value)) # what about PCM, no
65
    
66 67 68 69 70 71 72 73 74 75 76
    return np.array(batch, dtype=dtype)


class SpecBatcher(object):
    def __init__(self, pad_value=0., dtype=np.float32):
        self.pad_value = pad_value
        self.dtype = dtype

    def __call__(self, minibatch):
        out = batch_spec(minibatch, pad_value=self.pad_value, dtype=self.dtype)
        return out
77

78
def batch_spec(minibatch, pad_value=0., dtype=np.float32):
79 80 81 82 83 84 85 86 87 88 89 90
    """
    minibatch: List[Example]
    Example: ndarray, shape(C, F, T) for multi-channel spectrogram, shape(F, T) for mono-channel spectrogram, dtype: float32 
    """
    # assume (F, T) or (C, F, T)
    peek_example = minibatch[0]
    if len(peek_example.shape) == 2:
        mono_channel = True
    elif len(peek_example.shape) == 3:
        mono_channel = False
    
    lengths = [example.shape[-1] for example in minibatch] # assume (channel, F, n_frame) or (F, n_frame)
L
lifuchen 已提交
91
    max_len = np.max(lengths)  
92 93 94 95 96
    
    batch = []
    for example in minibatch:
        pad_len = max_len - example.shape[-1]
        if mono_channel:
97
            batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=pad_value))
98
        else:
99
            batch.append(np.pad(example, [(0, 0), (0, 0), (0, pad_len)], mode='constant', constant_values=pad_value)) # what about PCM, no
100
    
101
    return np.array(batch, dtype=dtype)