transforms.py 7.9 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import traceback
import numpy as np
from PIL import Image

import logging
logger = logging.getLogger(__name__)

23 24 25 26
__all__ = [
    'GroupScale', 'GroupMultiScaleCrop', 'GroupRandomCrop', 'GroupRandomFlip',
    'GroupCenterCrop', 'NormalizeImage'
]
D
dengkaipeng 已提交
27 28 29 30 31 32 33 34 35


class GroupScale(object):
    """
    Group scale image

    Args:
        target_size (int): image resize target size
    """
36

D
dengkaipeng 已提交
37 38 39 40
    def __init__(self, target_size=224):
        self.target_size = target_size

    def __call__(self, imgs, label):
D
dengkaipeng 已提交
41 42 43 44 45
        resized_imgs = []
        for i in range(len(imgs)):
            img = imgs[i]
            w, h = img.size
            if (w <= h and w == self.target_size) or \
D
dengkaipeng 已提交
46
                    (h <= w and h == self.target_size):
D
dengkaipeng 已提交
47 48
                resized_imgs.append(img)
                continue
D
dengkaipeng 已提交
49

D
dengkaipeng 已提交
50 51 52 53 54 55 56 57
            if w < h:
                ow = self.target_size
                oh = int(self.target_size * 4.0 / 3.0)
                resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
            else:
                oh = self.target_size
                ow = int(self.target_size * 4.0 / 3.0)
                resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
58

D
dengkaipeng 已提交
59
        return resized_imgs, label
D
dengkaipeng 已提交
60 61 62 63 64 65


class GroupMultiScaleCrop(object):
    """
    FIXME: add comments
    """
66

D
dengkaipeng 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80
    def __init__(self,
                 short_size=256,
                 scales=None,
                 max_distort=1,
                 fix_crop=True,
                 more_fix_crop=True):
        self.short_size = short_size
        self.scales = scales if scales is not None \
                        else [1, .875, .75, .66]
        self.max_distort = max_distort
        self.fix_crop = fix_crop
        self.more_fix_crop = more_fix_crop

    def __call__(self, imgs, label):
D
dengkaipeng 已提交
81 82
        input_size = [self.short_size, self.short_size]
        im_size = imgs[0].size
83

D
dengkaipeng 已提交
84 85 86
        # get random crop offset
        def _sample_crop_size(im_size):
            image_w, image_h = im_size[0], im_size[1]
87

D
dengkaipeng 已提交
88 89 90 91 92 93 94 95 96 97
            base_size = min(image_w, image_h)
            crop_sizes = [int(base_size * x) for x in self.scales]
            crop_h = [
                input_size[1] if abs(x - input_size[1]) < 3 else x
                for x in crop_sizes
            ]
            crop_w = [
                input_size[0] if abs(x - input_size[0]) < 3 else x
                for x in crop_sizes
            ]
98

D
dengkaipeng 已提交
99 100 101 102 103 104 105 106 107 108 109 110
            pairs = []
            for i, h in enumerate(crop_h):
                for j, w in enumerate(crop_w):
                    if abs(i - j) <= self.max_distort:
                        pairs.append((w, h))
            crop_pair = random.choice(pairs)
            if not self.fix_crop:
                w_offset = np.random.randint(0, image_w - crop_pair[0])
                h_offset = np.random.randint(0, image_h - crop_pair[1])
            else:
                w_step = (image_w - crop_pair[0]) / 4
                h_step = (image_h - crop_pair[1]) / 4
111

D
dengkaipeng 已提交
112 113 114 115 116 117 118 119 120 121
                ret = list()
                ret.append((0, 0))  # upper left
                if w_step != 0:
                    ret.append((4 * w_step, 0))  # upper right
                if h_step != 0:
                    ret.append((0, 4 * h_step))  # lower left
                if h_step != 0 and w_step != 0:
                    ret.append((4 * w_step, 4 * h_step))  # lower right
                if h_step != 0 or w_step != 0:
                    ret.append((2 * w_step, 2 * h_step))  # center
122

D
dengkaipeng 已提交
123 124 125 126 127
                if self.more_fix_crop:
                    ret.append((0, 2 * h_step))  # center left
                    ret.append((4 * w_step, 2 * h_step))  # center right
                    ret.append((2 * w_step, 4 * h_step))  # lower center
                    ret.append((2 * w_step, 0 * h_step))  # upper center
128

D
dengkaipeng 已提交
129 130 131 132
                    ret.append((1 * w_step, 1 * h_step))  # upper left quarter
                    ret.append((3 * w_step, 1 * h_step))  # upper right quarter
                    ret.append((1 * w_step, 3 * h_step))  # lower left quarter
                    ret.append((3 * w_step, 3 * h_step))  # lower righ quarter
133

D
dengkaipeng 已提交
134
                w_offset, h_offset = random.choice(ret)
135

D
dengkaipeng 已提交
136
            return crop_pair[0], crop_pair[1], w_offset, h_offset
137

D
dengkaipeng 已提交
138 139
        crop_w, crop_h, offset_w, offset_h = _sample_crop_size(im_size)
        crop_imgs = [
140 141
            img.crop(
                (offset_w, offset_h, offset_w + crop_w, offset_h + crop_h))
D
dengkaipeng 已提交
142 143 144 145 146 147 148 149
            for img in imgs
        ]
        ret_imgs = [
            img.resize((input_size[0], input_size[1]), Image.BILINEAR)
            for img in crop_imgs
        ]

        return ret_imgs, label
D
dengkaipeng 已提交
150 151 152 153 154 155 156


class GroupRandomCrop(object):
    def __init__(self, target_size=224):
        self.target_size = target_size

    def __call__(self, imgs, label):
D
dengkaipeng 已提交
157 158
        w, h = imgs[0].size
        th, tw = self.target_size, self.target_size
159

D
dengkaipeng 已提交
160 161 162
        assert (w >= self.target_size) and (h >= self.target_size), \
            "image width({}) and height({}) should be larger than " \
            "crop size".format(w, h, self.target_size)
163

D
dengkaipeng 已提交
164 165 166
        out_images = []
        x1 = np.random.randint(0, w - tw)
        y1 = np.random.randint(0, h - th)
167

D
dengkaipeng 已提交
168 169 170 171 172
        for img in imgs:
            if w == tw and h == th:
                out_images.append(img)
            else:
                out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
173

D
dengkaipeng 已提交
174
        return out_images, label
D
dengkaipeng 已提交
175 176 177 178


class GroupRandomFlip(object):
    def __call__(self, imgs, label):
D
dengkaipeng 已提交
179 180 181 182 183
        v = np.random.random()
        if v < 0.5:
            ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs]
            return ret, label
        else:
184
            return imgs, label
D
dengkaipeng 已提交
185 186 187 188 189 190 191


class GroupCenterCrop(object):
    def __init__(self, target_size=224):
        self.target_size = target_size

    def __call__(self, imgs, label):
D
dengkaipeng 已提交
192 193 194 195 196 197 198 199 200 201
        crop_imgs = []
        for img in imgs:
            w, h = img.size
            th, tw = self.target_size, self.target_size
            assert (w >= self.target_size) and (h >= self.target_size), \
                "image width({}) and height({}) should be larger " \
                "than crop size".format(w, h, self.target_size)
            x1 = int(round((w - tw) / 2.))
            y1 = int(round((h - th) / 2.))
            crop_imgs.append(img.crop((x1, y1, x1 + tw, y1 + th)))
202 203

        return crop_imgs, label
D
dengkaipeng 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219


class NormalizeImage(object):
    def __init__(self,
                 target_size=224,
                 img_mean=[0.485, 0.456, 0.406],
                 img_std=[0.229, 0.224, 0.225],
                 seg_num=8,
                 seg_len=1):
        self.target_size = target_size
        self.img_mean = np.array(img_mean).reshape((3, 1, 1)).astype('float32')
        self.img_std = np.array(img_std).reshape((3, 1, 1)).astype('float32')
        self.seg_num = seg_num
        self.seg_len = seg_len

    def __call__(self, imgs, label):
D
dengkaipeng 已提交
220
        np_imgs = (np.array(imgs[0]).astype('float32').transpose(
221
            (2, 0, 1))).reshape(1, 3, self.target_size, self.target_size) / 255
D
dengkaipeng 已提交
222 223 224
        for i in range(len(imgs) - 1):
            img = (np.array(imgs[i + 1]).astype('float32').transpose(
                (2, 0, 1))).reshape(1, 3, self.target_size,
225
                                    self.target_size) / 255
D
dengkaipeng 已提交
226
            np_imgs = np.concatenate((np_imgs, img))
227

D
dengkaipeng 已提交
228 229 230
        np_imgs -= self.img_mean
        np_imgs /= self.img_std
        np_imgs = np.reshape(np_imgs, (self.seg_num, self.seg_len * 3,
231 232
                                       self.target_size, self.target_size))

D
dengkaipeng 已提交
233
        return np_imgs, label