image_util.py 5.9 KB
Newer Older
Y
yangyaming 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
from PIL import Image
import numpy as np
import random
import math


class sampler():
    def __init__(self, max_sample, max_trial, min_scale, max_scale,
                 min_aspect_ratio, max_aspect_ratio, min_jaccard_overlap,
                 max_jaccard_overlap):
        self.max_sample = max_sample
        self.max_trial = max_trial
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.min_aspect_ratio = min_aspect_ratio
        self.max_aspect_ratio = max_aspect_ratio
        self.min_jaccard_overlap = min_jaccard_overlap
        self.max_jaccard_overlap = max_jaccard_overlap


class bbox():
    def __init__(self, xmin, ymin, xmax, ymax):
        self.xmin = xmin
        self.ymin = ymin
        self.xmax = xmax
        self.ymax = ymax


29
def bbox_area(src_bbox):
Y
yangyaming 已提交
30 31 32 33 34
    width = src_bbox.xmax - src_bbox.xmin
    height = src_bbox.ymax - src_bbox.ymin
    return width * height


35
def generate_sample(sampler):
Y
yangyaming 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    scale = random.uniform(sampler.min_scale, sampler.max_scale)
    min_aspect_ratio = max(sampler.min_aspect_ratio, (scale**2.0))
    max_aspect_ratio = min(sampler.max_aspect_ratio, 1 / (scale**2.0))
    aspect_ratio = random.uniform(min_aspect_ratio, max_aspect_ratio)
    bbox_width = scale * (aspect_ratio**0.5)
    bbox_height = scale / (aspect_ratio**0.5)
    xmin_bound = 1 - bbox_width
    ymin_bound = 1 - bbox_height
    xmin = random.uniform(0, xmin_bound)
    ymin = random.uniform(0, ymin_bound)
    xmax = xmin + bbox_width
    ymax = ymin + bbox_height
    sampled_bbox = bbox(xmin, ymin, xmax, ymax)
    return sampled_bbox


52
def jaccard_overlap(sample_bbox, object_bbox):
Y
yangyaming 已提交
53 54 55 56 57 58 59 60 61 62 63
    if sample_bbox.xmin >= object_bbox.xmax or \
            sample_bbox.xmax <= object_bbox.xmin or \
            sample_bbox.ymin >= object_bbox.ymax or \
            sample_bbox.ymax <= object_bbox.ymin:
        return 0
    intersect_xmin = max(sample_bbox.xmin, object_bbox.xmin)
    intersect_ymin = max(sample_bbox.ymin, object_bbox.ymin)
    intersect_xmax = min(sample_bbox.xmax, object_bbox.xmax)
    intersect_ymax = min(sample_bbox.ymax, object_bbox.ymax)
    intersect_size = (intersect_xmax - intersect_xmin) * (
        intersect_ymax - intersect_ymin)
64 65
    sample_bbox_size = bbox_area(sample_bbox)
    object_bbox_size = bbox_area(object_bbox)
Y
yangyaming 已提交
66 67 68 69 70
    overlap = intersect_size / (
        sample_bbox_size + object_bbox_size - intersect_size)
    return overlap


71
def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
Y
yangyaming 已提交
72 73 74 75 76
    if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0:
        return True
    for i in range(len(bbox_labels)):
        object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
                           bbox_labels[i][3], bbox_labels[i][4])
77
        overlap = jaccard_overlap(sample_bbox, object_bbox)
Y
yangyaming 已提交
78 79 80 81 82 83 84 85 86 87
        if sampler.min_jaccard_overlap != 0 and \
                overlap < sampler.min_jaccard_overlap:
            continue
        if sampler.max_jaccard_overlap != 0 and \
                overlap > sampler.max_jaccard_overlap:
            continue
        return True
    return False


88 89
def generate_batch_samples(batch_sampler, bbox_labels, image_width,
                           image_height):
Y
yangyaming 已提交
90 91 92 93 94 95 96 97
    sampled_bbox = []
    index = []
    c = 0
    for sampler in batch_sampler:
        found = 0
        for i in range(sampler.max_trial):
            if found >= sampler.max_sample:
                break
98 99
            sample_bbox = generate_sample(sampler)
            if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
Y
yangyaming 已提交
100 101 102 103 104 105 106
                sampled_bbox.append(sample_bbox)
                found = found + 1
                index.append(c)
        c = c + 1
    return sampled_bbox


107
def clip_bbox(src_bbox):
Y
yangyaming 已提交
108 109 110 111 112 113 114
    src_bbox.xmin = max(min(src_bbox.xmin, 1.0), 0.0)
    src_bbox.ymin = max(min(src_bbox.ymin, 1.0), 0.0)
    src_bbox.xmax = max(min(src_bbox.xmax, 1.0), 0.0)
    src_bbox.ymax = max(min(src_bbox.ymax, 1.0), 0.0)
    return src_bbox


115
def meet_emit_constraint(src_bbox, sample_bbox):
Y
yangyaming 已提交
116 117 118 119 120 121 122 123 124 125
    center_x = (src_bbox.xmax + src_bbox.xmin) / 2
    center_y = (src_bbox.ymax + src_bbox.ymin) / 2
    if center_x >= sample_bbox.xmin and \
        center_x <= sample_bbox.xmax and \
        center_y >= sample_bbox.ymin and \
        center_y <= sample_bbox.ymax:
        return True
    return False


126
def transform_labels(bbox_labels, sample_bbox):
Y
yangyaming 已提交
127 128 129 130 131 132
    proj_bbox = bbox(0, 0, 0, 0)
    sample_labels = []
    for i in range(len(bbox_labels)):
        sample_label = []
        object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
                           bbox_labels[i][3], bbox_labels[i][4])
133
        if not meet_emit_constraint(object_bbox, sample_bbox):
Y
yangyaming 已提交
134 135 136 137 138 139 140
            continue
        sample_width = sample_bbox.xmax - sample_bbox.xmin
        sample_height = sample_bbox.ymax - sample_bbox.ymin
        proj_bbox.xmin = (object_bbox.xmin - sample_bbox.xmin) / sample_width
        proj_bbox.ymin = (object_bbox.ymin - sample_bbox.ymin) / sample_height
        proj_bbox.xmax = (object_bbox.xmax - sample_bbox.xmin) / sample_width
        proj_bbox.ymax = (object_bbox.ymax - sample_bbox.ymin) / sample_height
141 142
        proj_bbox = clip_bbox(proj_bbox)
        if bbox_area(proj_bbox) > 0:
Y
yangyaming 已提交
143 144 145 146 147 148 149 150 151 152
            sample_label.append(bbox_labels[i][0])
            sample_label.append(float(proj_bbox.xmin))
            sample_label.append(float(proj_bbox.ymin))
            sample_label.append(float(proj_bbox.xmax))
            sample_label.append(float(proj_bbox.ymax))
            sample_label.append(bbox_labels[i][5])
            sample_labels.append(sample_label)
    return sample_labels


153 154
def crop_image(img, bbox_labels, sample_bbox, image_width, image_height):
    sample_bbox = clip_bbox(sample_bbox)
Y
yangyaming 已提交
155 156 157 158 159
    xmin = int(sample_bbox.xmin * image_width)
    xmax = int(sample_bbox.xmax * image_width)
    ymin = int(sample_bbox.ymin * image_height)
    ymax = int(sample_bbox.ymax * image_height)
    sample_img = img[ymin:ymax, xmin:xmax]
160
    sample_labels = transform_labels(bbox_labels, sample_bbox)
Y
yangyaming 已提交
161
    return sample_img, sample_labels