提交 2934885a 编写于 作者: J jiangjiajun

modify ComposedTransforms

上级 421266a7
...@@ -18,7 +18,6 @@ import random ...@@ -18,7 +18,6 @@ import random
import os.path as osp import os.path as osp
import numpy as np import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from .template import TemplateTransforms
class ClsTransform: class ClsTransform:
...@@ -93,6 +92,12 @@ class Compose(ClsTransform): ...@@ -93,6 +92,12 @@ class Compose(ClsTransform):
outputs = (im, label) outputs = (im, label)
return outputs return outputs
def add_augmenters(self, augmenters):
if not isinstance(augmenters, list):
raise Exception(
"augmenters should be list type in func add_augmenters()")
self.transforms = augmenters + self.transforms.transforms
class RandomCrop(ClsTransform): class RandomCrop(ClsTransform):
"""对图像进行随机剪裁,模型训练时的数据增强操作。 """对图像进行随机剪裁,模型训练时的数据增强操作。
...@@ -464,7 +469,7 @@ class ArrangeClassifier(ClsTransform): ...@@ -464,7 +469,7 @@ class ArrangeClassifier(ClsTransform):
return outputs return outputs
class BasicClsTransforms(TemplateTransforms): class ComposedClsTransforms(Compose):
""" 分类模型的基础Transforms流程,具体如下 """ 分类模型的基础Transforms流程,具体如下
训练阶段: 训练阶段:
1. 随机从图像中crop一块子图,并resize成crop_size大小 1. 随机从图像中crop一块子图,并resize成crop_size大小
...@@ -487,7 +492,6 @@ class BasicClsTransforms(TemplateTransforms): ...@@ -487,7 +492,6 @@ class BasicClsTransforms(TemplateTransforms):
crop_size=[224, 224], crop_size=[224, 224],
mean=[0.485, 0.456, 0.406], mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]): std=[0.229, 0.224, 0.225]):
super(TemplateClsTransforms, self).__init__(mode=mode)
width = crop_size width = crop_size
if isinstance(crop_size, list): if isinstance(crop_size, list):
if shape[0] != shape[1]: if shape[0] != shape[1]:
...@@ -499,18 +503,19 @@ class BasicClsTransforms(TemplateTransforms): ...@@ -499,18 +503,19 @@ class BasicClsTransforms(TemplateTransforms):
"In classifier model, width and height should be multiple of 32, e.g 224、256、320...." "In classifier model, width and height should be multiple of 32, e.g 224、256、320...."
) )
if self.mode == 'train': if mode == 'train':
# 训练时的transforms,包含数据增强 # 训练时的transforms,包含数据增强
self.transforms = transforms.Compose([ transforms = [
transforms.RandomCrop(crop_size=width), RandomCrop(crop_size=width), RandomHorizontalFlip(prob=0.5),
transforms.RandomHorizontalFlip(prob=0.5), Normalize(
transforms.Normalize(
mean=mean, std=std) mean=mean, std=std)
]) ]
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
self.transforms = transforms.Compose([ transforms = [
transforms.ReiszeByShort(short_size=int(width * 1.14)), ReiszeByShort(short_size=int(width * 1.14)),
transforms.CenterCrop(crop_size=width), transforms.Normalize( CenterCrop(crop_size=width), Normalize(
mean=mean, std=std) mean=mean, std=std)
]) ]
super(ComposedClsTransforms, self).__init__(transforms)
...@@ -27,7 +27,6 @@ from PIL import Image, ImageEnhance ...@@ -27,7 +27,6 @@ from PIL import Image, ImageEnhance
from .imgaug_support import execute_imgaug from .imgaug_support import execute_imgaug
from .ops import * from .ops import *
from .box_utils import * from .box_utils import *
from .template import TemplateTransforms
class DetTransform: class DetTransform:
...@@ -153,6 +152,13 @@ class Compose(DetTransform): ...@@ -153,6 +152,13 @@ class Compose(DetTransform):
outputs = (im, im_info) outputs = (im, im_info)
return outputs return outputs
def add_augmenters(self, augmenters):
if not isinstance(augmenters, list):
raise Exception(
"augmenters should be list type in func add_augmenters()")
assert mode == 'train', "There should be exists augmenters while on train mode"
self.transforms = augmenters + self.transforms.transforms
class ResizeByShort(DetTransform): class ResizeByShort(DetTransform):
"""根据图像的短边调整图像大小(resize)。 """根据图像的短边调整图像大小(resize)。
...@@ -1230,7 +1236,7 @@ class ArrangeYOLOv3(DetTransform): ...@@ -1230,7 +1236,7 @@ class ArrangeYOLOv3(DetTransform):
return outputs return outputs
class BasicRCNNTransforms(TemplateTransforms): class ComposedRCNNTransforms(Compose):
""" RCNN模型(faster-rcnn/mask-rcnn)图像处理流程,具体如下, """ RCNN模型(faster-rcnn/mask-rcnn)图像处理流程,具体如下,
训练阶段: 训练阶段:
1. 随机以0.5的概率将图像水平翻转 1. 随机以0.5的概率将图像水平翻转
...@@ -1257,27 +1263,27 @@ class BasicRCNNTransforms(TemplateTransforms): ...@@ -1257,27 +1263,27 @@ class BasicRCNNTransforms(TemplateTransforms):
min_max_size=[800, 1333], min_max_size=[800, 1333],
mean=[0.485, 0.456, 0.406], mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]): std=[0.229, 0.224, 0.225]):
super(RCNNTransforms, self).__init__(mode=mode) if mode == 'train':
if self.mode == 'train':
# 训练时的transforms,包含数据增强 # 训练时的transforms,包含数据增强
self.transforms = transforms.Compose([ transforms = [
transforms.RandomHorizontalFlip(prob=0.5), RandomHorizontalFlip(prob=0.5), Normalize(
transforms.Normalize( mean=mean, std=std), ResizeByShort(
mean=mean, std=std), transforms.ResizeByShort(
short_size=min_max_size[0], max_size=min_max_size[1]), short_size=min_max_size[0], max_size=min_max_size[1]),
transforms.Padding(coarsest_stride=32) Padding(coarsest_stride=32)
]) ]
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
self.transforms = transforms.Compose([ transforms = [
transforms.Normalize( Normalize(
mean=mean, std=std), transforms.ResizeByShort( mean=mean, std=std), ResizeByShort(
short_size=min_max_size[0], max_size=min_max_size[1]), short_size=min_max_size[0], max_size=min_max_size[1]),
transforms.Padding(coarsest_stride=32) Padding(coarsest_stride=32)
]) ]
super(RCNNTransforms, self).__init__(transforms)
class BasicYOLOTransforms(TemplateTransforms):
class ComposedYOLOTransforms(Compose):
"""YOLOv3模型的图像预处理流程,具体如下, """YOLOv3模型的图像预处理流程,具体如下,
训练阶段: 训练阶段:
1. 在前mixup_epoch轮迭代中,使用MixupImage策略,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#mixupimage 1. 在前mixup_epoch轮迭代中,使用MixupImage策略,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#mixupimage
...@@ -1305,7 +1311,6 @@ class BasicYOLOTransforms(TemplateTransforms): ...@@ -1305,7 +1311,6 @@ class BasicYOLOTransforms(TemplateTransforms):
mixup_epoch=250, mixup_epoch=250,
mean=[0.485, 0.456, 0.406], mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]): std=[0.229, 0.224, 0.225]):
super(YOLOTransforms, self).__init__(mode=mode)
width = shape width = shape
if isinstance(shape, list): if isinstance(shape, list):
if shape[0] != shape[1]: if shape[0] != shape[1]:
...@@ -1317,20 +1322,20 @@ class BasicYOLOTransforms(TemplateTransforms): ...@@ -1317,20 +1322,20 @@ class BasicYOLOTransforms(TemplateTransforms):
"In YOLOv3 model, width and height should be multiple of 32, e.g 224、256、320...." "In YOLOv3 model, width and height should be multiple of 32, e.g 224、256、320...."
) )
if self.mode == 'train': if mode == 'train':
# 训练时的transforms,包含数据增强 # 训练时的transforms,包含数据增强
self.transforms = transforms.Compose([ transforms = [
transforms.MixupImage(mixup_epoch=mixup_epoch), MixupImage(mixup_epoch=mixup_epoch), RandomDistort(),
transforms.RandomDistort(), transforms.RandomExpand(), RandomExpand(), RandomCrop(), Resize(
transforms.RandomCrop(), transforms.Resize( target_size=width,
target_size=width, interp='RANDOM'), interp='RANDOM'), RandomHorizontalFlip(), Normalize(
transforms.RandomHorizontalFlip(), transforms.Normalize(
mean=mean, std=std) mean=mean, std=std)
]) ]
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
self.transforms = transforms.Compose([ transforms = [
transforms.Resize( Resize(
target_size=width, interp='CUBIC'), transforms.Normalize( target_size=width, interp='CUBIC'), Normalize(
mean=mean, std=std) mean=mean, std=std)
]) ]
super(YOLOTransforms, self).__init__(transforms)
...@@ -1091,7 +1091,7 @@ class ArrangeSegmenter(SegTransform): ...@@ -1091,7 +1091,7 @@ class ArrangeSegmenter(SegTransform):
return (im, ) return (im, )
class BasicSegTransforms(TemplateTransforms): class ComposedTransforms(Compose):
""" 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下 """ 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下
训练阶段: 训练阶段:
1. 随机对图像以0.5的概率水平翻转 1. 随机对图像以0.5的概率水平翻转
...@@ -1113,18 +1113,15 @@ class BasicSegTransforms(TemplateTransforms): ...@@ -1113,18 +1113,15 @@ class BasicSegTransforms(TemplateTransforms):
train_crop_size=[769, 769], train_crop_size=[769, 769],
mean=[0.5, 0.5, 0.5], mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]): std=[0.5, 0.5, 0.5]):
super(TemplateSegTransforms, self).__init__(mode=mode)
if self.mode == 'train': if self.mode == 'train':
# 训练时的transforms,包含数据增强 # 训练时的transforms,包含数据增强
self.transforms = transforms.Compose([ transforms = [
transforms.RandomHorizontalFlip(), RandomHorizontalFlip(prob=0.5), ResizeStepScaling(),
transforms.ResizeStepScaling(), RandomPaddingCrop(crop_size=train_crop_size), Normalize(
transforms.RandomPaddingCrop(crop_size=train_crop_size),
transforms.Normalize(
mean=mean, std=std) mean=mean, std=std)
]) ]
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
self.transforms = transforms.Compose( transforms = [transforms.Normalize(mean=mean, std=std)]
[transforms.Normalize(
mean=mean, std=std)]) super(ComposedSegTransforms, self).__init__(transforms)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册