randaugment.py 5.3 KB
Newer Older
Z
zhoujun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
import six


class RawRandAugment(object):
    def __init__(self,
                 num_layers=2,
                 magnitude=5,
                 fillcolor=(128, 128, 128),
                 **kwargs):
        self.num_layers = num_layers
        self.magnitude = magnitude
        self.max_level = 10

        abso_level = self.magnitude / self.max_level
        self.level_map = {
            "shearX": 0.3 * abso_level,
            "shearY": 0.3 * abso_level,
            "translateX": 150.0 / 331 * abso_level,
            "translateY": 150.0 / 331 * abso_level,
            "rotate": 30 * abso_level,
            "color": 0.9 * abso_level,
            "posterize": int(4.0 * abso_level),
            "solarize": 256.0 * abso_level,
            "contrast": 0.9 * abso_level,
            "sharpness": 0.9 * abso_level,
            "brightness": 0.9 * abso_level,
            "autocontrast": 0,
            "equalize": 0,
            "invert": 0
        }

        # from https://stackoverflow.com/questions/5252170/
        # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
        def rotate_with_fill(img, magnitude):
            rot = img.convert("RGBA").rotate(magnitude)
            return Image.composite(rot,
                                   Image.new("RGBA", rot.size, (128, ) * 4),
                                   rot).convert(img.mode)

        rnd_ch_op = random.choice

        self.func = {
            "shearX": lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
                Image.BICUBIC,
                fillcolor=fillcolor),
            "shearY": lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
                Image.BICUBIC,
                fillcolor=fillcolor),
            "translateX": lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
                fillcolor=fillcolor),
            "translateY": lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
                fillcolor=fillcolor),
            "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
            "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
                1 + magnitude * rnd_ch_op([-1, 1])),
            "posterize": lambda img, magnitude:
            ImageOps.posterize(img, magnitude),
            "solarize": lambda img, magnitude:
            ImageOps.solarize(img, magnitude),
            "contrast": lambda img, magnitude:
            ImageEnhance.Contrast(img).enhance(
                1 + magnitude * rnd_ch_op([-1, 1])),
            "sharpness": lambda img, magnitude:
            ImageEnhance.Sharpness(img).enhance(
                1 + magnitude * rnd_ch_op([-1, 1])),
            "brightness": lambda img, magnitude:
            ImageEnhance.Brightness(img).enhance(
                1 + magnitude * rnd_ch_op([-1, 1])),
            "autocontrast": lambda img, magnitude:
            ImageOps.autocontrast(img),
            "equalize": lambda img, magnitude: ImageOps.equalize(img),
            "invert": lambda img, magnitude: ImageOps.invert(img)
        }

    def __call__(self, img):
        avaiable_op_names = list(self.level_map.keys())
        for layer_num in range(self.num_layers):
            op_name = np.random.choice(avaiable_op_names)
            img = self.func[op_name](img, self.level_map[op_name])
        return img


class RandAugment(RawRandAugment):
    """ RandAugment wrapper to auto fit different img types """

L
littletomatodonkey 已提交
120 121
    def __init__(self, prob=0.5, *args, **kwargs):
        self.prob = prob
Z
zhoujun 已提交
122 123 124 125 126 127
        if six.PY2:
            super(RandAugment, self).__init__(*args, **kwargs)
        else:
            super().__init__(*args, **kwargs)

    def __call__(self, data):
L
littletomatodonkey 已提交
128 129
        if np.random.rand() > self.prob:
            return data
Z
zhoujun 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143
        img = data['image']
        if not isinstance(img, Image.Image):
            img = np.ascontiguousarray(img)
            img = Image.fromarray(img)

        if six.PY2:
            img = super(RandAugment, self).__call__(img)
        else:
            img = super().__call__(img)

        if isinstance(img, Image.Image):
            img = np.asarray(img)
        data['image'] = img
        return data