提交 fe9f519b 编写于 作者: S sibo2rr

modify according to review

上级 cccd13af
......@@ -64,11 +64,14 @@ DataLoader:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
# support to specify width and height respectively:
# scales: [(160,160), (192,192), (256,256) (288,288) (320,320)]
sampler:
name: MultiScaleSamplerDDP
scales: [160, 192, 256, 288, 320]
first_bs: 64
down_sample: 32
is_training: True
loader:
......
......@@ -26,25 +26,7 @@ from ppcls.data import preprocess
from ppcls.data.preprocess import transform
from ppcls.data.preprocess.ops.operators import DecodeImage
from ppcls.utils import logger
def create_operators(params):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(params, list), ('operator config should be a list')
ops = []
for operator in params:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
op = getattr(preprocess, op_name)(**param)
ops.append(op)
return ops
from ppcls.data.dataloader.common_dataset import create_operators
class MultiScaleDataset(Dataset):
......@@ -56,9 +38,6 @@ class MultiScaleDataset(Dataset):
self._img_root = image_root
self._cls_path = cls_label_path
self.transform_ops = transform_ops
# if transform_ops:
# self._transform_ops = create_operators(transform_ops)
self.images = []
self.labels = []
self._load_anno()
......@@ -79,7 +58,6 @@ class MultiScaleDataset(Dataset):
self.labels.append(np.int64(l[1]))
assert os.path.exists(self.images[-1])
def __getitem__(self, properties):
# properites is a tuple, contains (width, height, index)
img_width = properties[0]
......@@ -89,11 +67,14 @@ class MultiScaleDataset(Dataset):
if self.transform_ops:
for i in range(len(self.transform_ops)):
op = self.transform_ops[i]
if 'RandCropImage' in op:
warnings.warn("Multi scale dataset will crop image according to the multi scale resolution")
self.transform_ops[i]['RandCropImage'] = {'size': img_width}
has_crop = True
resize_op = ['RandCropImage', 'ResizeImage', 'CropImage']
for resize in resize_op:
if resize in op:
logger.error("Multi scale dataset will crop image according to the multi scale resolution")
self.transform_ops[i][resize] = {'size': (img_height, img_width)}
has_crop = True
if has_crop == False:
logger.error("Multi scale dateset requests RandCropImage")
raise RuntimeError("Multi scale dateset requests RandCropImage")
self._transform_ops = create_operators(self.transform_ops)
......
......@@ -8,8 +8,16 @@ import numpy as np
from ppcls import data
class MultiScaleSamplerDDP(Sampler):
def __init__(self, data_source, scales, first_bs, g):
print(scales)
def __init__(self, data_source, scales, first_bs, divided_factor=32, is_training = True, seed=None):
"""
multi scale samper
Args:
data_source(dataset)
scales(list): several scales for image resolution
first_bs(int): batch size for the first scale in scales
divided_factor(int): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
is_training(boolean): mode
"""
# min. and max. spatial dimensions
self.data_source = data_source
self.n_data_samples = len(self.data_source)
......@@ -36,8 +44,8 @@ class MultiScaleSamplerDDP(Sampler):
# compute the spatial dimensions and corresponding batch size
# ImageNet models down-sample images by a factor of 32.
# Ensure that width and height dimensions are multiples are multiple of 32.
width_dims = [int((w // 32) * 32) for w in width_dims]
height_dims = [int((h // 32) * 32) for h in height_dims]
width_dims = [int((w // divided_factor) * divided_factor) for w in width_dims]
height_dims = [int((h // divided_factor) * divided_factor) for h in height_dims]
img_batch_pairs = list()
base_elements = base_im_w * base_im_h * base_batch_size
......@@ -54,7 +62,7 @@ class MultiScaleSamplerDDP(Sampler):
self.epoch = 0
self.rank = rank
self.num_replicas = num_replicas
self.seed = seed
self.batch_list = []
self.current = 0
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
......@@ -76,7 +84,10 @@ class MultiScaleSamplerDDP(Sampler):
def __iter__(self):
if self.shuffle:
random.seed(self.epoch)
if self.seed is not None:
random.seed(self.seed)
else:
random.seed(self.epoch)
random.shuffle(self.img_indices)
random.shuffle(self.img_batch_pairs)
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
......
......@@ -50,17 +50,22 @@ class UnifiedResize(object):
}
def _pil_resize(src, size, resample):
# to be accordance with opencv, the input size is (h,w)
pil_img = Image.fromarray(src)
pil_img = pil_img.resize(size, resample)
return np.asarray(pil_img)
def _cv2_resize(src, size, interpolation):
cv_img = cv2.resize(src, size[::-1], interpolation)
return cv_img
if backend.lower() == "cv2":
if isinstance(interpolation, str):
interpolation = _cv2_interp_from_str[interpolation.lower()]
# compatible with opencv < version 4.4.0
elif interpolation is None:
interpolation = cv2.INTER_LINEAR
self.resize_func = partial(cv2.resize, interpolation=interpolation)
self.resize_func = partial(_cv2_resize, interpolation=interpolation)
elif backend.lower() == "pil":
if isinstance(interpolation, str):
interpolation = _pil_interp_from_str[interpolation.lower()]
......@@ -123,8 +128,8 @@ class ResizeImage(object):
self.h = None
elif size is not None:
self.resize_short = None
self.w = size if type(size) is int else size[0]
self.h = size if type(size) is int else size[1]
self.h = size if type(size) is int else size[0]
self.w = size if type(size) is int else size[1]
else:
raise OperatorParamError("invalid params for ReisizeImage for '\
'both 'size' and 'resize_short' are None")
......@@ -141,7 +146,7 @@ class ResizeImage(object):
else:
w = self.w
h = self.h
return self._resize_func(img, (w, h))
return self._resize_func(img, (h, w))
class CropImage(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册