提交 5c04b3e6 编写于 作者: J jiangjiajun

composed transforms enhancement

上级 c41152ab
...@@ -39,14 +39,14 @@ class EasyDataCls(ImageNet): ...@@ -39,14 +39,14 @@ class EasyDataCls(ImageNet):
线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。 线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。 shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
""" """
def __init__(self, def __init__(self,
data_dir, data_dir,
file_list, file_list,
label_list, label_list,
transforms=None, transforms=None,
num_workers='auto', num_workers='auto',
buffer_size=100, buffer_size=8,
parallel_method='process', parallel_method='process',
shuffle=False): shuffle=False):
super(ImageNet, self).__init__( super(ImageNet, self).__init__(
...@@ -58,7 +58,7 @@ class EasyDataCls(ImageNet): ...@@ -58,7 +58,7 @@ class EasyDataCls(ImageNet):
self.file_list = list() self.file_list = list()
self.labels = list() self.labels = list()
self._epoch = 0 self._epoch = 0
with open(label_list, encoding=get_encoding(label_list)) as f: with open(label_list, encoding=get_encoding(label_list)) as f:
for line in f: for line in f:
item = line.strip() item = line.strip()
...@@ -73,8 +73,8 @@ class EasyDataCls(ImageNet): ...@@ -73,8 +73,8 @@ class EasyDataCls(ImageNet):
if not osp.isfile(json_file): if not osp.isfile(json_file):
continue continue
if not osp.exists(img_file): if not osp.exists(img_file):
raise IOError( raise IOError('The image file {} is not exist!'.format(
'The image file {} is not exist!'.format(img_file)) img_file))
with open(json_file, mode='r', \ with open(json_file, mode='r', \
encoding=get_encoding(json_file)) as j: encoding=get_encoding(json_file)) as j:
json_info = json.load(j) json_info = json.load(j)
...@@ -83,4 +83,3 @@ class EasyDataCls(ImageNet): ...@@ -83,4 +83,3 @@ class EasyDataCls(ImageNet):
self.num_samples = len(self.file_list) self.num_samples = len(self.file_list)
logging.info("{} samples in file {}".format( logging.info("{} samples in file {}".format(
len(self.file_list), file_list)) len(self.file_list), file_list))
\ No newline at end of file
...@@ -45,7 +45,7 @@ class ImageNet(Dataset): ...@@ -45,7 +45,7 @@ class ImageNet(Dataset):
label_list, label_list,
transforms=None, transforms=None,
num_workers='auto', num_workers='auto',
buffer_size=100, buffer_size=8,
parallel_method='process', parallel_method='process',
shuffle=False): shuffle=False):
super(ImageNet, self).__init__( super(ImageNet, self).__init__(
...@@ -70,8 +70,8 @@ class ImageNet(Dataset): ...@@ -70,8 +70,8 @@ class ImageNet(Dataset):
continue continue
full_path = osp.join(data_dir, items[0]) full_path = osp.join(data_dir, items[0])
if not osp.exists(full_path): if not osp.exists(full_path):
raise IOError( raise IOError('The image file {} is not exist!'.format(
'The image file {} is not exist!'.format(full_path)) full_path))
self.file_list.append([full_path, int(items[1])]) self.file_list.append([full_path, int(items[1])])
self.num_samples = len(self.file_list) self.num_samples = len(self.file_list)
logging.info("{} samples in file {}".format( logging.info("{} samples in file {}".format(
......
...@@ -70,8 +70,8 @@ class Compose(ClsTransform): ...@@ -70,8 +70,8 @@ class Compose(ClsTransform):
if isinstance(im, np.ndarray): if isinstance(im, np.ndarray):
if len(im.shape) != 3: if len(im.shape) != 3:
raise Exception( raise Exception(
"im should be 3-dimension, but now is {}-dimensions". "im should be 3-dimension, but now is {}-dimensions".format(
format(len(im.shape))) len(im.shape)))
else: else:
try: try:
im = cv2.imread(im).astype('float32') im = cv2.imread(im).astype('float32')
...@@ -100,7 +100,9 @@ class Compose(ClsTransform): ...@@ -100,7 +100,9 @@ class Compose(ClsTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
logging.error("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) logging.error(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
...@@ -139,8 +141,8 @@ class RandomCrop(ClsTransform): ...@@ -139,8 +141,8 @@ class RandomCrop(ClsTransform):
tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据; tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
""" """
im = random_crop(im, self.crop_size, self.lower_scale, im = random_crop(im, self.crop_size, self.lower_scale, self.lower_ratio,
self.lower_ratio, self.upper_ratio) self.upper_ratio)
if label is None: if label is None:
return (im, ) return (im, )
else: else:
...@@ -270,14 +272,12 @@ class ResizeByShort(ClsTransform): ...@@ -270,14 +272,12 @@ class ResizeByShort(ClsTransform):
im_short_size = min(im.shape[0], im.shape[1]) im_short_size = min(im.shape[0], im.shape[1])
im_long_size = max(im.shape[0], im.shape[1]) im_long_size = max(im.shape[0], im.shape[1])
scale = float(self.short_size) / im_short_size scale = float(self.short_size) / im_short_size
if self.max_size > 0 and np.round(scale * if self.max_size > 0 and np.round(scale * im_long_size) > self.max_size:
im_long_size) > self.max_size:
scale = float(self.max_size) / float(im_long_size) scale = float(self.max_size) / float(im_long_size)
resized_width = int(round(im.shape[1] * scale)) resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale)) resized_height = int(round(im.shape[0] * scale))
im = cv2.resize( im = cv2.resize(
im, (resized_width, resized_height), im, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)
interpolation=cv2.INTER_LINEAR)
if label is None: if label is None:
return (im, ) return (im, )
...@@ -490,13 +490,15 @@ class ComposedClsTransforms(Compose): ...@@ -490,13 +490,15 @@ class ComposedClsTransforms(Compose):
crop_size(int|list): 输入模型里的图像大小 crop_size(int|list): 输入模型里的图像大小
mean(list): 图像均值 mean(list): 图像均值
std(list): 图像方差 std(list): 图像方差
random_horizontal_flip(bool): 是否以0.5的概率使用随机水平翻转增强,该仅在mode为`train`时生效,默认为True
""" """
def __init__(self, def __init__(self,
mode, mode,
crop_size=[224, 224], crop_size=[224, 224],
mean=[0.485, 0.456, 0.406], mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]): std=[0.229, 0.224, 0.225],
random_horizontal_flip=True):
width = crop_size width = crop_size
if isinstance(crop_size, list): if isinstance(crop_size, list):
if crop_size[0] != crop_size[1]: if crop_size[0] != crop_size[1]:
...@@ -512,10 +514,11 @@ class ComposedClsTransforms(Compose): ...@@ -512,10 +514,11 @@ class ComposedClsTransforms(Compose):
if mode == 'train': if mode == 'train':
# 训练时的transforms,包含数据增强 # 训练时的transforms,包含数据增强
transforms = [ transforms = [
RandomCrop(crop_size=width), RandomHorizontalFlip(prob=0.5), RandomCrop(crop_size=width), Normalize(
Normalize(
mean=mean, std=std) mean=mean, std=std)
] ]
if random_horizontal_flip:
transforms.insert(0, RandomHorizontalFlip())
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
transforms = [ transforms = [
......
...@@ -160,7 +160,9 @@ class Compose(DetTransform): ...@@ -160,7 +160,9 @@ class Compose(DetTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
logging.error("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) logging.error(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
...@@ -220,15 +222,13 @@ class ResizeByShort(DetTransform): ...@@ -220,15 +222,13 @@ class ResizeByShort(DetTransform):
im_short_size = min(im.shape[0], im.shape[1]) im_short_size = min(im.shape[0], im.shape[1])
im_long_size = max(im.shape[0], im.shape[1]) im_long_size = max(im.shape[0], im.shape[1])
scale = float(self.short_size) / im_short_size scale = float(self.short_size) / im_short_size
if self.max_size > 0 and np.round(scale * if self.max_size > 0 and np.round(scale * im_long_size) > self.max_size:
im_long_size) > self.max_size:
scale = float(self.max_size) / float(im_long_size) scale = float(self.max_size) / float(im_long_size)
resized_width = int(round(im.shape[1] * scale)) resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale)) resized_height = int(round(im.shape[0] * scale))
im_resize_info = [resized_height, resized_width, scale] im_resize_info = [resized_height, resized_width, scale]
im = cv2.resize( im = cv2.resize(
im, (resized_width, resized_height), im, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)
interpolation=cv2.INTER_LINEAR)
im_info['im_resize_info'] = np.array(im_resize_info).astype(np.float32) im_info['im_resize_info'] = np.array(im_resize_info).astype(np.float32)
if label_info is None: if label_info is None:
return (im, im_info) return (im, im_info)
...@@ -268,8 +268,7 @@ class Padding(DetTransform): ...@@ -268,8 +268,7 @@ class Padding(DetTransform):
if not isinstance(target_size, tuple) and not isinstance( if not isinstance(target_size, tuple) and not isinstance(
target_size, list): target_size, list):
raise TypeError( raise TypeError(
"Padding: Type of target_size must in (int|list|tuple)." "Padding: Type of target_size must in (int|list|tuple).")
)
elif len(target_size) != 2: elif len(target_size) != 2:
raise ValueError( raise ValueError(
"Padding: Length of target_size must equal 2.") "Padding: Length of target_size must equal 2.")
...@@ -454,8 +453,7 @@ class RandomHorizontalFlip(DetTransform): ...@@ -454,8 +453,7 @@ class RandomHorizontalFlip(DetTransform):
ValueError: 数据长度不匹配。 ValueError: 数据长度不匹配。
""" """
if not isinstance(im, np.ndarray): if not isinstance(im, np.ndarray):
raise TypeError( raise TypeError("RandomHorizontalFlip: image is not a numpy array.")
"RandomHorizontalFlip: image is not a numpy array.")
if len(im.shape) != 3: if len(im.shape) != 3:
raise ValueError( raise ValueError(
"RandomHorizontalFlip: image is not 3-dimensional.") "RandomHorizontalFlip: image is not 3-dimensional.")
...@@ -736,7 +734,7 @@ class MixupImage(DetTransform): ...@@ -736,7 +734,7 @@ class MixupImage(DetTransform):
gt_poly2 = im_info['mixup'][2]['gt_poly'] gt_poly2 = im_info['mixup'][2]['gt_poly']
is_crowd1 = label_info['is_crowd'] is_crowd1 = label_info['is_crowd']
is_crowd2 = im_info['mixup'][2]['is_crowd'] is_crowd2 = im_info['mixup'][2]['is_crowd']
if 0 not in gt_class1 and 0 not in gt_class2: if 0 not in gt_class1 and 0 not in gt_class2:
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0) gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class = np.concatenate((gt_class1, gt_class2), axis=0) gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
...@@ -785,9 +783,7 @@ class RandomExpand(DetTransform): ...@@ -785,9 +783,7 @@ class RandomExpand(DetTransform):
fill_value (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。 fill_value (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。
""" """
def __init__(self, def __init__(self, ratio=4., prob=0.5,
ratio=4.,
prob=0.5,
fill_value=[123.675, 116.28, 103.53]): fill_value=[123.675, 116.28, 103.53]):
super(RandomExpand, self).__init__() super(RandomExpand, self).__init__()
assert ratio > 1.01, "expand ratio must be larger than 1.01" assert ratio > 1.01, "expand ratio must be larger than 1.01"
...@@ -1281,21 +1277,25 @@ class ComposedRCNNTransforms(Compose): ...@@ -1281,21 +1277,25 @@ class ComposedRCNNTransforms(Compose):
min_max_size(list): 图像在缩放时,最小边和最大边的约束条件 min_max_size(list): 图像在缩放时,最小边和最大边的约束条件
mean(list): 图像均值 mean(list): 图像均值
std(list): 图像方差 std(list): 图像方差
random_horizontal_flip(bool): 是否以0.5的概率使用随机水平翻转增强,该仅在mode为`train`时生效,默认为True
""" """
def __init__(self, def __init__(self,
mode, mode,
min_max_size=[800, 1333], min_max_size=[800, 1333],
mean=[0.485, 0.456, 0.406], mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]): std=[0.229, 0.224, 0.225],
random_horizontal_flip=True):
if mode == 'train': if mode == 'train':
# 训练时的transforms,包含数据增强 # 训练时的transforms,包含数据增强
transforms = [ transforms = [
RandomHorizontalFlip(prob=0.5), Normalize( Normalize(
mean=mean, std=std), ResizeByShort( mean=mean, std=std), ResizeByShort(
short_size=min_max_size[0], max_size=min_max_size[1]), short_size=min_max_size[0], max_size=min_max_size[1]),
Padding(coarsest_stride=32) Padding(coarsest_stride=32)
] ]
if random_horizontal_flip:
transforms.insert(0, RandomHorizontalFlip())
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
transforms = [ transforms = [
...@@ -1325,9 +1325,14 @@ class ComposedYOLOv3Transforms(Compose): ...@@ -1325,9 +1325,14 @@ class ComposedYOLOv3Transforms(Compose):
Args: Args:
mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test' mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
shape(list): 输入模型中图像的大小,输入模型的图像会被Resize成此大小 shape(list): 输入模型中图像的大小,输入模型的图像会被Resize成此大小
mixup_epoch(int): 模型训练过程中,前mixup_epoch会使用mixup策略 mixup_epoch(int): 模型训练过程中,前mixup_epoch会使用mixup策略, 若设为-1,则表示不使用该策略
mean(list): 图像均值 mean(list): 图像均值
std(list): 图像方差 std(list): 图像方差
random_distort(bool): 数据增强方式,参数仅在mode为`train`时生效,表示是否在训练过程中随机扰动图像,默认为True
random_expand(bool): 数据增强方式,参数仅在mode为`train`时生效,表示是否在训练过程中随机扩张图像,默认为True
random_crop(bool): 数据增强方式,参数仅在mode为`train`时生效,表示是否在训练过程中随机裁剪图像,默认为True
random_horizontal_flip(bool): 数据增强方式,参数仅在mode为`train`时生效,表示是否在训练过程中随机水平翻转图像,默认为True
""" """
def __init__(self, def __init__(self,
...@@ -1335,7 +1340,11 @@ class ComposedYOLOv3Transforms(Compose): ...@@ -1335,7 +1340,11 @@ class ComposedYOLOv3Transforms(Compose):
shape=[608, 608], shape=[608, 608],
mixup_epoch=250, mixup_epoch=250,
mean=[0.485, 0.456, 0.406], mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]): std=[0.229, 0.224, 0.225],
random_distort=True,
random_expand=True,
random_crop=True,
random_horizontal_flip=True):
width = shape width = shape
if isinstance(shape, list): if isinstance(shape, list):
if shape[0] != shape[1]: if shape[0] != shape[1]:
...@@ -1350,12 +1359,18 @@ class ComposedYOLOv3Transforms(Compose): ...@@ -1350,12 +1359,18 @@ class ComposedYOLOv3Transforms(Compose):
if mode == 'train': if mode == 'train':
# 训练时的transforms,包含数据增强 # 训练时的transforms,包含数据增强
transforms = [ transforms = [
MixupImage(mixup_epoch=mixup_epoch), RandomDistort(), MixupImage(mixup_epoch=mixup_epoch), Resize(
RandomExpand(), RandomCrop(), Resize( target_size=width, interp='RANDOM'), Normalize(
target_size=width,
interp='RANDOM'), RandomHorizontalFlip(), Normalize(
mean=mean, std=std) mean=mean, std=std)
] ]
if random_horizontal_flip:
transforms.insert(1, RandomHorizontalFlip())
if random_crop:
transforms.insert(1, RandomCrop())
if random_expand:
transforms.insert(1, RandomExpand())
if random_distort:
transforms.insert(1, RandomDistort())
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
transforms = [ transforms = [
......
...@@ -116,7 +116,9 @@ class Compose(SegTransform): ...@@ -116,7 +116,9 @@ class Compose(SegTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
logging.error("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) logging.error(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
...@@ -401,8 +403,7 @@ class ResizeByShort(SegTransform): ...@@ -401,8 +403,7 @@ class ResizeByShort(SegTransform):
im_short_size = min(im.shape[0], im.shape[1]) im_short_size = min(im.shape[0], im.shape[1])
im_long_size = max(im.shape[0], im.shape[1]) im_long_size = max(im.shape[0], im.shape[1])
scale = float(self.short_size) / im_short_size scale = float(self.short_size) / im_short_size
if self.max_size > 0 and np.round(scale * if self.max_size > 0 and np.round(scale * im_long_size) > self.max_size:
im_long_size) > self.max_size:
scale = float(self.max_size) / float(im_long_size) scale = float(self.max_size) / float(im_long_size)
resized_width = int(round(im.shape[1] * scale)) resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale)) resized_height = int(round(im.shape[0] * scale))
...@@ -1113,25 +1114,35 @@ class ComposedSegTransforms(Compose): ...@@ -1113,25 +1114,35 @@ class ComposedSegTransforms(Compose):
Args: Args:
mode(str): 图像处理所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test' mode(str): 图像处理所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
train_crop_size(list): 模型训练阶段,随机从原图crop的大小 min_max_size(list): 训练过程中,图像的最长边会随机resize至此区间(短边按比例相应resize);预测阶段,图像最长边会resize至此区间中间值,即(min_size+max_size)/2。默认为[400, 600]
train_crop_size(list): 仅在mode为'train`时生效,训练过程中,随机从图像中裁剪出对应大小的子图(如若原图小于此大小,则会padding到此大小),默认为[400, 600]
mean(list): 图像均值 mean(list): 图像均值
std(list): 图像方差 std(list): 图像方差
random_horizontal_flip(bool): 数据增强方式,仅在mode为`train`时生效,表示训练过程是否随机水平翻转图像,默认为True
""" """
def __init__(self, def __init__(self,
mode, mode,
train_crop_size=[769, 769], min_max_size=[400, 600],
train_crop_size=[512, 512],
mean=[0.5, 0.5, 0.5], mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]): std=[0.5, 0.5, 0.5],
random_horizontal_flip=True):
if mode == 'train': if mode == 'train':
# 训练时的transforms,包含数据增强 # 训练时的transforms,包含数据增强
transforms = [ transforms = [
RandomHorizontalFlip(prob=0.5), ResizeStepScaling(), ResizeRangeScaling(
min_value=min(min_max_size), max_value=max(min_max_size)),
RandomPaddingCrop(crop_size=train_crop_size), Normalize( RandomPaddingCrop(crop_size=train_crop_size), Normalize(
mean=mean, std=std) mean=mean, std=std)
] ]
if random_horizontal_flip:
transforms.insert(0, RandomHorizontalFlip())
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
transforms = [Normalize(mean=mean, std=std)] long_size = (min(min_max_size) + max(min_max_size)) // 2
transforms = [
ResizeByLong(long_size=long_size), Normalize(
mean=mean, std=std)
]
super(ComposedSegTransforms, self).__init__(transforms) super(ComposedSegTransforms, self).__init__(transforms)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册