diff --git a/demo/detection/yolov3_darknet53_pascalvoc/predict.py b/demo/detection/yolov3_darknet53_pascalvoc/predict.py index 27886f078bb5093f591d42d729e11c3ac63c24fe..e8e49eb41fe8d22124c9b7fb9c0b1de77a7da8f6 100644 --- a/demo/detection/yolov3_darknet53_pascalvoc/predict.py +++ b/demo/detection/yolov3_darknet53_pascalvoc/predict.py @@ -6,4 +6,4 @@ if __name__ == '__main__': paddle.disable_static() model = model = hub.Module(name='yolov3_darknet53_pascalvoc', is_train=False) model.eval() - model.predict(imgpath="/PATH/TO/IMAGE", filelist="/PATH/TO/JSON/FILE") + model.predict(imgpath="4026.jpeg", filelist="/PATH/TO/JSON/FILE") diff --git a/demo/detection/yolov3_darknet53_pascalvoc/train.py b/demo/detection/yolov3_darknet53_pascalvoc/train.py index d846c6b7443d524ada4ac57ea6802fed0854ffcb..2db0833dc6c97a9438c952005a9a5cc6bddbc0bc 100644 --- a/demo/detection/yolov3_darknet53_pascalvoc/train.py +++ b/demo/detection/yolov3_darknet53_pascalvoc/train.py @@ -3,18 +3,21 @@ import paddlehub as hub import paddle.nn as nn from paddlehub.finetune.trainer import Trainer from paddlehub.datasets.pascalvoc import DetectionData -from paddlehub.process.transforms import DetectTrainReader, DetectTestReader +from paddlehub.process.detect_transforms import Compose, RandomDistort, RandomExpand, RandomCrop, RandomFlip, Normalize, Resize, ShuffleBox if __name__ == "__main__": place = paddle.CUDAPlace(0) paddle.disable_static() - is_train = True - if is_train: - transform = DetectTrainReader() - train_reader = DetectionData(transform) - else: - transform = DetectTestReader() - test_reader = DetectionData(transform) + transform = Compose([ + RandomDistort(), + RandomExpand(fill=[0.485, 0.456, 0.406]), + RandomCrop(), + Resize(target_size=416), + RandomFlip(), + ShuffleBox(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + train_reader = DetectionData(transform) model = hub.Module(name='yolov3_darknet53_pascalvoc') model.train() optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) diff --git a/hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py b/hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py index 920839a95fd054555fc9f0de236c6166a76a04c7..d4c81e0fc4c5196f6fea0218f82a41bc2744fb6a 100644 --- a/hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py +++ b/hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py @@ -7,7 +7,7 @@ from paddle.nn.initializer import Normal, Constant from paddle.regularizer import L2Decay from pycocotools.coco import COCO from paddlehub.module.cv_module import Yolov3Module -from paddlehub.process.transforms import DetectTrainReader, DetectTestReader +from paddlehub.process.detect_transforms import Compose, RandomDistort, RandomExpand, RandomCrop, Resize, RandomFlip, ShuffleBox, Normalize from paddlehub.module.module import moduleinfo @@ -286,12 +286,24 @@ class YOLOv3(nn.Layer): self.set_dict(model_dict) print("load pretrained checkpoint success") - def transform(self, img: paddle.Tensor, size: int): + def transform(self, img): if self.is_train: - transforms = DetectTrainReader() + transform = Compose([ + RandomDistort(), + RandomExpand(fill=[0.485, 0.456, 0.406]), + RandomCrop(), + Resize(target_size=416), + RandomFlip(), + ShuffleBox(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) else: - transforms = DetectTestReader() - return transforms(img, size) + transform = Compose([ + Resize(target_size=416, interp='CUBIC'), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + return transform(img) def get_label_infos(self, file_list: str): self.COCO = COCO(file_list) @@ -301,23 +313,8 @@ class YOLOv3(nn.Layer): label_names.append(category['name']) return label_names - def forward(self, - inputs: paddle.Tensor, - gtbox: paddle.Tensor = None, - gtlabel: paddle.Tensor = None, - gtscore: paddle.Tensor = None, - im_shape: paddle.Tensor = None): - - self.gtbox = gtbox - self.gtlabel = gtlabel - self.gtscore = gtscore - self.im_shape = im_shape - self.outputs = [] - self.boxes = [] - self.scores = [] - self.losses = [] - self.pred = [] - self.downsample = 32 + def forward(self, inputs: paddle.Tensor): + outputs = [] blocks = self.block(inputs) route = None for i, block in enumerate(blocks): @@ -325,58 +322,9 @@ class YOLOv3(nn.Layer): block = paddle.concat([route, block], axis=1) route, tip = self.yolo_blocks[i](block) block_out = self.block_outputs[i](tip) - self.outputs.append(block_out) + outputs.append(block_out) if i < 2: route = self.route_blocks_2[i](route) route = self.upsample(route) - for i, out in enumerate(self.outputs): - anchor_mask = self.anchor_masks[i] - - if self.is_train: - loss = F.yolov3_loss(x=out, - gt_box=self.gtbox, - gt_label=self.gtlabel, - gt_score=self.gtscore, - anchors=self.anchors, - anchor_mask=anchor_mask, - class_num=self.class_num, - ignore_thresh=self.ignore_thresh, - downsample_ratio=self.downsample, - use_label_smooth=False) - else: - loss = paddle.to_tensor(0.0) - self.losses.append(paddle.reduce_mean(loss)) - - mask_anchors = [] - for m in anchor_mask: - mask_anchors.append((self.anchors[2 * m])) - mask_anchors.append(self.anchors[2 * m + 1]) - - boxes, scores = F.yolo_box(x=out, - img_size=self.im_shape, - anchors=mask_anchors, - class_num=self.class_num, - conf_thresh=self.valid_thresh, - downsample_ratio=self.downsample, - name="yolo_box" + str(i)) - - self.boxes.append(boxes) - self.scores.append(paddle.transpose(scores, perm=[0, 2, 1])) - self.downsample //= 2 - - for i in range(self.boxes[0].shape[0]): - yolo_boxes = paddle.unsqueeze(paddle.concat([self.boxes[0][i], self.boxes[1][i], self.boxes[2][i]], axis=0), - 0) - yolo_scores = paddle.unsqueeze( - paddle.concat([self.scores[0][i], self.scores[1][i], self.scores[2][i]], axis=1), 0) - pred = F.multiclass_nms(bboxes=yolo_boxes, - scores=yolo_scores, - score_threshold=self.valid_thresh, - nms_top_k=self.nms_topk, - keep_top_k=self.nms_posk, - nms_threshold=self.nms_thresh, - background_label=-1) - self.pred.append(pred) - - return sum(self.losses), self.pred + return outputs diff --git a/paddlehub/datasets/pascalvoc.py b/paddlehub/datasets/pascalvoc.py index cc2d58f08bac0cfd2af50e78af89ee730c5ab684..42a2cbfd8a8e9076ba779c4519fa9630654ea9f1 100644 --- a/paddlehub/datasets/pascalvoc.py +++ b/paddlehub/datasets/pascalvoc.py @@ -61,14 +61,10 @@ class DetectionData(paddle.io.Dataset): self.data = parse_images() def __getitem__(self, idx: int): - if self.mode == "train": - img = self.data[idx] - out_img, gt_boxes, gt_labels, gt_scores = self.transform(img, 416) - return out_img, gt_boxes, gt_labels, gt_scores - elif self.mode == "test": - img = self.data[idx] - out_img, id, (h, w) = self.transform(img) - return out_img, id, (h, w) + img = self.data[idx] + im, data = self.transform(img) + out_img, gt_boxes, gt_labels, gt_scores = im, data['gt_boxes'], data['gt_labels'], data['gt_scores'] + return out_img, gt_boxes, gt_labels, gt_scores def __len__(self): return len(self.data) diff --git a/paddlehub/module/cv_module.py b/paddlehub/module/cv_module.py index 8cef24921efdeab5eb447a59bad045ad7171e981..5ca749e6da33a2a0aaf5b74375735fb6982f94f7 100644 --- a/paddlehub/module/cv_module.py +++ b/paddlehub/module/cv_module.py @@ -26,7 +26,8 @@ from PIL import Image from paddlehub.module.module import serving, RunModule from paddlehub.utils.utils import base64_to_cv2 -from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize, BoxTool +from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize +from paddlehub.process.functional import subtract_imagenet_mean_batch, gram_matrix, draw_boxes_on_image, img_shape class ImageServing(object): @@ -218,38 +219,30 @@ class Yolov3Module(RunModule, ImageServing): Returns: results(dict) : The model outputs, such as metrics. ''' - ious = [] - boxtool = BoxTool() img = batch[0].astype('float32') - B, C, W, H = img.shape - im_shape = np.array([(W, H)] * B).astype('int32') - im_shape = paddle.to_tensor(im_shape) - - gt_box = batch[1].astype('float32') - gt_label = batch[2].astype('int32') - gt_score = batch[3].astype("float32") - loss, pred = self(img, gt_box, gt_label, gt_score, im_shape) - - for i in range(len(pred)): - bboxes = pred[i].numpy() - labels = bboxes[:, 0].astype('int32') - scores = bboxes[:, 1].astype('float32') - boxes = bboxes[:, 2:].astype('float32') - iou = [] - - for j, (box, score, label) in enumerate(zip(boxes, scores, labels)): - x1, y1, x2, y2 = box - w = x2 - x1 + 1 - h = y2 - y1 + 1 - bbox = [x1, y1, w, h] - bbox = np.expand_dims(boxtool.coco_anno_box_to_center_relative(bbox, H, W), 0) - gt = gt_box[i].numpy() - iou.append(max(boxtool.box_iou_xywh(bbox, gt))) - - ious.append(max(iou)) - ious = paddle.to_tensor(np.array(ious)) - - return {'loss': loss, 'metrics': {'iou': ious}} + gtbox = batch[1].astype('float32') + gtlabel = batch[2].astype('int32') + gtscore = batch[3].astype("float32") + losses = [] + outputs = self(img) + self.downsample = 32 + + for i, out in enumerate(outputs): + anchor_mask = self.anchor_masks[i] + loss = F.yolov3_loss(x=out, + gt_box=gtbox, + gt_label=gtlabel, + gt_score=gtscore, + anchors=self.anchors, + anchor_mask=anchor_mask, + class_num=self.class_num, + ignore_thresh=self.ignore_thresh, + downsample_ratio=32, + use_label_smooth=False) + losses.append(paddle.reduce_mean(loss)) + self.downsample //= 2 + + return {'loss': sum(losses)} def predict(self, imgpath: str, filelist: str, visualization: bool = True, save_path: str = 'result'): ''' @@ -266,28 +259,53 @@ class Yolov3Module(RunModule, ImageServing): scores(np.ndarray): Predict score. labels(np.ndarray): Predict labels. ''' - boxtool = BoxTool() - img = {} - img['image'] = imgpath - img['id'] = 0 - im, im_id, im_shape = self.transform(img, 416) + boxes = [] + scores = [] + self.downsample = 32 + im = self.transform(imgpath) + h, w, c = img_shape(imgpath) + im_shape = paddle.to_tensor(np.array([[h, w]]).astype('int32')) label_names = self.get_label_infos(filelist) - img_data = np.array([im]).astype('float32') - img_data = paddle.to_tensor(img_data) - im_shape = np.array([im_shape]).astype('int32') - im_shape = paddle.to_tensor(im_shape) - - output, pred = self(img_data, None, None, None, im_shape) - - for i in range(len(pred)): - bboxes = pred[i].numpy() - labels = bboxes[:, 0].astype('int32') - scores = bboxes[:, 1].astype('float32') - boxes = bboxes[:, 2:].astype('float32') - - if visualization: - if not os.path.exists(save_path): - os.mkdir(save_path) - boxtool.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5) + img_data = paddle.to_tensor(np.array([im]).astype('float32')) + + outputs = self(img_data) + + for i, out in enumerate(outputs): + anchor_mask = self.anchor_masks[i] + mask_anchors = [] + for m in anchor_mask: + mask_anchors.append((self.anchors[2 * m])) + mask_anchors.append(self.anchors[2 * m + 1]) + + box, score = F.yolo_box(x=out, + img_size=im_shape, + anchors=mask_anchors, + class_num=self.class_num, + conf_thresh=self.valid_thresh, + downsample_ratio=self.downsample, + name="yolo_box" + str(i)) + + boxes.append(box) + scores.append(paddle.transpose(score, perm=[0, 2, 1])) + self.downsample //= 2 + + yolo_boxes = paddle.concat(boxes, axis=1) + yolo_scores = paddle.concat(scores, axis=2) + + pred = F.multiclass_nms(bboxes=yolo_boxes, + scores=yolo_scores, + score_threshold=self.valid_thresh, + nms_top_k=self.nms_topk, + keep_top_k=self.nms_posk, + nms_threshold=self.nms_thresh, + background_label=-1) + + bboxes = pred.numpy() + labels = bboxes[:, 0].astype('int32') + scores = bboxes[:, 1].astype('float32') + boxes = bboxes[:, 2:].astype('float32') + + if visualization: + draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5) return boxes, scores, labels diff --git a/paddlehub/process/detect_transforms.py b/paddlehub/process/detect_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..811a68bd766f88a6f0d62b40b93abb5f59f4b76b --- /dev/null +++ b/paddlehub/process/detect_transforms.py @@ -0,0 +1,427 @@ +import copy +import os +import random +from typing import Callable + +import cv2 +import numpy as np +import matplotlib +import PIL +from PIL import Image, ImageEnhance +from matplotlib import pyplot as plt + +from paddlehub.process.functional import * + +matplotlib.use('Agg') + + +class DetectCatagory: + """Load label name, id and map from detection dataset. + + Args: + attrbox(Callable): Method to get detection attributes of images. + data_dir(str): Image dataset path. + + Returns: + label_names(List(str)): The dataset label names. + label_ids(List(int)): The dataset label ids. + category_to_id_map(dict): Mapping relations of category and id for images. + """ + def __init__(self, attrbox: Callable, data_dir: str): + self.attrbox = attrbox + self.img_dir = data_dir + + def __call__(self): + self.categories = self.attrbox.loadCats(self.attrbox.getCatIds()) + self.num_category = len(self.categories) + label_names = [] + label_ids = [] + for category in self.categories: + label_names.append(category['name']) + label_ids.append(int(category['id'])) + category_to_id_map = {v: i for i, v in enumerate(label_ids)} + return label_names, label_ids, category_to_id_map + + +class ParseImages: + """Prepare images for detection. + + Args: + attrbox(Callable): Method to get detection attributes of images. + is_train(bool): Select the mode for train or test. + data_dir(str): Image dataset path. + category_to_id_map(dict): Mapping relations of category and id for images. + + Returns: + imgs(dict): The input for detection model, it is a dict. + """ + def __init__(self, attrbox: Callable, data_dir: str, category_to_id_map: dict): + self.attrbox = attrbox + self.img_dir = data_dir + self.category_to_id_map = category_to_id_map + self.parse_gt_annotations = GTAnotations(self.attrbox, self.category_to_id_map) + + def __call__(self): + image_ids = self.attrbox.getImgIds() + image_ids.sort() + imgs = copy.deepcopy(self.attrbox.loadImgs(image_ids)) + + for img in imgs: + img['image'] = os.path.join(self.img_dir, img['file_name']) + assert os.path.exists(img['image']), "image {} not found.".format(img['image']) + box_num = 50 + img['gt_boxes'] = np.zeros((box_num, 4), dtype=np.float32) + img['gt_labels'] = np.zeros((box_num), dtype=np.int32) + img = self.parse_gt_annotations(img) + return imgs + + +class GTAnotations: + """Set gt boxes and gt labels for train. + + Args: + attrbox(Callable): Method for get detection attributes for images. + category_to_id_map(dict): Mapping relations of category and id for images. + img(dict): Input for detection model. + + Returns: + img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input. + """ + def __init__(self, attrbox: Callable, category_to_id_map: dict): + self.attrbox = attrbox + self.category_to_id_map = category_to_id_map + + def __call__(self, img: dict): + img_height = img['height'] + img_width = img['width'] + anno = self.attrbox.loadAnns(self.attrbox.getAnnIds(imgIds=img['id'], iscrowd=None)) + gt_index = 0 + + for target in anno: + if target['area'] < -1: + continue + if 'ignore' in target and target['ignore']: + continue + box = coco_anno_box_to_center_relative(target['bbox'], img_height, img_width) + + if box[2] <= 0 and box[3] <= 0: + continue + img['gt_boxes'][gt_index] = box + img['gt_labels'][gt_index] = \ + self.category_to_id_map[target['category_id']] + gt_index += 1 + if gt_index >= 50: + break + return img + + +class RandomDistort: + """ Distort the input image randomly. + Args: + lower(float): The lower bound value for enhancement, default is 0.5. + upper(float): The upper bound value for enhancement, default is 1.5. + + Returns: + img(np.ndarray): Distorted image. + data(dict): Image info and label info. + + """ + def __init__(self, lower: float = 0.5, upper: float = 1.5): + self.lower = lower + self.upper = upper + + def random_brightness(self, img: PIL.Image): + e = np.random.uniform(self.lower, self.upper) + return ImageEnhance.Brightness(img).enhance(e) + + def random_contrast(self, img: PIL.Image): + e = np.random.uniform(self.lower, self.upper) + return ImageEnhance.Contrast(img).enhance(e) + + def random_color(self, img: PIL.Image): + e = np.random.uniform(self.lower, self.upper) + return ImageEnhance.Color(img).enhance(e) + + def __call__(self, img: np.ndarray, data: dict): + ops = [self.random_brightness, self.random_contrast, self.random_color] + np.random.shuffle(ops) + img = Image.fromarray(img) + img = ops[0](img) + img = ops[1](img) + img = ops[2](img) + img = np.asarray(img) + + return img, data + + +class RandomExpand: + """Randomly expand images and gt boxes by random ratio. It is a data enhancement operation for model training. + Args: + max_ratio(float): Max value for expansion ratio, default is 4. + fill(list): Initialize the pixel value of the image with the input fill value, default is None. + keep_ratio(bool): Whether image keeps ratio. + thresh(float): If random ratio does not exceed the thresh, return original images and gt boxes, default is 0.5. + + Return: + img(np.ndarray): Distorted image. + data(dict): Image info and label info. + + """ + def __init__(self, max_ratio: float = 4., fill: list = None, keep_ratio: bool = True, thresh: float = 0.5): + + self.max_ratio = max_ratio + self.fill = fill + self.keep_ratio = keep_ratio + self.thresh = thresh + + def __call__(self, img: np.ndarray, data: dict): + gtboxes = data['gt_boxes'] + + if random.random() > self.thresh: + return img, data + if self.max_ratio < 1.0: + return img, data + h, w, c = img.shape + + ratio_x = random.uniform(1, self.max_ratio) + if self.keep_ratio: + ratio_y = ratio_x + else: + ratio_y = random.uniform(1, self.max_ratio) + + oh = int(h * ratio_y) + ow = int(w * ratio_x) + off_x = random.randint(0, ow - w) + off_y = random.randint(0, oh - h) + + out_img = np.zeros((oh, ow, c)) + if self.fill and len(self.fill) == c: + for i in range(c): + out_img[:, :, i] = self.fill[i] * 255.0 + + out_img[off_y:off_y + h, off_x:off_x + w, :] = img + gtboxes[:, 0] = ((gtboxes[:, 0] * w) + off_x) / float(ow) + gtboxes[:, 1] = ((gtboxes[:, 1] * h) + off_y) / float(oh) + gtboxes[:, 2] = gtboxes[:, 2] / ratio_x + gtboxes[:, 3] = gtboxes[:, 3] / ratio_y + data['gt_boxes'] = gtboxes + img = out_img.astype('uint8') + + return img, data + + +class RandomCrop: + """ + Random crop the input image according to constraints. + Args: + scales(list): The value of the cutting area relative to the original area, expressed in the form of \ + [min, max]. The default value is [.3, 1.]. + max_ratio(float): Max ratio of the original area relative to the cutting area, default is 2.0. + constraints(list): The value of min and max iou values, default is None. + max_trial(int): The max trial for finding a valid crop area. The default value is 50. + + Returns: + img(np.ndarray): Distorted image. + data(dict): Image info and label info. + + """ + def __init__(self, + scales: list = [0.3, 1.0], + max_ratio: float = 2.0, + constraints: list = None, + max_trial: int = 50): + self.scales = scales + self.max_ratio = max_ratio + self.constraints = constraints + self.max_trial = max_trial + + def __call__(self, img: np.ndarray, data: dict): + boxes = data['gt_boxes'] + labels = data['gt_labels'] + scores = data['gt_scores'] + + if len(boxes) == 0: + return img, data + if not self.constraints: + self.constraints = [(0.1, 1.0), (0.3, 1.0), (0.5, 1.0), (0.7, 1.0), (0.9, 1.0), (0.0, 1.0)] + + img = Image.fromarray(img) + w, h = img.size + crops = [(0, 0, w, h)] + for min_iou, max_iou in self.constraints: + for _ in range(self.max_trial): + scale = random.uniform(self.scales[0], self.scales[1]) + aspect_ratio = random.uniform(max(1 / self.max_ratio, scale * scale), \ + min(self.max_ratio, 1 / scale / scale)) + crop_h = int(h * scale / np.sqrt(aspect_ratio)) + crop_w = int(w * scale * np.sqrt(aspect_ratio)) + crop_x = random.randrange(w - crop_w) + crop_y = random.randrange(h - crop_h) + crop_box = np.array([[(crop_x + crop_w / 2.0) / w, (crop_y + crop_h / 2.0) / h, crop_w / float(w), + crop_h / float(h)]]) + iou = box_iou_xywh(crop_box, boxes) + if min_iou <= iou.min() and max_iou >= iou.max(): + crops.append((crop_x, crop_y, crop_w, crop_h)) + break + + while crops: + crop = crops.pop(np.random.randint(0, len(crops))) + crop_boxes, crop_labels, crop_scores, box_num = box_crop(boxes, labels, scores, crop, (w, h)) + + if box_num < 1: + continue + img = img.crop((crop[0], crop[1], crop[0] + crop[2], crop[1] + crop[3])).resize(img.size, Image.LANCZOS) + img = np.asarray(img) + data['gt_boxes'] = crop_boxes + data['gt_labels'] = crop_labels + data['gt_scores'] = crop_scores + return img, data + img = np.asarray(img) + data['gt_boxes'] = boxes + data['gt_labels'] = labels + data['gt_scores'] = scores + return img, data + + +class RandomFlip: + """Flip the images and gt boxes randomly. + Args: + thresh: Probability for random flip. + Returns: + img(np.ndarray): Distorted image. + data(dict): Image info and label info. + """ + def __init__(self, thresh: float = 0.5): + self.thresh = thresh + + def __call__(self, img, data): + gtboxes = data['gt_boxes'] + if random.random() > self.thresh: + img = img[:, ::-1, :] + gtboxes[:, 0] = 1.0 - gtboxes[:, 0] + data['gt_boxes'] = gtboxes + return img, data + + +class Compose: + """Preprocess the input data according to the operators. + Args: + transforms(list): Preprocessing operators. + Returns: + img(np.ndarray): Preprocessed image. + data(dict): Image info and label info, default is None. + """ + def __init__(self, transforms: list): + if not isinstance(transforms, list): + raise TypeError('The transforms must be a list!') + if len(transforms) < 1: + raise ValueError('The length of transforms ' + \ + 'must be equal or larger than 1!') + self.transforms = transforms + + def __call__(self, data: dict): + + if isinstance(data, dict): + if isinstance(data['image'], str): + img = cv2.imread(data['image']) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + gt_labels = data['gt_labels'].copy() + data['gt_scores'] = np.ones_like(gt_labels) + for op in self.transforms: + img, data = op(img, data) + img = img.transpose((2, 0, 1)) + return img, data + + if isinstance(data, str): + img = cv2.imread(data) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + for op in self.transforms: + img, data = op(img, data) + img = img.transpose((2, 0, 1)) + return img + + +class Resize: + """Resize the input images. + Args: + target_size(int): Targeted input size. + interp(str): Interpolation method. + Returns: + img(np.ndarray): Preprocessed image. + data(dict): Image info and label info, default is None. + """ + def __init__(self, target_size: int = 512, interp: str = 'RANDOM'): + self.interp_dict = { + 'NEAREST': cv2.INTER_NEAREST, + 'LINEAR': cv2.INTER_LINEAR, + 'CUBIC': cv2.INTER_CUBIC, + 'AREA': cv2.INTER_AREA, + 'LANCZOS4': cv2.INTER_LANCZOS4 + } + self.interp = interp + if not (interp == "RANDOM" or interp in self.interp_dict): + raise ValueError("interp should be one of {}".format(self.interp_dict.keys())) + if isinstance(target_size, list) or isinstance(target_size, tuple): + if len(target_size) != 2: + raise TypeError( + 'when target is list or tuple, it should include 2 elements, but it is {}'.format(target_size)) + elif not isinstance(target_size, int): + raise TypeError("Type of target_size is invalid. Must be Integer or List or tuple, now is {}".format( + type(target_size))) + + self.target_size = target_size + + def __call__(self, img, data=None): + + if self.interp == "RANDOM": + interp = random.choice(list(self.interp_dict.keys())) + else: + interp = self.interp + img = resize(img, self.target_size, self.interp_dict[interp]) + if data is not None: + return img, data + else: + return img + + +class Normalize: + """Normalize the input images. + Args: + mean(list): Mean values for normalization, default is [0.5, 0.5, 0.5]. + std(list): Standard deviation for normalization, default is [0.5, 0.5, 0.5]. + Returns: + img(np.ndarray): Preprocessed image. + data(dict): Image info and label info, default is None. + """ + def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + self.mean = mean + self.std = std + if not (isinstance(self.mean, list) and isinstance(self.std, list)): + raise ValueError("{}: input type is invalid.".format(self)) + from functools import reduce + if reduce(lambda x, y: x * y, self.std) == 0: + raise ValueError('{}: std is invalid!'.format(self)) + + def __call__(self, im, data=None): + if data is not None: + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + im = normalize(im, mean, std) + return im, data + else: + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + im = normalize(im, mean, std) + return im + + +class ShuffleBox: + """Shuffle data information.""" + def __call__(self, img, data): + gt = np.concatenate([data['gt_boxes'], data['gt_labels'][:, np.newaxis], data['gt_scores'][:, np.newaxis]], + axis=1) + idx = np.arange(gt.shape[0]) + np.random.shuffle(idx) + gt = gt[idx, :] + data['gt_boxes'], data['gt_labels'], data['gt_scores'] = gt[:, :4], gt[:, 4], gt[:, 5] + return img, data diff --git a/paddlehub/process/functional.py b/paddlehub/process/functional.py index 633adb5bcd1b75b5bba97291ae064cf1e1cd2222..7b97a004356c64e2d9fd9e6e0f016948fbee9d59 100644 --- a/paddlehub/process/functional.py +++ b/paddlehub/process/functional.py @@ -11,12 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import cv2 +import paddle +import matplotlib import numpy as np from PIL import Image, ImageEnhance +from matplotlib import pyplot as plt + +matplotlib.use('Agg') def normalize(im, mean, std): @@ -117,4 +121,139 @@ def get_img_file(dir_name: str) -> list: print(img_path) images.append(img_path) images.sort() - return images \ No newline at end of file + return images + + +def coco_anno_box_to_center_relative(box: list, img_height: int, img_width: int) -> np.ndarray: + """ + Convert COCO annotations box with format [x1, y1, w, h] to + center mode [center_x, center_y, w, h] and divide image width + and height to get relative value in range[0, 1] + """ + assert len(box) == 4, "box should be a len(4) list or tuple" + x, y, w, h = box + + x1 = max(x, 0) + x2 = min(x + w - 1, img_width - 1) + y1 = max(y, 0) + y2 = min(y + h - 1, img_height - 1) + + x = (x1 + x2) / 2 / img_width + y = (y1 + y2) / 2 / img_height + w = (x2 - x1) / img_width + h = (y2 - y1) / img_height + + return np.array([x, y, w, h]) + + +def box_crop(boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, crop: list, img_shape: list): + """Crop the boxes ,labels, scores according to the given shape""" + + x, y, w, h = map(float, crop) + im_w, im_h = map(float, img_shape) + + boxes = boxes.copy() + boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (boxes[:, 0] + boxes[:, 2] / 2) * im_w + boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (boxes[:, 1] + boxes[:, 3] / 2) * im_h + + crop_box = np.array([x, y, x + w, y + h]) + centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 + mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(axis=1) + + boxes[:, :2] = np.maximum(boxes[:, :2], crop_box[:2]) + boxes[:, 2:] = np.minimum(boxes[:, 2:], crop_box[2:]) + boxes[:, :2] -= crop_box[:2] + boxes[:, 2:] -= crop_box[:2] + + mask = np.logical_and(mask, (boxes[:, :2] < boxes[:, 2:]).all(axis=1)) + boxes = boxes * np.expand_dims(mask.astype('float32'), axis=1) + labels = labels * mask.astype('float32') + scores = scores * mask.astype('float32') + boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (boxes[:, 2] - boxes[:, 0]) / w + boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (boxes[:, 3] - boxes[:, 1]) / h + + return boxes, labels, scores, mask.sum() + + +def box_iou_xywh(box1: np.ndarray, box2: np.ndarray) -> float: + """Calculate iou by xywh""" + + assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4." + assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4." + + b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 + b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 + b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 + b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 + + inter_x1 = np.maximum(b1_x1, b2_x1) + inter_x2 = np.minimum(b1_x2, b2_x2) + inter_y1 = np.maximum(b1_y1, b2_y1) + inter_y2 = np.minimum(b1_y2, b2_y2) + inter_w = inter_x2 - inter_x1 + inter_h = inter_y2 - inter_y1 + inter_w[inter_w < 0] = 0 + inter_h[inter_h < 0] = 0 + + inter_area = inter_w * inter_h + b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + + return inter_area / (b1_area + b2_area - inter_area) + + +def draw_boxes_on_image(image_path: str, + boxes: np.ndarray, + scores: np.ndarray, + labels: np.ndarray, + label_names: list, + score_thresh: float = 0.5): + """Draw boxes on images.""" + image = np.array(Image.open(image_path)) + plt.figure() + _, ax = plt.subplots(1) + ax.imshow(image) + + image_name = image_path.split('/')[-1] + print("Image {} detect: ".format(image_name)) + colors = {} + for box, score, label in zip(boxes, scores, labels): + if score < score_thresh: + continue + if box[2] <= box[0] or box[3] <= box[1]: + continue + label = int(label) + if label not in colors: + colors[label] = plt.get_cmap('hsv')(label / len(label_names)) + x1, y1, x2, y2 = box[0], box[1], box[2], box[3] + rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, linewidth=2.0, edgecolor=colors[label]) + ax.add_patch(rect) + ax.text(x1, + y1, + '{} {:.4f}'.format(label_names[label], score), + verticalalignment='bottom', + horizontalalignment='left', + bbox={ + 'facecolor': colors[label], + 'alpha': 0.5, + 'pad': 0 + }, + fontsize=8, + color='white') + print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(label)], str(list(map(int, list(box)))), score)) + image_name = image_name.replace('jpg', 'png') + plt.axis('off') + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0) + print("Detect result save at ./output/{}\n".format(image_name)) + plt.cla() + plt.close('all') + + +def img_shape(img_path: str): + """Get image shape.""" + im = cv2.imread(img_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + h, w, c = im.shape + return h, w, c diff --git a/paddlehub/process/transforms.py b/paddlehub/process/transforms.py index 9e301b060548cb80a57abd9a02a10754cdda98fa..80821267424dceddfdc9aff4cbccbc78ca47e639 100644 --- a/paddlehub/process/transforms.py +++ b/paddlehub/process/transforms.py @@ -22,12 +22,9 @@ import cv2 import numpy as np import matplotlib from PIL import Image, ImageEnhance -from matplotlib import pyplot as plt from paddlehub.process.functional import * -matplotlib.use('Agg') - class Compose: def __init__(self, transforms, to_rgb=True, stay_rgb=False): @@ -737,505 +734,3 @@ class ColorPostprocess: img = np.clip(img, 0, 1) * 255 img = img.astype(self.type) return img - - -class DetectCatagory: - """Load label name, id and map from detection dataset. - - Args: - COCO(Callable): Method for get detection attributes for images. - data_dir(str): Image dataset path. - - Returns: - label_names(List(str)): The dataset label names. - label_ids(List(int)): The dataset label ids. - category_to_id_map(dict): Mapping relations of category and id for images. - """ - def __init__(self, COCO: Callable, data_dir: str): - self.COCO = COCO - self.img_dir = data_dir - - def __call__(self): - self.categories = self.COCO.loadCats(self.COCO.getCatIds()) - self.num_category = len(self.categories) - label_names = [] - label_ids = [] - for category in self.categories: - label_names.append(category['name']) - label_ids.append(int(category['id'])) - category_to_id_map = {v: i for i, v in enumerate(label_ids)} - return label_names, label_ids, category_to_id_map - - -class ParseImages: - """Prepare images for detection. - - Args: - COCO(Callable): Method for get detection attributes for images. - is_train(bool): Select the mode for train or test. - data_dir(str): Image dataset path. - category_to_id_map(dict): Mapping relations of category and id for images. - - Returns: - imgs(dict): The input for detection model, it is a dict. - """ - def __init__(self, COCO: Callable, is_train: bool, data_dir: str, category_to_id_map: dict): - self.COCO = COCO - self.is_train = is_train - self.img_dir = data_dir - self.category_to_id_map = category_to_id_map - self.parse_gt_annotations = GTAnotations(self.COCO, self.category_to_id_map) - - def __call__(self): - image_ids = self.COCO.getImgIds() - image_ids.sort() - imgs = copy.deepcopy(self.COCO.loadImgs(image_ids)) - - for img in imgs: - img['image'] = os.path.join(self.img_dir, img['file_name']) - assert os.path.exists(img['image']), \ - "image {} not found.".format(img['image']) - box_num = 50 - img['gt_boxes'] = np.zeros((box_num, 4), dtype=np.float32) - img['gt_labels'] = np.zeros((box_num), dtype=np.int32) - if self.is_train: - img = self.parse_gt_annotations(img) - - return imgs - - -class GTAnotations: - """Set gt boxes and gt labels for train. - - Args: - COCO(Callable): Method for get detection attributes for images. - category_to_id_map(dict): Mapping relations of category and id for images. - img(dict): Input for detection model. - - Returns: - img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input. - """ - def __init__(self, COCO: Callable, category_to_id_map: dict): - self.COCO = COCO - self.category_to_id_map = category_to_id_map - self.boxtool = BoxTool() - - def __call__(self, img: dict): - img_height = img['height'] - img_width = img['width'] - anno = self.COCO.loadAnns(self.COCO.getAnnIds(imgIds=img['id'], iscrowd=None)) - gt_index = 0 - - for target in anno: - if target['area'] < -1: - continue - if 'ignore' in target and target['ignore']: - continue - - box = self.boxtool.coco_anno_box_to_center_relative(target['bbox'], img_height, img_width) - if box[2] <= 0 and box[3] <= 0: - continue - img['gt_boxes'][gt_index] = box - img['gt_labels'][gt_index] = \ - self.category_to_id_map[target['category_id']] - gt_index += 1 - if gt_index >= 50: - break - return img - - -class DetectTestReader: - """Preprocess for detection dataset on test mode. - - Args: - mean(list): Mean values for normalization, default is [0.485, 0.456, 0.406]. - std(list): Standard deviation for normalization, default is [0.229, 0.224, 0.225]. - img(dict): Prepared input for detection model. - size(int): Image size for detection. - - Returns: - out_img(np.ndarray): Normalized image, shape is [C, H, W]. - id(int): Id number for corresponding out_img. - (h, w)(tuple): height and weight for corresponding out_img. - """ - def __init__(self, mean: list = [0.485, 0.456, 0.406], std: list = [0.229, 0.224, 0.225]): - self.mean = mean - self.std = std - - def __call__(self, img, size): - im_path = img['image'] - im = cv2.imread(im_path).astype('float32') - im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) - h, w, _ = im.shape - im_scale_x = size / float(w) - im_scale_y = size / float(h) - - out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC) - - mean = np.array(self.mean).reshape((1, 1, -1)) - std = np.array(self.std).reshape((1, 1, -1)) - out_img = (out_img / 255.0 - mean) / std - out_img = out_img.transpose((2, 0, 1)) - id = int(img['id']) - return out_img, id, (h, w) - - -class DetectTrainReader: - """Preprocess for detection dataset on train mode. - - Args: - mean(list): Mean values for normalization, default is [0.485, 0.456, 0.406]. - std(list): Standard deviation for normalization, default is [0.229, 0.224, 0.225]. - img(dict): Prepared input for detection model. - size(int): Image size for detection. - - Returns: - out_img(np.ndarray): Normalized image, shape is [C, H, W]. - gt_boxes(np.ndarray): Ground truth boxes information. - gt_labels(np.ndarray): Ground truth labels. - gt_scores(np.ndarray): Ground truth scores. - """ - def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): - self.mean = mean - self.std = std - self.boxtool = BoxTool() - - def __call__(self, img, size): - im_path = img['image'] - im = cv2.imread(im_path) - im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) - gt_boxes = img['gt_boxes'].copy() - gt_labels = img['gt_labels'].copy() - gt_scores = np.ones_like(gt_labels) - im, gt_boxes, gt_labels, gt_scores = self.boxtool.image_augment(im, gt_boxes, gt_labels, gt_scores, size, - self.mean) - mean = np.array(self.mean).reshape((1, 1, -1)) - std = np.array(self.std).reshape((1, 1, -1)) - out_img = (im / 255.0 - mean) / std - out_img = out_img.astype('float32').transpose((2, 0, 1)) - return out_img, gt_boxes, gt_labels, gt_scores - - -class BoxTool: - """This class provides common methods for box processing in detection tasks.""" - def __init__(self): - super(BoxTool, self).__init__() - - def coco_anno_box_to_center_relative(self, box: list, img_height: int, img_width: int) -> np.ndarray: - """ - Convert COCO annotations box with format [x1, y1, w, h] to - center mode [center_x, center_y, w, h] and divide image width - and height to get relative value in range[0, 1] - """ - assert len(box) == 4, "box should be a len(4) list or tuple" - x, y, w, h = box - - x1 = max(x, 0) - x2 = min(x + w - 1, img_width - 1) - y1 = max(y, 0) - y2 = min(y + h - 1, img_height - 1) - - x = (x1 + x2) / 2 / img_width - y = (y1 + y2) / 2 / img_height - w = (x2 - x1) / img_width - h = (y2 - y1) / img_height - - return np.array([x, y, w, h]) - - def clip_relative_box_in_image(self, x: int, y: int, w: int, h: int) -> int: - """Clip relative box coordinates x, y, w, h to [0, 1]""" - - x1 = max(x - w / 2, 0.) - x2 = min(x + w / 2, 1.) - y1 = min(y - h / 2, 0.) - y2 = max(y + h / 2, 1.) - x = (x1 + x2) / 2 - y = (y1 + y2) / 2 - w = x2 - x1 - h = y2 - y1 - return x, y, w, h - - def box_xywh_to_xyxy(self, box: np.ndarray) -> np.ndarray: - """Change box from xywh to xyxy""" - - shape = box.shape - assert shape[-1] == 4, "Box shape[-1] should be 4." - - box = box.reshape((-1, 4)) - box[:, 0], box[:, 2] = box[:, 0] - box[:, 2] / 2, box[:, 0] + box[:, 2] / 2 - box[:, 1], box[:, 3] = box[:, 1] - box[:, 3] / 2, box[:, 1] + box[:, 3] / 2 - box = box.reshape(shape) - return box - - def box_iou_xywh(self, box1: np.ndarray, box2: np.ndarray) -> float: - """Calculate iou by xywh""" - - assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4." - assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4." - - b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 - b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 - b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 - b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 - - inter_x1 = np.maximum(b1_x1, b2_x1) - inter_x2 = np.minimum(b1_x2, b2_x2) - inter_y1 = np.maximum(b1_y1, b2_y1) - inter_y2 = np.minimum(b1_y2, b2_y2) - inter_w = inter_x2 - inter_x1 - inter_h = inter_y2 - inter_y1 - inter_w[inter_w < 0] = 0 - inter_h[inter_h < 0] = 0 - - inter_area = inter_w * inter_h - b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) - b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) - - return inter_area / (b1_area + b2_area - inter_area) - - def box_iou_xyxy(self, box1: np.ndarray, box2: np.ndarray) -> float: - """Calculate iou by xyxy""" - - assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4." - assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4." - - b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] - b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] - - inter_x1 = np.maximum(b1_x1, b2_x1) - inter_x2 = np.minimum(b1_x2, b2_x2) - inter_y1 = np.maximum(b1_y1, b2_y1) - inter_y2 = np.minimum(b1_y2, b2_y2) - inter_w = inter_x2 - inter_x1 - inter_h = inter_y2 - inter_y1 - inter_w[inter_w < 0] = 0 - inter_h[inter_h < 0] = 0 - - inter_area = inter_w * inter_h - b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) - b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) - - return inter_area / (b1_area + b2_area - inter_area) - - def box_crop(self, boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, crop: list, img_shape: list): - """Crop the boxes ,labels, scores according to the given shape""" - - x, y, w, h = map(float, crop) - im_w, im_h = map(float, img_shape) - - boxes = boxes.copy() - boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (boxes[:, 0] + boxes[:, 2] / 2) * im_w - boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (boxes[:, 1] + boxes[:, 3] / 2) * im_h - - crop_box = np.array([x, y, x + w, y + h]) - centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 - mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(axis=1) - - boxes[:, :2] = np.maximum(boxes[:, :2], crop_box[:2]) - boxes[:, 2:] = np.minimum(boxes[:, 2:], crop_box[2:]) - boxes[:, :2] -= crop_box[:2] - boxes[:, 2:] -= crop_box[:2] - - mask = np.logical_and(mask, (boxes[:, :2] < boxes[:, 2:]).all(axis=1)) - boxes = boxes * np.expand_dims(mask.astype('float32'), axis=1) - labels = labels * mask.astype('float32') - scores = scores * mask.astype('float32') - boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (boxes[:, 2] - boxes[:, 0]) / w - boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (boxes[:, 3] - boxes[:, 1]) / h - - return boxes, labels, scores, mask.sum() - - def random_distort(self, img): - """ Distort the input image randomly.""" - def random_brightness(img, lower=0.5, upper=1.5): - e = np.random.uniform(lower, upper) - return ImageEnhance.Brightness(img).enhance(e) - - def random_contrast(img, lower=0.5, upper=1.5): - e = np.random.uniform(lower, upper) - return ImageEnhance.Contrast(img).enhance(e) - - def random_color(img, lower=0.5, upper=1.5): - e = np.random.uniform(lower, upper) - return ImageEnhance.Color(img).enhance(e) - - ops = [random_brightness, random_contrast, random_color] - np.random.shuffle(ops) - - img = Image.fromarray(img) - img = ops[0](img) - img = ops[1](img) - img = ops[2](img) - img = np.asarray(img) - - return img - - def random_crop(self, img, boxes, labels, scores, scales=[0.3, 1.0], max_ratio=2.0, constraints=None, max_trial=50): - """Random crop the input image according to constraints.""" - if len(boxes) == 0: - return img, boxes - - if not constraints: - constraints = [(0.1, 1.0), (0.3, 1.0), (0.5, 1.0), (0.7, 1.0), (0.9, 1.0), (0.0, 1.0)] - - img = Image.fromarray(img) - w, h = img.size - crops = [(0, 0, w, h)] - for min_iou, max_iou in constraints: - for _ in range(max_trial): - scale = random.uniform(scales[0], scales[1]) - aspect_ratio = random.uniform(max(1 / max_ratio, scale * scale), \ - min(max_ratio, 1 / scale / scale)) - crop_h = int(h * scale / np.sqrt(aspect_ratio)) - crop_w = int(w * scale * np.sqrt(aspect_ratio)) - crop_x = random.randrange(w - crop_w) - crop_y = random.randrange(h - crop_h) - crop_box = np.array([[(crop_x + crop_w / 2.0) / w, (crop_y + crop_h / 2.0) / h, crop_w / float(w), - crop_h / float(h)]]) - - iou = self.box_iou_xywh(crop_box, boxes) - if min_iou <= iou.min() and max_iou >= iou.max(): - crops.append((crop_x, crop_y, crop_w, crop_h)) - break - - while crops: - crop = crops.pop(np.random.randint(0, len(crops))) - crop_boxes, crop_labels, crop_scores, box_num = \ - self.box_crop(boxes, labels, scores, crop, (w, h)) - if box_num < 1: - continue - img = img.crop((crop[0], crop[1], crop[0] + crop[2], crop[1] + crop[3])).resize(img.size, Image.LANCZOS) - img = np.asarray(img) - return img, crop_boxes, crop_labels, crop_scores - img = np.asarray(img) - return img, boxes, labels, scores - - def random_flip(self, img, gtboxes, thresh=0.5): - """Flip the images randomly""" - if random.random() > thresh: - img = img[:, ::-1, :] - gtboxes[:, 0] = 1.0 - gtboxes[:, 0] - return img, gtboxes - - def random_interp(self, img, size, interp=None): - interp_method = [ - cv2.INTER_NEAREST, - cv2.INTER_LINEAR, - cv2.INTER_AREA, - cv2.INTER_CUBIC, - cv2.INTER_LANCZOS4, - ] - if not interp or interp not in interp_method: - interp = interp_method[random.randint(0, len(interp_method) - 1)] - h, w, _ = img.shape - im_scale_x = size / float(w) - im_scale_y = size / float(h) - img = cv2.resize(img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp) - return img - - def random_expand(self, img, gtboxes, max_ratio=4., fill=None, keep_ratio=True, thresh=0.5): - """Expand input image and ground truth box by random ratio.""" - if random.random() > thresh: - return img, gtboxes - - if max_ratio < 1.0: - return img, gtboxes - - h, w, c = img.shape - ratio_x = random.uniform(1, max_ratio) - if keep_ratio: - ratio_y = ratio_x - else: - ratio_y = random.uniform(1, max_ratio) - oh = int(h * ratio_y) - ow = int(w * ratio_x) - off_x = random.randint(0, ow - w) - off_y = random.randint(0, oh - h) - - out_img = np.zeros((oh, ow, c)) - if fill and len(fill) == c: - for i in range(c): - out_img[:, :, i] = fill[i] * 255.0 - - out_img[off_y:off_y + h, off_x:off_x + w, :] = img - gtboxes[:, 0] = ((gtboxes[:, 0] * w) + off_x) / float(ow) - gtboxes[:, 1] = ((gtboxes[:, 1] * h) + off_y) / float(oh) - gtboxes[:, 2] = gtboxes[:, 2] / ratio_x - gtboxes[:, 3] = gtboxes[:, 3] / ratio_y - - return out_img.astype('uint8'), gtboxes - - def shuffle_gtbox(self, gtbox, gtlabel, gtscore): - """Shuffle gt box.""" - - gt = np.concatenate([gtbox, gtlabel[:, np.newaxis], gtscore[:, np.newaxis]], axis=1) - idx = np.arange(gt.shape[0]) - np.random.shuffle(idx) - gt = gt[idx, :] - return gt[:, :4], gt[:, 4], gt[:, 5] - - def image_augment(self, img, gtboxes, gtlabels, gtscores, size, means=None): - """Random processes for input image.""" - - img = self.random_distort(img) - img, gtboxes = self.random_expand(img, gtboxes, fill=means) - img, gtboxes, gtlabels, gtscores = \ - self.random_crop(img, gtboxes, gtlabels, gtscores) - img = self.random_interp(img, size) - img, gtboxes = self.random_flip(img, gtboxes) - gtboxes, gtlabels, gtscores = self.shuffle_gtbox(gtboxes, gtlabels, gtscores) - - return img.astype('float32'), gtboxes.astype('float32'), \ - gtlabels.astype('int32'), gtscores.astype('float32') - - def draw_boxes_on_image(self, - image_path: str, - boxes: np.ndarray, - scores: np.ndarray, - labels: np.ndarray, - label_names: list, - score_thresh: float = 0.5): - """Draw boxes on images""" - - image = np.array(Image.open(image_path)) - plt.figure() - _, ax = plt.subplots(1) - ax.imshow(image) - - image_name = image_path.split('/')[-1] - print("Image {} detect: ".format(image_name)) - colors = {} - - for box, score, label in zip(boxes, scores, labels): - if score < score_thresh: - continue - if box[2] <= box[0] or box[3] <= box[1]: - continue - label = int(label) - if label not in colors: - colors[label] = plt.get_cmap('hsv')(label / len(label_names)) - x1, y1, x2, y2 = box[0], box[1], box[2], box[3] - rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, linewidth=2.0, edgecolor=colors[label]) - ax.add_patch(rect) - ax.text(x1, - y1, - '{} {:.4f}'.format(label_names[label], score), - verticalalignment='bottom', - horizontalalignment='left', - bbox={ - 'facecolor': colors[label], - 'alpha': 0.5, - 'pad': 0 - }, - fontsize=8, - color='white') - print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(label)], str(list(map(int, list(box)))), - score)) - image_name = image_name.replace('jpg', 'png') - plt.axis('off') - plt.gca().xaxis.set_major_locator(plt.NullLocator()) - plt.gca().yaxis.set_major_locator(plt.NullLocator()) - plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0) - print("Detect result save at ./output/{}\n".format(image_name)) - plt.cla() - plt.close('all')