提交 b6cf2786 编写于 作者: S sibo2rr

change MultiScaleSamplerDDP into MultiScaleSampler

上级 fe9f519b
......@@ -66,12 +66,14 @@ DataLoader:
order: ''
# support to specify width and height respectively:
# scales: [(160,160), (192,192), (256,256) (288,288) (320,320)]
# scales: [(160,160), (192,192), (224,225) (288,288) (320,320)]
sampler:
name: MultiScaleSamplerDDP
scales: [160, 192, 256, 288, 320]
name: MultiScaleSampler
scales: [160, 192, 224, 288, 320]
# first_bs: batch size for the first image resolution in the scales list
# divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
first_bs: 64
down_sample: 32
divided_factor: 32
is_training: True
loader:
......
......@@ -34,7 +34,7 @@ from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSamplerDDP
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
......
......@@ -7,5 +7,5 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSamplerDDP
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
from ppcls.data.dataloader.pk_sampler import PKSampler
......@@ -41,6 +41,7 @@ class MultiScaleDataset(Dataset):
self.images = []
self.labels = []
self._load_anno()
self.has_crop_flag = 1
def _load_anno(self, seed=None):
assert os.path.exists(self._cls_path)
......@@ -70,9 +71,15 @@ class MultiScaleDataset(Dataset):
resize_op = ['RandCropImage', 'ResizeImage', 'CropImage']
for resize in resize_op:
if resize in op:
logger.error("Multi scale dataset will crop image according to the multi scale resolution")
self.transform_ops[i][resize] = {'size': (img_height, img_width)}
if self.has_crop_flag:
logger.error(
"Multi scale dataset will crop image according to the multi scale resolution"
)
self.transform_ops[i][resize] = {
'size': (img_height, img_width)
}
has_crop = True
self.has_crop_flag = 0
if has_crop == False:
logger.error("Multi scale dateset requests RandCropImage")
raise RuntimeError("Multi scale dateset requests RandCropImage")
......@@ -82,7 +89,7 @@ class MultiScaleDataset(Dataset):
with open(self.images[index], 'rb') as f:
img = f.read()
if self._transform_ops:
img = transform(img, self._transform_ops)
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, self.labels[index])
......
......@@ -7,8 +7,15 @@ import numpy as np
from ppcls import data
class MultiScaleSamplerDDP(Sampler):
def __init__(self, data_source, scales, first_bs, divided_factor=32, is_training = True, seed=None):
class MultiScaleSampler(Sampler):
def __init__(self,
data_source,
scales,
first_bs,
divided_factor=32,
is_training=True,
seed=None):
"""
multi scale samper
Args:
......@@ -21,7 +28,7 @@ class MultiScaleSamplerDDP(Sampler):
# min. and max. spatial dimensions
self.data_source = data_source
self.n_data_samples = len(self.data_source)
if isinstance(scales[0], tuple):
width_dims = [i[0] for i in scales]
height_dims = [i[1] for i in scales]
......@@ -31,12 +38,13 @@ class MultiScaleSamplerDDP(Sampler):
base_im_w = width_dims[0]
base_im_h = height_dims[0]
base_batch_size = first_bs
# Get the GPU and node related information
num_replicas =dist.get_world_size()
num_replicas = dist.get_world_size()
rank = dist.get_rank()
# adjust the total samples to avoid batch dropping
num_samples_per_replica = int(math.ceil(self.n_data_samples * 1.0 / num_replicas))
num_samples_per_replica = int(
math.ceil(self.n_data_samples * 1.0 / num_replicas))
img_indices = [idx for idx in range(self.n_data_samples)]
self.shuffle = False
......@@ -44,8 +52,13 @@ class MultiScaleSamplerDDP(Sampler):
# compute the spatial dimensions and corresponding batch size
# ImageNet models down-sample images by a factor of 32.
# Ensure that width and height dimensions are multiples are multiple of 32.
width_dims = [int((w // divided_factor) * divided_factor) for w in width_dims]
height_dims = [int((h // divided_factor) * divided_factor) for h in height_dims]
width_dims = [
int((w // divided_factor) * divided_factor) for w in width_dims
]
height_dims = [
int((h // divided_factor) * divided_factor)
for h in height_dims
]
img_batch_pairs = list()
base_elements = base_im_w * base_im_h * base_batch_size
......@@ -55,8 +68,8 @@ class MultiScaleSamplerDDP(Sampler):
self.img_batch_pairs = img_batch_pairs
self.shuffle = True
else:
self.img_batch_pairs = [(base_im_h , base_im_w , base_batch_size)]
self.img_batch_pairs = [(base_im_h, base_im_w, base_batch_size)]
self.img_indices = img_indices
self.n_samples_per_replica = num_samples_per_replica
self.epoch = 0
......@@ -65,21 +78,23 @@ class MultiScaleSamplerDDP(Sampler):
self.seed = seed
self.batch_list = []
self.current = 0
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
while self.current < self.n_samples_per_replica:
curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs)
end_index = min(self.current + curr_bsz, self.n_samples_per_replica)
end_index = min(self.current + curr_bsz,
self.n_samples_per_replica)
batch_ids = indices_rank_i[self.current:end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
self.current += curr_bsz
if len(batch_ids) > 0:
batch = [curr_h, curr_w, len(batch_ids)]
self.batch_list.append(batch)
batch = [curr_h, curr_w, len(batch_ids)]
self.batch_list.append(batch)
self.length = len(self.batch_list)
def __iter__(self):
......@@ -90,9 +105,11 @@ class MultiScaleSamplerDDP(Sampler):
random.seed(self.epoch)
random.shuffle(self.img_indices)
random.shuffle(self.img_batch_pairs)
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
else:
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
start_index = 0
for batch_tuple in self.batch_list:
......@@ -101,16 +118,15 @@ class MultiScaleSamplerDDP(Sampler):
batch_ids = indices_rank_i[start_index:end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
start_index += curr_bsz
if len(batch_ids) > 0:
batch = [(curr_h, curr_w, b_id) for b_id in batch_ids]
yield batch
batch = [(curr_h, curr_w, b_id) for b_id in batch_ids]
yield batch
def set_epoch(self, epoch: int):
self.epoch = epoch
def __len__(self):
return self.length
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册