multi_scale_sampler.py 4.0 KB
Newer Older
S
sibo2rr 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
from paddle.io import Sampler
import paddle.distributed as dist

import math
import random
import numpy as np

from ppcls import data

class MultiScaleSamplerDDP(Sampler):
    def __init__(self, data_source, scales, first_bs, g):
        print(scales)
        # min. and max. spatial dimensions
        self.data_source = data_source
        self.n_data_samples = len(self.data_source)
        
        if isinstance(scales[0], tuple):
            width_dims = [i[0] for i in scales]
            height_dims = [i[1] for i in scales]
        elif isinstance(scales[0], int):
            width_dims = scales
            height_dims = scales
        base_im_w = width_dims[0]
        base_im_h = height_dims[0]
        base_batch_size = first_bs
        
        # Get the GPU and node related information
        num_replicas  =dist.get_world_size()
        rank = dist.get_rank()
        # adjust the total samples to avoid batch dropping
        num_samples_per_replica = int(math.ceil(self.n_data_samples * 1.0 / num_replicas))
        img_indices = [idx for idx in range(self.n_data_samples)]

        self.shuffle = False
        if is_training:
            # compute the spatial dimensions and corresponding batch size
            # ImageNet models down-sample images by a factor of 32.
            # Ensure that width and height dimensions are multiples are multiple of 32.
            width_dims = [int((w // 32) * 32) for w in width_dims]
            height_dims = [int((h // 32) * 32) for h in height_dims]

            img_batch_pairs = list()
            base_elements = base_im_w * base_im_h * base_batch_size
            for (h, w) in zip(height_dims, width_dims):
                batch_size = int(max(1, (base_elements / (h * w))))
                img_batch_pairs.append((h, w, batch_size))
            self.img_batch_pairs = img_batch_pairs
            self.shuffle = True
        else:
            self.img_batch_pairs = [(base_im_h , base_im_w , base_batch_size)]
        
        self.img_indices = img_indices
        self.n_samples_per_replica = num_samples_per_replica
        self.epoch = 0
        self.rank = rank
        self.num_replicas = num_replicas
        
        self.batch_list = []
        self.current = 0
        indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
        while self.current < self.n_samples_per_replica:
            curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs)

            end_index = min(self.current + curr_bsz, self.n_samples_per_replica)

            batch_ids = indices_rank_i[self.current:end_index]
            n_batch_samples = len(batch_ids)
            if n_batch_samples != curr_bsz:
                    batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
            self.current += curr_bsz

            if len(batch_ids) > 0:
                    batch = [curr_h, curr_w, len(batch_ids)]
                    self.batch_list.append(batch)
        self.length = len(self.batch_list)

    def __iter__(self):
        if self.shuffle:
            random.seed(self.epoch)
            random.shuffle(self.img_indices)
            random.shuffle(self.img_batch_pairs)
            indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
        else:
            indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]

        start_index = 0
        for batch_tuple in self.batch_list:
            curr_h, curr_w, curr_bsz = batch_tuple
            end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
            batch_ids = indices_rank_i[start_index:end_index]
            n_batch_samples = len(batch_ids)
            if n_batch_samples != curr_bsz:
                    batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
            start_index += curr_bsz

            if len(batch_ids) > 0:
                    batch = [(curr_h, curr_w, b_id) for b_id in batch_ids]
                    yield batch

    def set_epoch(self, epoch: int):
        self.epoch = epoch

    def __len__(self):
        return self.length