utils.py 3.2 KB
Newer Older
1
import torch
2
import torch.utils.data as tdata
3
from data.random_erasing import RandomErasingTorch
4
from data.transforms import *
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21


def fast_collate(batch):
    targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
    batch_size = len(targets)
    tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
    for i in range(batch_size):
        tensor[i] += torch.from_numpy(batch[i][0])

    return tensor, targets


class PrefetchLoader:

    def __init__(self,
            loader,
            fp16=False,
22 23 24
            random_erasing=0.,
            mean=IMAGENET_DEFAULT_MEAN,
            std=IMAGENET_DEFAULT_STD):
25 26 27 28 29 30
        self.loader = loader
        self.fp16 = fp16
        self.random_erasing = random_erasing
        self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
        self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
        if random_erasing:
31 32
            self.random_erasing = RandomErasingTorch(
                probability=random_erasing, per_pixel=True)
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
        else:
            self.random_erasing = None

        if self.fp16:
            self.mean = self.mean.half()
            self.std = self.std.half()

    def __iter__(self):
        stream = torch.cuda.Stream()
        first = True

        for next_input, next_target in self.loader:
            with torch.cuda.stream(stream):
                next_input = next_input.cuda(non_blocking=True)
                next_target = next_target.cuda(non_blocking=True)
                if self.fp16:
                    next_input = next_input.half()
                else:
                    next_input = next_input.float()
                next_input = next_input.sub_(self.mean).div_(self.std)
                if self.random_erasing is not None:
                    next_input = self.random_erasing(next_input)

            if not first:
                yield input, target
            else:
                first = False

            torch.cuda.current_stream().wait_stream(stream)
            input = next_input
            target = next_target

        yield input, target

    def __len__(self):
        return len(self.loader)
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112


def create_loader(
        dataset,
        img_size,
        batch_size,
        is_training=False,
        use_prefetcher=True,
        random_erasing=0.,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        num_workers=1,
):

    if is_training:
        transform = transforms_imagenet_train(
            img_size,
            use_prefetcher=use_prefetcher,
            mean=mean,
            std=std)
    else:
        transform = transforms_imagenet_eval(
            img_size,
            use_prefetcher=use_prefetcher,
            mean=mean,
            std=std)

    dataset.transform = transform

    loader = tdata.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_training,
        num_workers=num_workers,
        collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate,
    )
    if use_prefetcher:
        loader = PrefetchLoader(
            loader,
            random_erasing=random_erasing if is_training else 0.,
            mean=mean,
            std=std)

    return loader