提交 e6e3da6f 编写于 作者: J jiangjiajun

fix composed transforms

上级 a98efcb4
......@@ -232,12 +232,12 @@ eval_transforms = transforms.Composed([
```
## ComposedYOLOTransforms类
## ComposedYOLOv3Transforms类
```python
paddlex.det.transforms.ComposedYOLOTransforms(mode, shape=[608, 608], mixup_epoch=250, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
paddlex.det.transforms.ComposedYOLOv3Transforms(mode, shape=[608, 608], mixup_epoch=250, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
```
目标检测YOLOv3模型中已经组合好的数据处理流程,开发者可以直接使用ComposedYOLOTransforms,简化手动组合transforms的过程, 该类中已经包含了[MixupImage](#MixupImage)、[RandomDistort](#RandomDistort)、[RandomExpand](#RandomExpand)、[RandomCrop](#RandomCrop)、[RandomHorizontalFlip](#RandomHorizontalFlip)5种数据增强方式,你仍可以通过[add_augmenters函数接口](#add_augmenters)添加新的数据增强方式。
ComposedYOLOTransforms共包括以下几个步骤:
目标检测YOLOv3模型中已经组合好的数据处理流程,开发者可以直接使用ComposedYOLOv3Transforms,简化手动组合transforms的过程, 该类中已经包含了[MixupImage](#MixupImage)、[RandomDistort](#RandomDistort)、[RandomExpand](#RandomExpand)、[RandomCrop](#RandomCrop)、[RandomHorizontalFlip](#RandomHorizontalFlip)5种数据增强方式,你仍可以通过[add_augmenters函数接口](#add_augmenters)添加新的数据增强方式。
ComposedYOLOv3Transforms共包括以下几个步骤:
> 训练阶段:
> > 1. 在前mixup_epoch轮迭代中,使用MixupImage策略
> > 2. 对图像进行随机扰动,包括亮度,对比度,饱和度和色调
......@@ -259,7 +259,7 @@ ComposedYOLOTransforms共包括以下几个步骤:
### 添加数据增强方式
```python
ComposedYOLOTransforms.add_augmenters(augmenters)
ComposedYOLOv3Transforms.add_augmenters(augmenters)
```
> **参数**
> * **augmenters**(list): 数据增强方式列表
......@@ -268,8 +268,8 @@ ComposedYOLOTransforms.add_augmenters(augmenters)
```
import paddlex as pdx
from paddlex.det import transforms
train_transforms = transforms.ComposedYOLOTransforms(mode='train', shape=[480, 480])
eval_transforms = transforms.ComposedYOLOTransforms(mode='eval', shape=[480, 480])
train_transforms = transforms.ComposedYOLOv3Transforms(mode='train', shape=[480, 480])
eval_transforms = transforms.ComposedYOLOv3Transforms(mode='eval', shape=[480, 480])
# 添加数据增强
import imgaug.augmenters as iaa
......
......@@ -10,18 +10,12 @@ veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
pdx.utils.download_and_decompress(veg_dataset, path='./')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.RandomCrop(crop_size=224),
transforms.RandomHorizontalFlip(),
transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.ResizeByShort(short_size=256),
transforms.CenterCrop(crop_size=224),
transforms.Normalize()
])
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/cls_transforms.html#composedclstransforms
train_transforms = transforms.ComposedClsTransforms(mode='train', crop_size=[224, 224])
eval_transforms = transforms.ComposedClsTransforms(mode='eval', crop_size=[224, 224])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/classification.html#imagenet
train_dataset = pdx.datasets.ImageNet(
data_dir='vegetables_cls',
file_list='vegetables_cls/train_list.txt',
......@@ -39,6 +33,8 @@ eval_dataset = pdx.datasets.ImageNet(
# VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/classification.html#resnet50
model = pdx.cls.MobileNetV2(num_classes=len(train_dataset.labels))
model.train(
num_epochs=10,
......
......@@ -11,16 +11,12 @@ veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
pdx.utils.download_and_decompress(veg_dataset, path='./')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose(
[transforms.RandomCrop(crop_size=224),
transforms.Normalize()])
eval_transforms = transforms.Compose([
transforms.ResizeByShort(short_size=256),
transforms.CenterCrop(crop_size=224),
transforms.Normalize()
])
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/cls_transforms.html#composedclstransforms
train_transforms = transforms.ComposedClsTransforms(mode='train', crop_size=[224, 224])
eval_transforms = transforms.ComposedClsTransforms(mode='eval', crop_size=[224, 224])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/classification.html#imagenet
train_dataset = pdx.datasets.ImageNet(
data_dir='vegetables_cls',
file_list='vegetables_cls/train_list.txt',
......@@ -47,6 +43,8 @@ optimizer = fluid.optimizer.Momentum(
# VisualDL启动方式: visualdl --logdir output/resnet50/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/classification.html#resnet50
model = pdx.cls.ResNet50(num_classes=len(train_dataset.labels))
model.train(
num_epochs=10,
......
......@@ -10,20 +10,12 @@ insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
pdx.utils.download_and_decompress(insect_dataset, path='./')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Normalize(),
transforms.ResizeByShort(short_size=800, max_size=1333),
transforms.Padding(coarsest_stride=32)
])
eval_transforms = transforms.Compose([
transforms.Normalize(),
transforms.ResizeByShort(short_size=800, max_size=1333),
transforms.Padding(coarsest_stride=32),
])
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#composedrcnntransforms
train_transforms = transforms.ComposedRCNNTransforms(mode='train', min_max_size=[800, 1333])
eval_transforms = transforms.ComposedRCNNTransforms(mode='eval', min_max_size=[800, 1333])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/detection.html#vocdetection
train_dataset = pdx.datasets.VOCDetection(
data_dir='insect_det',
file_list='insect_det/train_list.txt',
......@@ -42,6 +34,8 @@ eval_dataset = pdx.datasets.VOCDetection(
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# num_classes 需要设置为包含背景类的类别数,即: 目标类别数量 + 1
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/detection.html#fasterrcnn
num_classes = len(train_dataset.labels) + 1
model = pdx.det.FasterRCNN(num_classes=num_classes)
model.train(
......
......@@ -10,20 +10,12 @@ xiaoduxiong_dataset = 'https://bj.bcebos.com/paddlex/datasets/xiaoduxiong_ins_de
pdx.utils.download_and_decompress(xiaoduxiong_dataset, path='./')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Normalize(),
transforms.ResizeByShort(short_size=800, max_size=1333),
transforms.Padding(coarsest_stride=32)
])
eval_transforms = transforms.Compose([
transforms.Normalize(),
transforms.ResizeByShort(short_size=800, max_size=1333),
transforms.Padding(coarsest_stride=32)
])
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#composedrcnntransforms
train_transforms = transforms.ComposedRCNNTransforms(mode='train', min_max_size=[800, 1333])
eval_transforms = transforms.ComposedRCNNTransforms(mode='eval', min_max_size=[800, 1333])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/detection.html#cocodetection
train_dataset = pdx.datasets.CocoDetection(
data_dir='xiaoduxiong_ins_det/JPEGImages',
ann_file='xiaoduxiong_ins_det/train.json',
......@@ -40,6 +32,8 @@ eval_dataset = pdx.datasets.CocoDetection(
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# num_classes 需要设置为包含背景类的类别数,即: 目标类别数量 + 1
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/instance_segmentation.html#maskrcnn
num_classes = len(train_dataset.labels) + 1
model = pdx.det.MaskRCNN(num_classes=num_classes)
model.train(
......
......@@ -10,22 +10,12 @@ insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
pdx.utils.download_and_decompress(insect_dataset, path='./')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.MixupImage(mixup_epoch=250),
transforms.RandomDistort(),
transforms.RandomExpand(),
transforms.RandomCrop(),
transforms.Resize(target_size=608, interp='RANDOM'),
transforms.RandomHorizontalFlip(),
transforms.Normalize(),
])
eval_transforms = transforms.Compose([
transforms.Resize(target_size=608, interp='CUBIC'),
transforms.Normalize(),
])
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#composedyolotransforms
train_transforms = transforms.ComposedYOLOv3Transforms(mode='train', shape=[608, 608])
eval_transforms = transforms.ComposedYOLOv3Transforms(mode='eva', shape=[608, 608])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/detection.html#vocdetection
train_dataset = pdx.datasets.VOCDetection(
data_dir='insect_det',
file_list='insect_det/train_list.txt',
......@@ -43,6 +33,8 @@ eval_dataset = pdx.datasets.VOCDetection(
# VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/detection.html#yolov3
num_classes = len(train_dataset.labels)
model = pdx.det.YOLOv3(num_classes=num_classes, backbone='DarkNet53')
model.train(
......
......@@ -10,17 +10,16 @@ optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx.utils.download_and_decompress(optic_dataset, path='./')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(target_size=512),
transforms.RandomPaddingCrop(crop_size=500),
transforms.Normalize()
])
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/seg_transforms.html#composedsegtransforms
train_transforms = transforms.ComposedSegTransforms(mode='train', train_crop_size=[769, 769])
eval_transforms = transforms.ComposedSegTransforms(mode='eval')
eval_transforms = transforms.Compose(
[transforms.Resize(512), transforms.Normalize()])
train_transforms.add_augmenters([
transforms.RandomRotate()
])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/train_list.txt',
......@@ -38,6 +37,8 @@ eval_dataset = pdx.datasets.SegDataset(
# VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#deeplabv3p
num_classes = len(train_dataset.labels)
model = pdx.seg.DeepLabv3p(num_classes=num_classes)
model.train(
......
......@@ -10,17 +10,12 @@ optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx.utils.download_and_decompress(optic_dataset, path='./')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.ResizeByLong(long_size=512),
transforms.Padding(target_size=512), transforms.Normalize()
])
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/seg_transforms.html#composedsegtransforms
train_transforms = transforms.ComposedSegTransforms(mode='train', train_crop_size=[769, 769])
eval_transforms = transforms.ComposedSegTransforms(mode='eval')
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/train_list.txt',
......@@ -38,6 +33,8 @@ eval_dataset = pdx.datasets.SegDataset(
# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#hrnet
num_classes = len(train_dataset.labels)
model = pdx.seg.HRNet(num_classes=num_classes)
model.train(
......
......@@ -10,20 +10,12 @@ optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx.utils.download_and_decompress(optic_dataset, path='./')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ResizeRangeScaling(),
transforms.RandomPaddingCrop(crop_size=512),
transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.ResizeByLong(long_size=512),
transforms.Padding(target_size=512),
transforms.Normalize()
])
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/seg_transforms.html#composedsegtransforms
train_transforms = transforms.ComposedSegTransforms(mode='train', train_crop_size=[769, 769])
eval_transforms = transforms.ComposedSegTransforms(mode='eval')
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/train_list.txt',
......@@ -41,6 +33,8 @@ eval_dataset = pdx.datasets.SegDataset(
# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#unet
num_classes = len(train_dataset.labels)
model = pdx.seg.UNet(num_classes=num_classes)
model.train(
......
......@@ -337,7 +337,8 @@ class DeepLabv3p(BaseAPI):
for d in data:
padding_label = np.zeros(
(1, im_h, im_w)).astype('int64') + self.ignore_index
padding_label[:, :im_h, :im_w] = d[1]
_, label_h, label_w = d[1].shape
padding_label[:, :label_h, :label_w] = d[1]
labels.append(padding_label)
labels = np.array(labels)
......
......@@ -1287,7 +1287,7 @@ class ComposedRCNNTransforms(Compose):
super(ComposedRCNNTransforms, self).__init__(transforms)
class ComposedYOLOTransforms(Compose):
class ComposedYOLOv3Transforms(Compose):
"""YOLOv3模型的图像预处理流程,具体如下,
训练阶段:
1. 在前mixup_epoch轮迭代中,使用MixupImage策略,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#mixupimage
......@@ -1342,4 +1342,4 @@ class ComposedYOLOTransforms(Compose):
target_size=width, interp='CUBIC'), Normalize(
mean=mean, std=std)
]
super(ComposedYOLOTransforms, self).__init__(transforms)
super(ComposedYOLOv3Transforms, self).__init__(transforms)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册