multi_scale_sampler.py 5.0 KB
Newer Older
S
sibo2rr 已提交
1 2 3 4 5 6 7 8 9
from paddle.io import Sampler
import paddle.distributed as dist

import math
import random
import numpy as np

from ppcls import data

10 11 12 13 14 15 16 17 18

class MultiScaleSampler(Sampler):
    def __init__(self,
                 data_source,
                 scales,
                 first_bs,
                 divided_factor=32,
                 is_training=True,
                 seed=None):
S
sibo2rr 已提交
19 20 21 22 23 24 25 26 27
        """
            multi scale samper
            Args:
                data_source(dataset)
                scales(list): several scales for image resolution
                first_bs(int): batch size for the first scale in scales
                divided_factor(int): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
                is_training(boolean): mode 
        """
S
sibo2rr 已提交
28 29 30
        # min. and max. spatial dimensions
        self.data_source = data_source
        self.n_data_samples = len(self.data_source)
31

S
sibo2rr 已提交
32 33 34 35 36 37 38 39 40
        if isinstance(scales[0], tuple):
            width_dims = [i[0] for i in scales]
            height_dims = [i[1] for i in scales]
        elif isinstance(scales[0], int):
            width_dims = scales
            height_dims = scales
        base_im_w = width_dims[0]
        base_im_h = height_dims[0]
        base_batch_size = first_bs
41

S
sibo2rr 已提交
42
        # Get the GPU and node related information
43
        num_replicas = dist.get_world_size()
S
sibo2rr 已提交
44 45
        rank = dist.get_rank()
        # adjust the total samples to avoid batch dropping
46 47
        num_samples_per_replica = int(
            math.ceil(self.n_data_samples * 1.0 / num_replicas))
S
sibo2rr 已提交
48 49 50 51 52 53 54
        img_indices = [idx for idx in range(self.n_data_samples)]

        self.shuffle = False
        if is_training:
            # compute the spatial dimensions and corresponding batch size
            # ImageNet models down-sample images by a factor of 32.
            # Ensure that width and height dimensions are multiples are multiple of 32.
55 56 57 58 59 60 61
            width_dims = [
                int((w // divided_factor) * divided_factor) for w in width_dims
            ]
            height_dims = [
                int((h // divided_factor) * divided_factor)
                for h in height_dims
            ]
S
sibo2rr 已提交
62 63 64 65 66 67 68 69 70

            img_batch_pairs = list()
            base_elements = base_im_w * base_im_h * base_batch_size
            for (h, w) in zip(height_dims, width_dims):
                batch_size = int(max(1, (base_elements / (h * w))))
                img_batch_pairs.append((h, w, batch_size))
            self.img_batch_pairs = img_batch_pairs
            self.shuffle = True
        else:
71 72
            self.img_batch_pairs = [(base_im_h, base_im_w, base_batch_size)]

S
sibo2rr 已提交
73 74 75 76 77
        self.img_indices = img_indices
        self.n_samples_per_replica = num_samples_per_replica
        self.epoch = 0
        self.rank = rank
        self.num_replicas = num_replicas
S
sibo2rr 已提交
78
        self.seed = seed
S
sibo2rr 已提交
79 80
        self.batch_list = []
        self.current = 0
81 82
        indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
                                          self.num_replicas]
S
sibo2rr 已提交
83 84 85
        while self.current < self.n_samples_per_replica:
            curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs)

86 87
            end_index = min(self.current + curr_bsz,
                            self.n_samples_per_replica)
S
sibo2rr 已提交
88 89 90 91

            batch_ids = indices_rank_i[self.current:end_index]
            n_batch_samples = len(batch_ids)
            if n_batch_samples != curr_bsz:
92
                batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
S
sibo2rr 已提交
93 94 95
            self.current += curr_bsz

            if len(batch_ids) > 0:
96 97
                batch = [curr_h, curr_w, len(batch_ids)]
                self.batch_list.append(batch)
S
sibo2rr 已提交
98 99 100 101
        self.length = len(self.batch_list)

    def __iter__(self):
        if self.shuffle:
S
sibo2rr 已提交
102 103 104 105
            if self.seed is not None:
                random.seed(self.seed)
            else:
                random.seed(self.epoch)
S
sibo2rr 已提交
106 107
            random.shuffle(self.img_indices)
            random.shuffle(self.img_batch_pairs)
108 109
            indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
                                              self.num_replicas]
S
sibo2rr 已提交
110
        else:
111 112
            indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
                                              self.num_replicas]
S
sibo2rr 已提交
113 114 115 116 117 118 119 120

        start_index = 0
        for batch_tuple in self.batch_list:
            curr_h, curr_w, curr_bsz = batch_tuple
            end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
            batch_ids = indices_rank_i[start_index:end_index]
            n_batch_samples = len(batch_ids)
            if n_batch_samples != curr_bsz:
121
                batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
S
sibo2rr 已提交
122 123 124
            start_index += curr_bsz

            if len(batch_ids) > 0:
125 126
                batch = [(curr_h, curr_w, b_id) for b_id in batch_ids]
                yield batch
S
sibo2rr 已提交
127 128 129 130 131 132

    def set_epoch(self, epoch: int):
        self.epoch = epoch

    def __len__(self):
        return self.length