dataloaders.py 2.2 KB
Newer Older
1
import numpy as np
2 3 4 5
from torch.utils.data import DataLoader, WeightedRandomSampler

from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
6 7


8
def get_class_weights(dataset: WideDeepDataset) -> Tuple[np.ndarray, int, int]:
9
    """Helper function to get weights of classes in the imbalanced dataset.
10 11 12 13 14 15 16 17 18 19 20 21 22 23

    Parameters
    ----------
    dataset: ``WideDeepDataset``
        dataset containing target classes in dataset.Y

    Returns
    ----------
    weights: array
        numpy array with weights
    minor_class_count: int
        count of samples in the smallest class for undersampling
    num_classes: int
        number of classes
24
    """
25
    weights = 1 / np.unique(dataset.Y, return_counts=True)[1]
26 27 28 29 30
    minor_class_count = min(np.unique(dataset.Y, return_counts=True)[1])
    num_classes = len(np.unique(dataset.Y))
    return weights, minor_class_count, num_classes


31
class DataLoaderDefault(DataLoader):
32 33 34 35
    def __init__(self, dataset, batch_size, num_workers, **kwargs):
        super().__init__(dataset, batch_size, num_workers)


36
class DataLoaderImbalanced(DataLoader):
37 38 39
    r"""Class to load and shuffle batches with adjusted weights for imbalanced
    datasets. If the classes do not begin from 0 remapping is necessary. See
    `here <https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab>`_
40 41 42

    Parameters
    ----------
43 44
    dataset: ``WideDeepDataset``
        see ``pytorch_widedeep.training._wd_dataset``
45 46 47 48
    batch_size: int
        size of batch
    num_workers: int
        number of workers
49
    """
50 51 52 53 54 55

    def __init__(
        self, dataset: WideDeepDataset, batch_size: int, num_workers: int, **kwargs
    ):
        if "oversample_mul" in kwargs:
            oversample_mul = kwargs["oversample_mul"]
56 57 58 59
        else:
            oversample_mul = 1
        weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
        num_samples = int(minor_cls_cnt * num_clss * oversample_mul)
60 61 62
        samples_weight = list(np.array([weights[i] for i in dataset.Y]))
        sampler = WeightedRandomSampler(samples_weight, num_samples, replacement=True)
        super().__init__(dataset, batch_size, num_workers=num_workers, sampler=sampler)