From 7e6b370f80d0baa6faadf8ed6ccda77ce512bab6 Mon Sep 17 00:00:00 2001 From: haoyuying <18844182690@163.com> Date: Fri, 9 Oct 2020 16:05:16 +0800 Subject: [PATCH] revise transform and pascalvoc dataset --- .../yolov3_darknet53_pascalvoc/train.py | 21 +-- .../yolov3_darknet53_pascalvoc/module.py | 24 +-- paddlehub/datasets/pascalvoc.py | 124 +++++++++++++++- paddlehub/process/detect_transforms.py | 139 ++++-------------- paddlehub/process/functional.py | 22 --- 5 files changed, 173 insertions(+), 157 deletions(-) diff --git a/demo/detection/yolov3_darknet53_pascalvoc/train.py b/demo/detection/yolov3_darknet53_pascalvoc/train.py index 2db0833d..a2d72e9e 100644 --- a/demo/detection/yolov3_darknet53_pascalvoc/train.py +++ b/demo/detection/yolov3_darknet53_pascalvoc/train.py @@ -3,20 +3,21 @@ import paddlehub as hub import paddle.nn as nn from paddlehub.finetune.trainer import Trainer from paddlehub.datasets.pascalvoc import DetectionData -from paddlehub.process.detect_transforms import Compose, RandomDistort, RandomExpand, RandomCrop, RandomFlip, Normalize, Resize, ShuffleBox - +import paddlehub.process.detect_transforms as T if __name__ == "__main__": place = paddle.CUDAPlace(0) paddle.disable_static() - transform = Compose([ - RandomDistort(), - RandomExpand(fill=[0.485, 0.456, 0.406]), - RandomCrop(), - Resize(target_size=416), - RandomFlip(), - ShuffleBox(), - Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + transform = T.Compose([ + T.RandomDistort(), + T.RandomExpand(fill=[0.485, 0.456, 0.406]), + T.RandomCrop(), + T.Resize(target_size=416), + T.RandomFlip(), + T.ShuffleBox(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) + train_reader = DetectionData(transform) model = hub.Module(name='yolov3_darknet53_pascalvoc') model.train() diff --git a/hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py b/hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py index d4c81e0f..819d5ea5 100644 --- a/hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py +++ b/hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py @@ -7,7 +7,7 @@ from paddle.nn.initializer import Normal, Constant from paddle.regularizer import L2Decay from pycocotools.coco import COCO from paddlehub.module.cv_module import Yolov3Module -from paddlehub.process.detect_transforms import Compose, RandomDistort, RandomExpand, RandomCrop, Resize, RandomFlip, ShuffleBox, Normalize +import paddlehub.process.detect_transforms as T from paddlehub.module.module import moduleinfo @@ -288,19 +288,19 @@ class YOLOv3(nn.Layer): def transform(self, img): if self.is_train: - transform = Compose([ - RandomDistort(), - RandomExpand(fill=[0.485, 0.456, 0.406]), - RandomCrop(), - Resize(target_size=416), - RandomFlip(), - ShuffleBox(), - Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transform = T.Compose([ + T.RandomDistort(), + T.RandomExpand(fill=[0.485, 0.456, 0.406]), + T.RandomCrop(), + T.Resize(target_size=416), + T.RandomFlip(), + T.ShuffleBox(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: - transform = Compose([ - Resize(target_size=416, interp='CUBIC'), - Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transform = T.Compose([ + T.Resize(target_size=416, interp='CUBIC'), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(img) diff --git a/paddlehub/datasets/pascalvoc.py b/paddlehub/datasets/pascalvoc.py index 42a2cbfd..d6f6a344 100644 --- a/paddlehub/datasets/pascalvoc.py +++ b/paddlehub/datasets/pascalvoc.py @@ -14,15 +14,137 @@ # limitations under the License. import os +import copy from typing import Callable import paddle +import numpy as np from paddlehub.env import DATA_HOME from pycocotools.coco import COCO from paddlehub.process.transforms import DetectCatagory, ParseImages +class DetectCatagory: + """Load label name, id and map from detection dataset. + + Args: + attrbox(Callable): Method to get detection attributes of images. + data_dir(str): Image dataset path. + + Returns: + label_names(List(str)): The dataset label names. + label_ids(List(int)): The dataset label ids. + category_to_id_map(dict): Mapping relations of category and id for images. + """ + def __init__(self, attrbox: Callable, data_dir: str): + self.attrbox = attrbox + self.img_dir = data_dir + + def __call__(self): + self.categories = self.attrbox.loadCats(self.attrbox.getCatIds()) + self.num_category = len(self.categories) + label_names = [] + label_ids = [] + for category in self.categories: + label_names.append(category['name']) + label_ids.append(int(category['id'])) + category_to_id_map = {v: i for i, v in enumerate(label_ids)} + return label_names, label_ids, category_to_id_map + + +class ParseImages: + """Prepare images for detection. + + Args: + attrbox(Callable): Method to get detection attributes of images. + data_dir(str): Image dataset path. + category_to_id_map(dict): Mapping relations of category and id for images. + + Returns: + imgs(dict): The input for detection model, it is a dict. + """ + def __init__(self, attrbox: Callable, data_dir: str, category_to_id_map: dict): + self.attrbox = attrbox + self.img_dir = data_dir + self.category_to_id_map = category_to_id_map + self.parse_gt_annotations = GTAnotations(self.attrbox, self.category_to_id_map) + + def __call__(self): + image_ids = self.attrbox.getImgIds() + image_ids.sort() + imgs = copy.deepcopy(self.attrbox.loadImgs(image_ids)) + + for img in imgs: + img['image'] = os.path.join(self.img_dir, img['file_name']) + assert os.path.exists(img['image']), "image {} not found.".format(img['image']) + box_num = 50 + img['gt_boxes'] = np.zeros((box_num, 4), dtype=np.float32) + img['gt_labels'] = np.zeros((box_num), dtype=np.int32) + img = self.parse_gt_annotations(img) + return imgs + + +class GTAnotations: + """Set gt boxes and gt labels for train. + + Args: + attrbox(Callable): Method for get detection attributes for images. + category_to_id_map(dict): Mapping relations of category and id for images. + img(dict): Input for detection model. + + Returns: + img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input. + """ + def __init__(self, attrbox: Callable, category_to_id_map: dict): + self.attrbox = attrbox + self.category_to_id_map = category_to_id_map + + def box_to_center_relative(self, box: list, img_height: int, img_width: int) -> np.ndarray: + """ + Convert COCO annotations box with format [x1, y1, w, h] to + center mode [center_x, center_y, w, h] and divide image width + and height to get relative value in range[0, 1] + """ + assert len(box) == 4, "box should be a len(4) list or tuple" + x, y, w, h = box + + x1 = max(x, 0) + x2 = min(x + w - 1, img_width - 1) + y1 = max(y, 0) + y2 = min(y + h - 1, img_height - 1) + + x = (x1 + x2) / 2 / img_width + y = (y1 + y2) / 2 / img_height + w = (x2 - x1) / img_width + h = (y2 - y1) / img_height + + return np.array([x, y, w, h]) + + def __call__(self, img: dict): + img_height = img['height'] + img_width = img['width'] + anno = self.attrbox.loadAnns(self.attrbox.getAnnIds(imgIds=img['id'], iscrowd=None)) + gt_index = 0 + + for target in anno: + if target['area'] < -1: + continue + if 'ignore' in target and target['ignore']: + continue + box = self.box_to_center_relative(target['bbox'], img_height, img_width) + + if box[2] <= 0 and box[3] <= 0: + continue + img['gt_boxes'][gt_index] = box + img['gt_labels'][gt_index] = \ + self.category_to_id_map[target['category_id']] + gt_index += 1 + if gt_index >= 50: + break + return img + + class DetectionData(paddle.io.Dataset): """ Dataset for image detection. @@ -57,7 +179,7 @@ class DetectionData(paddle.io.Dataset): parse_dataset_catagory = DetectCatagory(self.COCO, self.img_dir) self.label_names, self.label_ids, self.category_to_id_map = parse_dataset_catagory() - parse_images = ParseImages(self.COCO, self.mode, self.img_dir, self.category_to_id_map) + parse_images = ParseImages(self.COCO, self.img_dir, self.category_to_id_map) self.data = parse_images() def __getitem__(self, idx: int): diff --git a/paddlehub/process/detect_transforms.py b/paddlehub/process/detect_transforms.py index 811a68bd..5a50ffb0 100644 --- a/paddlehub/process/detect_transforms.py +++ b/paddlehub/process/detect_transforms.py @@ -1,4 +1,3 @@ -import copy import os import random from typing import Callable @@ -15,108 +14,10 @@ from paddlehub.process.functional import * matplotlib.use('Agg') -class DetectCatagory: - """Load label name, id and map from detection dataset. - - Args: - attrbox(Callable): Method to get detection attributes of images. - data_dir(str): Image dataset path. - - Returns: - label_names(List(str)): The dataset label names. - label_ids(List(int)): The dataset label ids. - category_to_id_map(dict): Mapping relations of category and id for images. - """ - def __init__(self, attrbox: Callable, data_dir: str): - self.attrbox = attrbox - self.img_dir = data_dir - - def __call__(self): - self.categories = self.attrbox.loadCats(self.attrbox.getCatIds()) - self.num_category = len(self.categories) - label_names = [] - label_ids = [] - for category in self.categories: - label_names.append(category['name']) - label_ids.append(int(category['id'])) - category_to_id_map = {v: i for i, v in enumerate(label_ids)} - return label_names, label_ids, category_to_id_map - - -class ParseImages: - """Prepare images for detection. - - Args: - attrbox(Callable): Method to get detection attributes of images. - is_train(bool): Select the mode for train or test. - data_dir(str): Image dataset path. - category_to_id_map(dict): Mapping relations of category and id for images. - - Returns: - imgs(dict): The input for detection model, it is a dict. +class RandomDistort: """ - def __init__(self, attrbox: Callable, data_dir: str, category_to_id_map: dict): - self.attrbox = attrbox - self.img_dir = data_dir - self.category_to_id_map = category_to_id_map - self.parse_gt_annotations = GTAnotations(self.attrbox, self.category_to_id_map) - - def __call__(self): - image_ids = self.attrbox.getImgIds() - image_ids.sort() - imgs = copy.deepcopy(self.attrbox.loadImgs(image_ids)) - - for img in imgs: - img['image'] = os.path.join(self.img_dir, img['file_name']) - assert os.path.exists(img['image']), "image {} not found.".format(img['image']) - box_num = 50 - img['gt_boxes'] = np.zeros((box_num, 4), dtype=np.float32) - img['gt_labels'] = np.zeros((box_num), dtype=np.int32) - img = self.parse_gt_annotations(img) - return imgs - - -class GTAnotations: - """Set gt boxes and gt labels for train. + Distort the input image randomly. - Args: - attrbox(Callable): Method for get detection attributes for images. - category_to_id_map(dict): Mapping relations of category and id for images. - img(dict): Input for detection model. - - Returns: - img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input. - """ - def __init__(self, attrbox: Callable, category_to_id_map: dict): - self.attrbox = attrbox - self.category_to_id_map = category_to_id_map - - def __call__(self, img: dict): - img_height = img['height'] - img_width = img['width'] - anno = self.attrbox.loadAnns(self.attrbox.getAnnIds(imgIds=img['id'], iscrowd=None)) - gt_index = 0 - - for target in anno: - if target['area'] < -1: - continue - if 'ignore' in target and target['ignore']: - continue - box = coco_anno_box_to_center_relative(target['bbox'], img_height, img_width) - - if box[2] <= 0 and box[3] <= 0: - continue - img['gt_boxes'][gt_index] = box - img['gt_labels'][gt_index] = \ - self.category_to_id_map[target['category_id']] - gt_index += 1 - if gt_index >= 50: - break - return img - - -class RandomDistort: - """ Distort the input image randomly. Args: lower(float): The lower bound value for enhancement, default is 0.5. upper(float): The upper bound value for enhancement, default is 1.5. @@ -155,7 +56,9 @@ class RandomDistort: class RandomExpand: - """Randomly expand images and gt boxes by random ratio. It is a data enhancement operation for model training. + """ + Randomly expand images and gt boxes by random ratio. It is a data enhancement operation for model training. + Args: max_ratio(float): Max value for expansion ratio, default is 4. fill(list): Initialize the pixel value of the image with the input fill value, default is None. @@ -213,6 +116,7 @@ class RandomExpand: class RandomCrop: """ Random crop the input image according to constraints. + Args: scales(list): The value of the cutting area relative to the original area, expressed in the form of \ [min, max]. The default value is [.3, 1.]. @@ -276,6 +180,7 @@ class RandomCrop: data['gt_labels'] = crop_labels data['gt_scores'] = crop_scores return img, data + img = np.asarray(img) data['gt_boxes'] = boxes data['gt_labels'] = labels @@ -285,8 +190,10 @@ class RandomCrop: class RandomFlip: """Flip the images and gt boxes randomly. + Args: thresh: Probability for random flip. + Returns: img(np.ndarray): Distorted image. data(dict): Image info and label info. @@ -304,9 +211,12 @@ class RandomFlip: class Compose: - """Preprocess the input data according to the operators. + """ + Preprocess the input data according to the operators. + Args: transforms(list): Preprocessing operators. + Returns: img(np.ndarray): Preprocessed image. data(dict): Image info and label info, default is None. @@ -342,10 +252,13 @@ class Compose: class Resize: - """Resize the input images. + """ + Resize the input images. + Args: target_size(int): Targeted input size. interp(str): Interpolation method. + Returns: img(np.ndarray): Preprocessed image. data(dict): Image info and label info, default is None. @@ -385,10 +298,13 @@ class Resize: class Normalize: - """Normalize the input images. + """ + Normalize the input images. + Args: mean(list): Mean values for normalization, default is [0.5, 0.5, 0.5]. std(list): Standard deviation for normalization, default is [0.5, 0.5, 0.5]. + Returns: img(np.ndarray): Preprocessed image. data(dict): Image info and label info, default is None. @@ -403,20 +319,19 @@ class Normalize: raise ValueError('{}: std is invalid!'.format(self)) def __call__(self, im, data=None): + + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + im = normalize(im, mean, std) + if data is not None: - mean = np.array(self.mean)[np.newaxis, np.newaxis, :] - std = np.array(self.std)[np.newaxis, np.newaxis, :] - im = normalize(im, mean, std) return im, data else: - mean = np.array(self.mean)[np.newaxis, np.newaxis, :] - std = np.array(self.std)[np.newaxis, np.newaxis, :] - im = normalize(im, mean, std) return im class ShuffleBox: - """Shuffle data information.""" + """Shuffle detection information for corresponding input image.""" def __call__(self, img, data): gt = np.concatenate([data['gt_boxes'], data['gt_labels'][:, np.newaxis], data['gt_scores'][:, np.newaxis]], axis=1) diff --git a/paddlehub/process/functional.py b/paddlehub/process/functional.py index 7b97a004..2c4c5f01 100644 --- a/paddlehub/process/functional.py +++ b/paddlehub/process/functional.py @@ -124,28 +124,6 @@ def get_img_file(dir_name: str) -> list: return images -def coco_anno_box_to_center_relative(box: list, img_height: int, img_width: int) -> np.ndarray: - """ - Convert COCO annotations box with format [x1, y1, w, h] to - center mode [center_x, center_y, w, h] and divide image width - and height to get relative value in range[0, 1] - """ - assert len(box) == 4, "box should be a len(4) list or tuple" - x, y, w, h = box - - x1 = max(x, 0) - x2 = min(x + w - 1, img_width - 1) - y1 = max(y, 0) - y2 = min(y + h - 1, img_height - 1) - - x = (x1 + x2) / 2 / img_width - y = (y1 + y2) / 2 / img_height - w = (x2 - x1) / img_width - h = (y2 - y1) / img_height - - return np.array([x, y, w, h]) - - def box_crop(boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, crop: list, img_shape: list): """Crop the boxes ,labels, scores according to the given shape""" -- GitLab