提交 2934885a 编写于 作者: J jiangjiajun

modify ComposedTransforms

上级 421266a7
......@@ -18,7 +18,6 @@ import random
import os.path as osp
import numpy as np
from PIL import Image, ImageEnhance
from .template import TemplateTransforms
class ClsTransform:
......@@ -93,6 +92,12 @@ class Compose(ClsTransform):
outputs = (im, label)
return outputs
def add_augmenters(self, augmenters):
if not isinstance(augmenters, list):
raise Exception(
"augmenters should be list type in func add_augmenters()")
self.transforms = augmenters + self.transforms.transforms
class RandomCrop(ClsTransform):
"""对图像进行随机剪裁,模型训练时的数据增强操作。
......@@ -464,7 +469,7 @@ class ArrangeClassifier(ClsTransform):
return outputs
class BasicClsTransforms(TemplateTransforms):
class ComposedClsTransforms(Compose):
""" 分类模型的基础Transforms流程,具体如下
训练阶段:
1. 随机从图像中crop一块子图,并resize成crop_size大小
......@@ -487,7 +492,6 @@ class BasicClsTransforms(TemplateTransforms):
crop_size=[224, 224],
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
super(TemplateClsTransforms, self).__init__(mode=mode)
width = crop_size
if isinstance(crop_size, list):
if shape[0] != shape[1]:
......@@ -499,18 +503,19 @@ class BasicClsTransforms(TemplateTransforms):
"In classifier model, width and height should be multiple of 32, e.g 224、256、320...."
)
if self.mode == 'train':
if mode == 'train':
# 训练时的transforms,包含数据增强
self.transforms = transforms.Compose([
transforms.RandomCrop(crop_size=width),
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(
transforms = [
RandomCrop(crop_size=width), RandomHorizontalFlip(prob=0.5),
Normalize(
mean=mean, std=std)
])
]
else:
# 验证/预测时的transforms
self.transforms = transforms.Compose([
transforms.ReiszeByShort(short_size=int(width * 1.14)),
transforms.CenterCrop(crop_size=width), transforms.Normalize(
transforms = [
ReiszeByShort(short_size=int(width * 1.14)),
CenterCrop(crop_size=width), Normalize(
mean=mean, std=std)
])
]
super(ComposedClsTransforms, self).__init__(transforms)
......@@ -27,7 +27,6 @@ from PIL import Image, ImageEnhance
from .imgaug_support import execute_imgaug
from .ops import *
from .box_utils import *
from .template import TemplateTransforms
class DetTransform:
......@@ -153,6 +152,13 @@ class Compose(DetTransform):
outputs = (im, im_info)
return outputs
def add_augmenters(self, augmenters):
if not isinstance(augmenters, list):
raise Exception(
"augmenters should be list type in func add_augmenters()")
assert mode == 'train', "There should be exists augmenters while on train mode"
self.transforms = augmenters + self.transforms.transforms
class ResizeByShort(DetTransform):
"""根据图像的短边调整图像大小(resize)。
......@@ -1230,7 +1236,7 @@ class ArrangeYOLOv3(DetTransform):
return outputs
class BasicRCNNTransforms(TemplateTransforms):
class ComposedRCNNTransforms(Compose):
""" RCNN模型(faster-rcnn/mask-rcnn)图像处理流程,具体如下,
训练阶段:
1. 随机以0.5的概率将图像水平翻转
......@@ -1257,27 +1263,27 @@ class BasicRCNNTransforms(TemplateTransforms):
min_max_size=[800, 1333],
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
super(RCNNTransforms, self).__init__(mode=mode)
if self.mode == 'train':
if mode == 'train':
# 训练时的transforms,包含数据增强
self.transforms = transforms.Compose([
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(
mean=mean, std=std), transforms.ResizeByShort(
transforms = [
RandomHorizontalFlip(prob=0.5), Normalize(
mean=mean, std=std), ResizeByShort(
short_size=min_max_size[0], max_size=min_max_size[1]),
transforms.Padding(coarsest_stride=32)
])
Padding(coarsest_stride=32)
]
else:
# 验证/预测时的transforms
self.transforms = transforms.Compose([
transforms.Normalize(
mean=mean, std=std), transforms.ResizeByShort(
transforms = [
Normalize(
mean=mean, std=std), ResizeByShort(
short_size=min_max_size[0], max_size=min_max_size[1]),
transforms.Padding(coarsest_stride=32)
])
Padding(coarsest_stride=32)
]
super(RCNNTransforms, self).__init__(transforms)
class BasicYOLOTransforms(TemplateTransforms):
class ComposedYOLOTransforms(Compose):
"""YOLOv3模型的图像预处理流程,具体如下,
训练阶段:
1. 在前mixup_epoch轮迭代中,使用MixupImage策略,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#mixupimage
......@@ -1305,7 +1311,6 @@ class BasicYOLOTransforms(TemplateTransforms):
mixup_epoch=250,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
super(YOLOTransforms, self).__init__(mode=mode)
width = shape
if isinstance(shape, list):
if shape[0] != shape[1]:
......@@ -1317,20 +1322,20 @@ class BasicYOLOTransforms(TemplateTransforms):
"In YOLOv3 model, width and height should be multiple of 32, e.g 224、256、320...."
)
if self.mode == 'train':
if mode == 'train':
# 训练时的transforms,包含数据增强
self.transforms = transforms.Compose([
transforms.MixupImage(mixup_epoch=mixup_epoch),
transforms.RandomDistort(), transforms.RandomExpand(),
transforms.RandomCrop(), transforms.Resize(
target_size=width, interp='RANDOM'),
transforms.RandomHorizontalFlip(), transforms.Normalize(
transforms = [
MixupImage(mixup_epoch=mixup_epoch), RandomDistort(),
RandomExpand(), RandomCrop(), Resize(
target_size=width,
interp='RANDOM'), RandomHorizontalFlip(), Normalize(
mean=mean, std=std)
])
]
else:
# 验证/预测时的transforms
self.transforms = transforms.Compose([
transforms.Resize(
target_size=width, interp='CUBIC'), transforms.Normalize(
transforms = [
Resize(
target_size=width, interp='CUBIC'), Normalize(
mean=mean, std=std)
])
]
super(YOLOTransforms, self).__init__(transforms)
......@@ -1091,7 +1091,7 @@ class ArrangeSegmenter(SegTransform):
return (im, )
class BasicSegTransforms(TemplateTransforms):
class ComposedTransforms(Compose):
""" 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下
训练阶段:
1. 随机对图像以0.5的概率水平翻转
......@@ -1113,18 +1113,15 @@ class BasicSegTransforms(TemplateTransforms):
train_crop_size=[769, 769],
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]):
super(TemplateSegTransforms, self).__init__(mode=mode)
if self.mode == 'train':
# 训练时的transforms,包含数据增强
self.transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ResizeStepScaling(),
transforms.RandomPaddingCrop(crop_size=train_crop_size),
transforms.Normalize(
transforms = [
RandomHorizontalFlip(prob=0.5), ResizeStepScaling(),
RandomPaddingCrop(crop_size=train_crop_size), Normalize(
mean=mean, std=std)
])
]
else:
# 验证/预测时的transforms
self.transforms = transforms.Compose(
[transforms.Normalize(
mean=mean, std=std)])
transforms = [transforms.Normalize(mean=mean, std=std)]
super(ComposedSegTransforms, self).__init__(transforms)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册