dataloaders.py 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np


def get_class_weights(dataset):
    """Helper function to get weights of classes in the imbalanced dataset.
    Args:
        dataset (WideDeepDataset): dataset containing target classes in dataset.Y 
    Returns:
        weights (np.array): numpy array with weights
        minor_class_count (int): count of samples in the smallest class for undersampling
        num_classes (int): number of classes
    """
    weights = 1/np.unique(dataset.Y, return_counts=True)[1]
    minor_class_count = min(np.unique(dataset.Y, return_counts=True)[1])
    num_classes = len(np.unique(dataset.Y))
    return weights, minor_class_count, num_classes


class DataLoader_default(DataLoader):

    def __init__(self, dataset, batch_size, num_workers, **kwargs):
        super().__init__(dataset, batch_size, num_workers)


class DataLoader_imbalanced(DataLoader):
    """Helper function to load and shuffle tensors into models in
    batches with adjusted weights to "fight" against imbalance of the classes.
    If the classes do not begin from 0 remapping is necessary, see:
    https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab
    Args:
        dataset (WideDeepDataset): dataset containing target classes in dataset.Y 
        batch_size (int): size of batch
        num_workers (int): number of workers
        oversample_mul (float): multiplicator for random oversampling of minority class
    Returns:
        DataLoader: PyTorch DataLoader object
    """
    def __init__(self, dataset, batch_size, num_workers, **kwargs):
        if 'oversample_mul' in kwargs: 
            oversample_mul = kwargs['oversample_mul']
        else:
            oversample_mul = 1
        weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
        num_samples = int(minor_cls_cnt * num_clss * oversample_mul)
        # weight for each sample
        samples_weight = np.array([weights[i] for i in dataset.Y])
        # draw len(dataset) samples with given weights
        sampler = WeightedRandomSampler(
            samples_weight, num_samples, replacement=True)
        # sampler option is mutually exclusive with shuffle, can't set shuffle to
        # false/true
        # setting num_worker>0 with sampler causes error "DataLoader worker (pid 1362) is killed by signal: Segmentation fault"
        # I could not find a workaround, seems its related to sampling/multiprocessing
        super().__init__(dataset, batch_size, num_workers=0, sampler=sampler)