提交 7e6b370f 编写于 作者: H haoyuying

revise transform and pascalvoc dataset

上级 1f2180b6
...@@ -3,20 +3,21 @@ import paddlehub as hub ...@@ -3,20 +3,21 @@ import paddlehub as hub
import paddle.nn as nn import paddle.nn as nn
from paddlehub.finetune.trainer import Trainer from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.pascalvoc import DetectionData from paddlehub.datasets.pascalvoc import DetectionData
from paddlehub.process.detect_transforms import Compose, RandomDistort, RandomExpand, RandomCrop, RandomFlip, Normalize, Resize, ShuffleBox import paddlehub.process.detect_transforms as T
if __name__ == "__main__": if __name__ == "__main__":
place = paddle.CUDAPlace(0) place = paddle.CUDAPlace(0)
paddle.disable_static() paddle.disable_static()
transform = Compose([
RandomDistort(), transform = T.Compose([
RandomExpand(fill=[0.485, 0.456, 0.406]), T.RandomDistort(),
RandomCrop(), T.RandomExpand(fill=[0.485, 0.456, 0.406]),
Resize(target_size=416), T.RandomCrop(),
RandomFlip(), T.Resize(target_size=416),
ShuffleBox(), T.RandomFlip(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) T.ShuffleBox(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) ])
train_reader = DetectionData(transform) train_reader = DetectionData(transform)
model = hub.Module(name='yolov3_darknet53_pascalvoc') model = hub.Module(name='yolov3_darknet53_pascalvoc')
model.train() model.train()
......
...@@ -7,7 +7,7 @@ from paddle.nn.initializer import Normal, Constant ...@@ -7,7 +7,7 @@ from paddle.nn.initializer import Normal, Constant
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from pycocotools.coco import COCO from pycocotools.coco import COCO
from paddlehub.module.cv_module import Yolov3Module from paddlehub.module.cv_module import Yolov3Module
from paddlehub.process.detect_transforms import Compose, RandomDistort, RandomExpand, RandomCrop, Resize, RandomFlip, ShuffleBox, Normalize import paddlehub.process.detect_transforms as T
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
...@@ -288,19 +288,19 @@ class YOLOv3(nn.Layer): ...@@ -288,19 +288,19 @@ class YOLOv3(nn.Layer):
def transform(self, img): def transform(self, img):
if self.is_train: if self.is_train:
transform = Compose([ transform = T.Compose([
RandomDistort(), T.RandomDistort(),
RandomExpand(fill=[0.485, 0.456, 0.406]), T.RandomExpand(fill=[0.485, 0.456, 0.406]),
RandomCrop(), T.RandomCrop(),
Resize(target_size=416), T.Resize(target_size=416),
RandomFlip(), T.RandomFlip(),
ShuffleBox(), T.ShuffleBox(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) ])
else: else:
transform = Compose([ transform = T.Compose([
Resize(target_size=416, interp='CUBIC'), T.Resize(target_size=416, interp='CUBIC'),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) ])
return transform(img) return transform(img)
......
...@@ -14,15 +14,137 @@ ...@@ -14,15 +14,137 @@
# limitations under the License. # limitations under the License.
import os import os
import copy
from typing import Callable from typing import Callable
import paddle import paddle
import numpy as np
from paddlehub.env import DATA_HOME from paddlehub.env import DATA_HOME
from pycocotools.coco import COCO from pycocotools.coco import COCO
from paddlehub.process.transforms import DetectCatagory, ParseImages from paddlehub.process.transforms import DetectCatagory, ParseImages
class DetectCatagory:
"""Load label name, id and map from detection dataset.
Args:
attrbox(Callable): Method to get detection attributes of images.
data_dir(str): Image dataset path.
Returns:
label_names(List(str)): The dataset label names.
label_ids(List(int)): The dataset label ids.
category_to_id_map(dict): Mapping relations of category and id for images.
"""
def __init__(self, attrbox: Callable, data_dir: str):
self.attrbox = attrbox
self.img_dir = data_dir
def __call__(self):
self.categories = self.attrbox.loadCats(self.attrbox.getCatIds())
self.num_category = len(self.categories)
label_names = []
label_ids = []
for category in self.categories:
label_names.append(category['name'])
label_ids.append(int(category['id']))
category_to_id_map = {v: i for i, v in enumerate(label_ids)}
return label_names, label_ids, category_to_id_map
class ParseImages:
"""Prepare images for detection.
Args:
attrbox(Callable): Method to get detection attributes of images.
data_dir(str): Image dataset path.
category_to_id_map(dict): Mapping relations of category and id for images.
Returns:
imgs(dict): The input for detection model, it is a dict.
"""
def __init__(self, attrbox: Callable, data_dir: str, category_to_id_map: dict):
self.attrbox = attrbox
self.img_dir = data_dir
self.category_to_id_map = category_to_id_map
self.parse_gt_annotations = GTAnotations(self.attrbox, self.category_to_id_map)
def __call__(self):
image_ids = self.attrbox.getImgIds()
image_ids.sort()
imgs = copy.deepcopy(self.attrbox.loadImgs(image_ids))
for img in imgs:
img['image'] = os.path.join(self.img_dir, img['file_name'])
assert os.path.exists(img['image']), "image {} not found.".format(img['image'])
box_num = 50
img['gt_boxes'] = np.zeros((box_num, 4), dtype=np.float32)
img['gt_labels'] = np.zeros((box_num), dtype=np.int32)
img = self.parse_gt_annotations(img)
return imgs
class GTAnotations:
"""Set gt boxes and gt labels for train.
Args:
attrbox(Callable): Method for get detection attributes for images.
category_to_id_map(dict): Mapping relations of category and id for images.
img(dict): Input for detection model.
Returns:
img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input.
"""
def __init__(self, attrbox: Callable, category_to_id_map: dict):
self.attrbox = attrbox
self.category_to_id_map = category_to_id_map
def box_to_center_relative(self, box: list, img_height: int, img_width: int) -> np.ndarray:
"""
Convert COCO annotations box with format [x1, y1, w, h] to
center mode [center_x, center_y, w, h] and divide image width
and height to get relative value in range[0, 1]
"""
assert len(box) == 4, "box should be a len(4) list or tuple"
x, y, w, h = box
x1 = max(x, 0)
x2 = min(x + w - 1, img_width - 1)
y1 = max(y, 0)
y2 = min(y + h - 1, img_height - 1)
x = (x1 + x2) / 2 / img_width
y = (y1 + y2) / 2 / img_height
w = (x2 - x1) / img_width
h = (y2 - y1) / img_height
return np.array([x, y, w, h])
def __call__(self, img: dict):
img_height = img['height']
img_width = img['width']
anno = self.attrbox.loadAnns(self.attrbox.getAnnIds(imgIds=img['id'], iscrowd=None))
gt_index = 0
for target in anno:
if target['area'] < -1:
continue
if 'ignore' in target and target['ignore']:
continue
box = self.box_to_center_relative(target['bbox'], img_height, img_width)
if box[2] <= 0 and box[3] <= 0:
continue
img['gt_boxes'][gt_index] = box
img['gt_labels'][gt_index] = \
self.category_to_id_map[target['category_id']]
gt_index += 1
if gt_index >= 50:
break
return img
class DetectionData(paddle.io.Dataset): class DetectionData(paddle.io.Dataset):
""" """
Dataset for image detection. Dataset for image detection.
...@@ -57,7 +179,7 @@ class DetectionData(paddle.io.Dataset): ...@@ -57,7 +179,7 @@ class DetectionData(paddle.io.Dataset):
parse_dataset_catagory = DetectCatagory(self.COCO, self.img_dir) parse_dataset_catagory = DetectCatagory(self.COCO, self.img_dir)
self.label_names, self.label_ids, self.category_to_id_map = parse_dataset_catagory() self.label_names, self.label_ids, self.category_to_id_map = parse_dataset_catagory()
parse_images = ParseImages(self.COCO, self.mode, self.img_dir, self.category_to_id_map) parse_images = ParseImages(self.COCO, self.img_dir, self.category_to_id_map)
self.data = parse_images() self.data = parse_images()
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
......
import copy
import os import os
import random import random
from typing import Callable from typing import Callable
...@@ -15,108 +14,10 @@ from paddlehub.process.functional import * ...@@ -15,108 +14,10 @@ from paddlehub.process.functional import *
matplotlib.use('Agg') matplotlib.use('Agg')
class DetectCatagory: class RandomDistort:
"""Load label name, id and map from detection dataset.
Args:
attrbox(Callable): Method to get detection attributes of images.
data_dir(str): Image dataset path.
Returns:
label_names(List(str)): The dataset label names.
label_ids(List(int)): The dataset label ids.
category_to_id_map(dict): Mapping relations of category and id for images.
"""
def __init__(self, attrbox: Callable, data_dir: str):
self.attrbox = attrbox
self.img_dir = data_dir
def __call__(self):
self.categories = self.attrbox.loadCats(self.attrbox.getCatIds())
self.num_category = len(self.categories)
label_names = []
label_ids = []
for category in self.categories:
label_names.append(category['name'])
label_ids.append(int(category['id']))
category_to_id_map = {v: i for i, v in enumerate(label_ids)}
return label_names, label_ids, category_to_id_map
class ParseImages:
"""Prepare images for detection.
Args:
attrbox(Callable): Method to get detection attributes of images.
is_train(bool): Select the mode for train or test.
data_dir(str): Image dataset path.
category_to_id_map(dict): Mapping relations of category and id for images.
Returns:
imgs(dict): The input for detection model, it is a dict.
""" """
def __init__(self, attrbox: Callable, data_dir: str, category_to_id_map: dict): Distort the input image randomly.
self.attrbox = attrbox
self.img_dir = data_dir
self.category_to_id_map = category_to_id_map
self.parse_gt_annotations = GTAnotations(self.attrbox, self.category_to_id_map)
def __call__(self):
image_ids = self.attrbox.getImgIds()
image_ids.sort()
imgs = copy.deepcopy(self.attrbox.loadImgs(image_ids))
for img in imgs:
img['image'] = os.path.join(self.img_dir, img['file_name'])
assert os.path.exists(img['image']), "image {} not found.".format(img['image'])
box_num = 50
img['gt_boxes'] = np.zeros((box_num, 4), dtype=np.float32)
img['gt_labels'] = np.zeros((box_num), dtype=np.int32)
img = self.parse_gt_annotations(img)
return imgs
class GTAnotations:
"""Set gt boxes and gt labels for train.
Args:
attrbox(Callable): Method for get detection attributes for images.
category_to_id_map(dict): Mapping relations of category and id for images.
img(dict): Input for detection model.
Returns:
img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input.
"""
def __init__(self, attrbox: Callable, category_to_id_map: dict):
self.attrbox = attrbox
self.category_to_id_map = category_to_id_map
def __call__(self, img: dict):
img_height = img['height']
img_width = img['width']
anno = self.attrbox.loadAnns(self.attrbox.getAnnIds(imgIds=img['id'], iscrowd=None))
gt_index = 0
for target in anno:
if target['area'] < -1:
continue
if 'ignore' in target and target['ignore']:
continue
box = coco_anno_box_to_center_relative(target['bbox'], img_height, img_width)
if box[2] <= 0 and box[3] <= 0:
continue
img['gt_boxes'][gt_index] = box
img['gt_labels'][gt_index] = \
self.category_to_id_map[target['category_id']]
gt_index += 1
if gt_index >= 50:
break
return img
class RandomDistort:
""" Distort the input image randomly.
Args: Args:
lower(float): The lower bound value for enhancement, default is 0.5. lower(float): The lower bound value for enhancement, default is 0.5.
upper(float): The upper bound value for enhancement, default is 1.5. upper(float): The upper bound value for enhancement, default is 1.5.
...@@ -155,7 +56,9 @@ class RandomDistort: ...@@ -155,7 +56,9 @@ class RandomDistort:
class RandomExpand: class RandomExpand:
"""Randomly expand images and gt boxes by random ratio. It is a data enhancement operation for model training. """
Randomly expand images and gt boxes by random ratio. It is a data enhancement operation for model training.
Args: Args:
max_ratio(float): Max value for expansion ratio, default is 4. max_ratio(float): Max value for expansion ratio, default is 4.
fill(list): Initialize the pixel value of the image with the input fill value, default is None. fill(list): Initialize the pixel value of the image with the input fill value, default is None.
...@@ -213,6 +116,7 @@ class RandomExpand: ...@@ -213,6 +116,7 @@ class RandomExpand:
class RandomCrop: class RandomCrop:
""" """
Random crop the input image according to constraints. Random crop the input image according to constraints.
Args: Args:
scales(list): The value of the cutting area relative to the original area, expressed in the form of \ scales(list): The value of the cutting area relative to the original area, expressed in the form of \
[min, max]. The default value is [.3, 1.]. [min, max]. The default value is [.3, 1.].
...@@ -276,6 +180,7 @@ class RandomCrop: ...@@ -276,6 +180,7 @@ class RandomCrop:
data['gt_labels'] = crop_labels data['gt_labels'] = crop_labels
data['gt_scores'] = crop_scores data['gt_scores'] = crop_scores
return img, data return img, data
img = np.asarray(img) img = np.asarray(img)
data['gt_boxes'] = boxes data['gt_boxes'] = boxes
data['gt_labels'] = labels data['gt_labels'] = labels
...@@ -285,8 +190,10 @@ class RandomCrop: ...@@ -285,8 +190,10 @@ class RandomCrop:
class RandomFlip: class RandomFlip:
"""Flip the images and gt boxes randomly. """Flip the images and gt boxes randomly.
Args: Args:
thresh: Probability for random flip. thresh: Probability for random flip.
Returns: Returns:
img(np.ndarray): Distorted image. img(np.ndarray): Distorted image.
data(dict): Image info and label info. data(dict): Image info and label info.
...@@ -304,9 +211,12 @@ class RandomFlip: ...@@ -304,9 +211,12 @@ class RandomFlip:
class Compose: class Compose:
"""Preprocess the input data according to the operators. """
Preprocess the input data according to the operators.
Args: Args:
transforms(list): Preprocessing operators. transforms(list): Preprocessing operators.
Returns: Returns:
img(np.ndarray): Preprocessed image. img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None. data(dict): Image info and label info, default is None.
...@@ -342,10 +252,13 @@ class Compose: ...@@ -342,10 +252,13 @@ class Compose:
class Resize: class Resize:
"""Resize the input images. """
Resize the input images.
Args: Args:
target_size(int): Targeted input size. target_size(int): Targeted input size.
interp(str): Interpolation method. interp(str): Interpolation method.
Returns: Returns:
img(np.ndarray): Preprocessed image. img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None. data(dict): Image info and label info, default is None.
...@@ -385,10 +298,13 @@ class Resize: ...@@ -385,10 +298,13 @@ class Resize:
class Normalize: class Normalize:
"""Normalize the input images. """
Normalize the input images.
Args: Args:
mean(list): Mean values for normalization, default is [0.5, 0.5, 0.5]. mean(list): Mean values for normalization, default is [0.5, 0.5, 0.5].
std(list): Standard deviation for normalization, default is [0.5, 0.5, 0.5]. std(list): Standard deviation for normalization, default is [0.5, 0.5, 0.5].
Returns: Returns:
img(np.ndarray): Preprocessed image. img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None. data(dict): Image info and label info, default is None.
...@@ -403,20 +319,19 @@ class Normalize: ...@@ -403,20 +319,19 @@ class Normalize:
raise ValueError('{}: std is invalid!'.format(self)) raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, im, data=None): def __call__(self, im, data=None):
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std)
if data is not None: if data is not None:
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std)
return im, data return im, data
else: else:
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std)
return im return im
class ShuffleBox: class ShuffleBox:
"""Shuffle data information.""" """Shuffle detection information for corresponding input image."""
def __call__(self, img, data): def __call__(self, img, data):
gt = np.concatenate([data['gt_boxes'], data['gt_labels'][:, np.newaxis], data['gt_scores'][:, np.newaxis]], gt = np.concatenate([data['gt_boxes'], data['gt_labels'][:, np.newaxis], data['gt_scores'][:, np.newaxis]],
axis=1) axis=1)
......
...@@ -124,28 +124,6 @@ def get_img_file(dir_name: str) -> list: ...@@ -124,28 +124,6 @@ def get_img_file(dir_name: str) -> list:
return images return images
def coco_anno_box_to_center_relative(box: list, img_height: int, img_width: int) -> np.ndarray:
"""
Convert COCO annotations box with format [x1, y1, w, h] to
center mode [center_x, center_y, w, h] and divide image width
and height to get relative value in range[0, 1]
"""
assert len(box) == 4, "box should be a len(4) list or tuple"
x, y, w, h = box
x1 = max(x, 0)
x2 = min(x + w - 1, img_width - 1)
y1 = max(y, 0)
y2 = min(y + h - 1, img_height - 1)
x = (x1 + x2) / 2 / img_width
y = (y1 + y2) / 2 / img_height
w = (x2 - x1) / img_width
h = (y2 - y1) / img_height
return np.array([x, y, w, h])
def box_crop(boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, crop: list, img_shape: list): def box_crop(boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, crop: list, img_shape: list):
"""Crop the boxes ,labels, scores according to the given shape""" """Crop the boxes ,labels, scores according to the given shape"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册