diff --git a/ppcls/data/preprocess/ops/autoaugment.py b/ppcls/data/preprocess/ops/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..6065697e2a61f18ba2ae5ef05f651f0ae223ddaf --- /dev/null +++ b/ppcls/data/preprocess/ops/autoaugment.py @@ -0,0 +1,264 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is based on https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py + +from PIL import Image, ImageEnhance, ImageOps +import numpy as np +import random + + +class ImageNetPolicy(object): + """ Randomly choose one of the best 24 Sub-policies on ImageNet. + + Example: + >>> policy = ImageNetPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> ImageNetPolicy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), + SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), + SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), + SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), + SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), + SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), + SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), + SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), + SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), + SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), + SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) + ] + + def __call__(self, img, policy_idx=None): + if policy_idx is None or not isinstance(policy_idx, int): + policy_idx = random.randint(0, len(self.policies) - 1) + else: + policy_idx = policy_idx % len(self.policies) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment ImageNet Policy" + + +class CIFAR10Policy(object): + """ Randomly choose one of the best 25 Sub-policies on CIFAR10. + + Example: + >>> policy = CIFAR10Policy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> CIFAR10Policy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), + SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), + SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), + SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), + SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), + SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), + SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), + SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), + SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), + SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), + SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), + SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), + SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), + SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), + SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), + SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), + SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), + SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), + SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), + SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) + ] + + def __call__(self, img, policy_idx=None): + if policy_idx is None or not isinstance(policy_idx, int): + policy_idx = random.randint(0, len(self.policies) - 1) + else: + policy_idx = policy_idx % len(self.policies) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment CIFAR10 Policy" + + +class SVHNPolicy(object): + """ Randomly choose one of the best 25 Sub-policies on SVHN. + + Example: + >>> policy = SVHNPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> SVHNPolicy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), + SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), + SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), + SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), + SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), + SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), + SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), + SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), + SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), + SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), + SubPolicy( + 0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), SubPolicy( + 0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), SubPolicy( + 0.7, "solarize", 2, 0.6, "translateY", 7, + fillcolor), SubPolicy(0.8, "shearY", 4, 0.8, "invert", + 8, fillcolor), SubPolicy( + 0.7, "shearX", 9, 0.8, + "translateY", 3, + fillcolor), SubPolicy( + 0.8, "shearY", 5, 0.7, + "autocontrast", 3, + fillcolor), + SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) + ] + + def __call__(self, img, policy_idx=None): + if policy_idx is None or not isinstance(policy_idx, int): + policy_idx = random.randint(0, len(self.policies) - 1) + else: + policy_idx = policy_idx % len(self.policies) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment SVHN Policy" + + +class SubPolicy(object): + def __init__(self, + p1, + operation1, + magnitude_idx1, + p2, + operation2, + magnitude_idx2, + fillcolor=(128, 128, 128)): + ranges = { + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 150 / 331, 10), + "translateY": np.linspace(0, 150 / 331, 10), + "rotate": np.linspace(0, 30, 10), + "color": np.linspace(0.0, 0.9, 10), + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), + "solarize": np.linspace(256, 0, 10), + "contrast": np.linspace(0.0, 0.9, 10), + "sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.9, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10 + } + + # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, + Image.new("RGBA", rot.size, (128, ) * 4), + rot).convert(img.mode) + + func = { + "shearX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img) + } + + self.p1 = p1 + self.operation1 = func[operation1] + self.magnitude1 = ranges[operation1][magnitude_idx1] + self.p2 = p2 + self.operation2 = func[operation2] + self.magnitude2 = ranges[operation2][magnitude_idx2] + + def __call__(self, img): + if random.random() < self.p1: + img = self.operation1(img, self.magnitude1) + if random.random() < self.p2: + img = self.operation2(img, self.magnitude2) + return img diff --git a/ppcls/data/preprocess/ops/cutout.py b/ppcls/data/preprocess/ops/cutout.py new file mode 100644 index 0000000000000000000000000000000000000000..43d557f8619075976f3f313c1535ed9badb65563 --- /dev/null +++ b/ppcls/data/preprocess/ops/cutout.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is based on https://github.com/uoguelph-mlrg/Cutout + +import numpy as np +import random + + +class Cutout(object): + def __init__(self, n_holes=1, length=112): + self.n_holes = n_holes + self.length = length + + def __call__(self, img): + """ cutout_image """ + h, w = img.shape[:2] + mask = np.ones((h, w), np.float32) + + for n in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + img[y1:y2, x1:x2] = 0 + return img diff --git a/ppcls/data/preprocess/ops/fmix.py b/ppcls/data/preprocess/ops/fmix.py new file mode 100644 index 0000000000000000000000000000000000000000..fb9382115c1f3eb32ac2fd3e90699899006a0469 --- /dev/null +++ b/ppcls/data/preprocess/ops/fmix.py @@ -0,0 +1,217 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random + +import numpy as np +from scipy.stats import beta + + +def fftfreqnd(h, w=None, z=None): + """ Get bin values for discrete fourier transform of size (h, w, z) + + :param h: Required, first dimension size + :param w: Optional, second dimension size + :param z: Optional, third dimension size + """ + fz = fx = 0 + fy = np.fft.fftfreq(h) + + if w is not None: + fy = np.expand_dims(fy, -1) + + if w % 2 == 1: + fx = np.fft.fftfreq(w)[:w // 2 + 2] + else: + fx = np.fft.fftfreq(w)[:w // 2 + 1] + + if z is not None: + fy = np.expand_dims(fy, -1) + if z % 2 == 1: + fz = np.fft.fftfreq(z)[:, None] + else: + fz = np.fft.fftfreq(z)[:, None] + + return np.sqrt(fx * fx + fy * fy + fz * fz) + + +def get_spectrum(freqs, decay_power, ch, h, w=0, z=0): + """ Samples a fourier image with given size and frequencies decayed by decay power + + :param freqs: Bin values for the discrete fourier transform + :param decay_power: Decay power for frequency decay prop 1/f**d + :param ch: Number of channels for the resulting mask + :param h: Required, first dimension size + :param w: Optional, second dimension size + :param z: Optional, third dimension size + """ + scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) + **decay_power) + + param_size = [ch] + list(freqs.shape) + [2] + param = np.random.randn(*param_size) + + scale = np.expand_dims(scale, -1)[None, :] + + return scale * param + + +def make_low_freq_image(decay, shape, ch=1): + """ Sample a low frequency image from fourier space + + :param decay_power: Decay power for frequency decay prop 1/f**d + :param shape: Shape of desired mask, list up to 3 dims + :param ch: Number of channels for desired mask + """ + freqs = fftfreqnd(*shape) + spectrum = get_spectrum(freqs, decay, ch, + *shape) #.reshape((1, *shape[:-1], -1)) + spectrum = spectrum[:, 0] + 1j * spectrum[:, 1] + mask = np.real(np.fft.irfftn(spectrum, shape)) + + if len(shape) == 1: + mask = mask[:1, :shape[0]] + if len(shape) == 2: + mask = mask[:1, :shape[0], :shape[1]] + if len(shape) == 3: + mask = mask[:1, :shape[0], :shape[1], :shape[2]] + + mask = mask + mask = (mask - mask.min()) + mask = mask / mask.max() + return mask + + +def sample_lam(alpha, reformulate=False): + """ Sample a lambda from symmetric beta distribution with given alpha + + :param alpha: Alpha value for beta distribution + :param reformulate: If True, uses the reformulation of [1]. + """ + if reformulate: + lam = beta.rvs(alpha + 1, alpha) + else: + lam = beta.rvs(alpha, alpha) + + return lam + + +def binarise_mask(mask, lam, in_shape, max_soft=0.0): + """ Binarises a given low frequency image such that it has mean lambda. + + :param mask: Low frequency image, usually the result of `make_low_freq_image` + :param lam: Mean value of final mask + :param in_shape: Shape of inputs + :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. + :return: + """ + idx = mask.reshape(-1).argsort()[::-1] + mask = mask.reshape(-1) + num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor( + lam * mask.size) + + eff_soft = max_soft + if max_soft > lam or max_soft > (1 - lam): + eff_soft = min(lam, 1 - lam) + + soft = int(mask.size * eff_soft) + num_low = int(num - soft) + num_high = int(num + soft) + + mask[idx[:num_high]] = 1 + mask[idx[num_low:]] = 0 + mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low)) + + mask = mask.reshape((1, 1, in_shape[0], in_shape[1])) + return mask + + +def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False): + """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises + it based on this lambda + + :param alpha: Alpha value for beta distribution from which to sample mean of mask + :param decay_power: Decay power for frequency decay prop 1/f**d + :param shape: Shape of desired mask, list up to 3 dims + :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. + :param reformulate: If True, uses the reformulation of [1]. + """ + if isinstance(shape, int): + shape = (shape, ) + + # Choose lambda + lam = sample_lam(alpha, reformulate) + + # Make mask, get mean / std + mask = make_low_freq_image(decay_power, shape) + mask = binarise_mask(mask, lam, shape, max_soft) + + return float(lam), mask + + +def sample_and_apply(x, + alpha, + decay_power, + shape, + max_soft=0.0, + reformulate=False): + """ + + :param x: Image batch on which to apply fmix of shape [b, c, shape*] + :param alpha: Alpha value for beta distribution from which to sample mean of mask + :param decay_power: Decay power for frequency decay prop 1/f**d + :param shape: Shape of desired mask, list up to 3 dims + :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. + :param reformulate: If True, uses the reformulation of [1]. + :return: mixed input, permutation indices, lambda value of mix, + """ + lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate) + index = np.random.permutation(x.shape[0]) + + x1, x2 = x * mask, x[index] * (1 - mask) + return x1 + x2, index, lam + + +class FMixBase: + """ FMix augmentation + + Args: + decay_power (float): Decay power for frequency decay prop 1/f**d + alpha (float): Alpha value for beta distribution from which to sample mean of mask + size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims + max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. + reformulate (bool): If True, uses the reformulation of [1]. + """ + + def __init__(self, + decay_power=3, + alpha=1, + size=(32, 32), + max_soft=0.0, + reformulate=False): + super().__init__() + self.decay_power = decay_power + self.reformulate = reformulate + self.size = size + self.alpha = alpha + self.max_soft = max_soft + self.index = None + self.lam = None + + def __call__(self, x): + raise NotImplementedError + + def loss(self, *args, **kwargs): + raise NotImplementedError diff --git a/ppcls/data/preprocess/ops/functional.py b/ppcls/data/preprocess/ops/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec6bbbd0e3b3d468916f283868e0227ed133233 --- /dev/null +++ b/ppcls/data/preprocess/ops/functional.py @@ -0,0 +1,124 @@ +# encoding: utf-8 + +import numpy as np +from PIL import Image, ImageOps, ImageEnhance + + + +def int_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval . + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + Returns: + An int that results from scaling `maxval` according to `level`. + """ + return int(level * maxval / 10) + + +def float_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval. + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + Returns: + A float that results from scaling `maxval` according to `level`. + """ + return float(level) * maxval / 10. + + +def sample_level(n): + return np.random.uniform(low=0.1, high=n) + + +def autocontrast(pil_img, *args): + return ImageOps.autocontrast(pil_img) + + +def equalize(pil_img, *args): + return ImageOps.equalize(pil_img) + + +def posterize(pil_img, level, *args): + level = int_parameter(sample_level(level), 4) + return ImageOps.posterize(pil_img, 4 - level) + + +def rotate(pil_img, level, *args): + degrees = int_parameter(sample_level(level), 30) + if np.random.uniform() > 0.5: + degrees = -degrees + return pil_img.rotate(degrees, resample=Image.BILINEAR) + + +def solarize(pil_img, level, *args): + level = int_parameter(sample_level(level), 256) + return ImageOps.solarize(pil_img, 256 - level) + + +def shear_x(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform(pil_img.size, + Image.AFFINE, (1, level, 0, 0, 1, 0), + resample=Image.BILINEAR) + + +def shear_y(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform(pil_img.size, + Image.AFFINE, (1, 0, 0, level, 1, 0), + resample=Image.BILINEAR) + + +def translate_x(pil_img, level): + level = int_parameter(sample_level(level), pil_img.size[0] / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform(pil_img.size, + Image.AFFINE, (1, 0, level, 0, 1, 0), + resample=Image.BILINEAR) + + +def translate_y(pil_img, level): + level = int_parameter(sample_level(level), pil_img.size[1] / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform(pil_img.size, + Image.AFFINE, (1, 0, 0, 0, 1, level), + resample=Image.BILINEAR) + + +# operation that overlaps with ImageNet-C's test set +def color(pil_img, level, *args): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Color(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def contrast(pil_img, level, *args): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Contrast(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def brightness(pil_img, level, *args): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Brightness(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def sharpness(pil_img, level, *args): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Sharpness(pil_img).enhance(level) + + +augmentations = [ + autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, + translate_x, translate_y +] \ No newline at end of file diff --git a/ppcls/data/preprocess/ops/grid.py b/ppcls/data/preprocess/ops/grid.py new file mode 100644 index 0000000000000000000000000000000000000000..93e0c58ac756bb43bcfda388240c50786487973b --- /dev/null +++ b/ppcls/data/preprocess/ops/grid.py @@ -0,0 +1,89 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is based on https://github.com/akuxcw/GridMask + +import numpy as np +from PIL import Image +import pdb + +# curr +CURR_EPOCH = 0 +# epoch for the prob to be the upper limit +NUM_EPOCHS = 240 + + +class GridMask(object): + def __init__(self, d1=96, d2=224, rotate=1, ratio=0.5, mode=0, prob=1.): + self.d1 = d1 + self.d2 = d2 + self.rotate = rotate + self.ratio = ratio + self.mode = mode + self.st_prob = prob + self.prob = prob + self.last_prob = -1 + + def set_prob(self): + global CURR_EPOCH + global NUM_EPOCHS + self.prob = self.st_prob * min(1, 1.0 * CURR_EPOCH / NUM_EPOCHS) + + def __call__(self, img): + self.set_prob() + if abs(self.last_prob - self.prob) > 1e-10: + global CURR_EPOCH + global NUM_EPOCHS + print( + "self.prob is updated, self.prob={}, CURR_EPOCH: {}, NUM_EPOCHS: {}". + format(self.prob, CURR_EPOCH, NUM_EPOCHS)) + self.last_prob = self.prob + # print("CURR_EPOCH: {}, NUM_EPOCHS: {}, self.prob is set as: {}".format(CURR_EPOCH, NUM_EPOCHS, self.prob) ) + if np.random.rand() > self.prob: + return img + _, h, w = img.shape + hh = int(1.5 * h) + ww = int(1.5 * w) + d = np.random.randint(self.d1, self.d2) + #d = self.d + self.l = int(d * self.ratio + 0.5) + mask = np.ones((hh, ww), np.float32) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + for i in range(-1, hh // d + 1): + s = d * i + st_h + t = s + self.l + s = max(min(s, hh), 0) + t = max(min(t, hh), 0) + mask[s:t, :] *= 0 + for i in range(-1, ww // d + 1): + s = d * i + st_w + t = s + self.l + s = max(min(s, ww), 0) + t = max(min(t, ww), 0) + mask[:, s:t] *= 0 + r = np.random.randint(self.rotate) + mask = Image.fromarray(np.uint8(mask)) + mask = mask.rotate(r) + mask = np.asarray(mask) + mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) // + 2 + w] + + if self.mode == 1: + mask = 1 - mask + + mask = np.expand_dims(mask, axis=0) + img = (img * mask).astype(img.dtype) + + return img diff --git a/ppcls/data/preprocess/ops/hide_and_seek.py b/ppcls/data/preprocess/ops/hide_and_seek.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4a8f97d4a8fb8f5d5972ca7463cd536b69924a --- /dev/null +++ b/ppcls/data/preprocess/ops/hide_and_seek.py @@ -0,0 +1,44 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is based on https://github.com/kkanshul/Hide-and-Seek + +import numpy as np +import random + + +class HideAndSeek(object): + def __init__(self): + # possible grid size, 0 means no hiding + self.grid_sizes = [0, 16, 32, 44, 56] + # hiding probability + self.hide_prob = 0.5 + + def __call__(self, img): + # randomly choose one grid size + grid_size = np.random.choice(self.grid_sizes) + + _, h, w = img.shape + + # hide the patches + if grid_size == 0: + return img + for x in range(0, w, grid_size): + for y in range(0, h, grid_size): + x_end = min(w, x + grid_size) + y_end = min(h, y + grid_size) + if (random.random() <= self.hide_prob): + img[:, x:x_end, y:y_end] = 0 + + return img diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..e05a61ca0a4a1d51d8b11f922010ce209c2e09c5 --- /dev/null +++ b/ppcls/data/preprocess/ops/operators.py @@ -0,0 +1,281 @@ +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import six +import math +import random +import cv2 +import numpy as np +from PIL import Image + +from .autoaugment import ImageNetPolicy +from .functional import augmentations + +class OperatorParamError(ValueError): + """ OperatorParamError + """ + pass + +class DecodeImage(object): + """ decode image """ + + def __init__(self, to_rgb=True, to_np=False, channel_first=False): + self.to_rgb = to_rgb + self.to_np = to_np # to numpy + self.channel_first = channel_first # only enabled when to_np is True + + def __call__(self, img): + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + data = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(data, 1) + if self.to_rgb: + assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( + img.shape) + img = img[:, :, ::-1] + + if self.channel_first: + img = img.transpose((2, 0, 1)) + + return img + + +class ResizeImage(object): + """ resize image """ + + def __init__(self, size=None, resize_short=None, interpolation=-1): + self.interpolation = interpolation if interpolation >= 0 else None + if resize_short is not None and resize_short > 0: + self.resize_short = resize_short + self.w = None + self.h = None + elif size is not None: + self.resize_short = None + self.w = size if type(size) is int else size[0] + self.h = size if type(size) is int else size[1] + else: + raise OperatorParamError("invalid params for ReisizeImage for '\ + 'both 'size' and 'resize_short' are None") + + def __call__(self, img): + img_h, img_w = img.shape[:2] + if self.resize_short is not None: + percent = float(self.resize_short) / min(img_w, img_h) + w = int(round(img_w * percent)) + h = int(round(img_h * percent)) + else: + w = self.w + h = self.h + if self.interpolation is None: + return cv2.resize(img, (w, h)) + else: + return cv2.resize(img, (w, h), interpolation=self.interpolation) + + +class CropImage(object): + """ crop image """ + + def __init__(self, size): + if type(size) is int: + self.size = (size, size) + else: + self.size = size # (h, w) + + def __call__(self, img): + w, h = self.size + img_h, img_w = img.shape[:2] + w_start = (img_w - w) // 2 + h_start = (img_h - h) // 2 + + w_end = w_start + w + h_end = h_start + h + return img[h_start:h_end, w_start:w_end, :] + + +class RandCropImage(object): + """ random crop image """ + + def __init__(self, size, scale=None, ratio=None, interpolation=-1): + + self.interpolation = interpolation if interpolation >= 0 else None + if type(size) is int: + self.size = (size, size) # (h, w) + else: + self.size = size + + self.scale = [0.08, 1.0] if scale is None else scale + self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio + + def __call__(self, img): + size = self.size + scale = self.scale + ratio = self.ratio + + aspect_ratio = math.sqrt(random.uniform(*ratio)) + w = 1. * aspect_ratio + h = 1. / aspect_ratio + + img_h, img_w = img.shape[:2] + + bound = min((float(img_w) / img_h) / (w**2), + (float(img_h) / img_w) / (h**2)) + scale_max = min(scale[1], bound) + scale_min = min(scale[0], bound) + + target_area = img_w * img_h * random.uniform(scale_min, scale_max) + target_size = math.sqrt(target_area) + w = int(target_size * w) + h = int(target_size * h) + + i = random.randint(0, img_w - w) + j = random.randint(0, img_h - h) + + img = img[j:j + h, i:i + w, :] + if self.interpolation is None: + return cv2.resize(img, size) + else: + return cv2.resize(img, size, interpolation=self.interpolation) + + +class RandFlipImage(object): + """ random flip image + flip_code: + 1: Flipped Horizontally + 0: Flipped Vertically + -1: Flipped Horizontally & Vertically + """ + + def __init__(self, flip_code=1): + assert flip_code in [-1, 0, 1 + ], "flip_code should be a value in [-1, 0, 1]" + self.flip_code = flip_code + + def __call__(self, img): + if random.randint(0, 1) == 1: + return cv2.flip(img, self.flip_code) + else: + return img + + +class AutoAugment(object): + def __init__(self): + self.policy = ImageNetPolicy() + + def __call__(self, img): + from PIL import Image + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + img = self.policy(img) + img = np.asarray(img) + + +class NormalizeImage(object): + """ normalize image such as substract mean, divide std + """ + + def __init__(self, scale=None, mean=None, std=None, order='chw'): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, img): + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + + assert isinstance(img, + np.ndarray), "invalid input 'img' in NormalizeImage" + return (img.astype('float32') * self.scale - self.mean) / self.std + + +class ToCHWImage(object): + """ convert hwc image to chw image + """ + + def __init__(self): + pass + + def __call__(self, img): + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + + return img.transpose((2, 0, 1)) + + +class AugMix(object): + """ Perform AugMix augmentation and compute mixture. + """ + + def __init__(self, prob=0.5, aug_prob_coeff=0.1, mixture_width=3, mixture_depth=1, aug_severity=1): + """ + Args: + prob: Probability of taking augmix + aug_prob_coeff: Probability distribution coefficients. + mixture_width: Number of augmentation chains to mix per augmented example. + mixture_depth: Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]' + aug_severity: Severity of underlying augmentation operators (between 1 to 10). + """ + # fmt: off + self.prob = prob + self.aug_prob_coeff = aug_prob_coeff + self.mixture_width = mixture_width + self.mixture_depth = mixture_depth + self.aug_severity = aug_severity + self.augmentations = augmentations + # fmt: on + + def __call__(self, image): + """Perform AugMix augmentations and compute mixture. + Returns: + mixed: Augmented and mixed image. + """ + if random.random() > self.prob: + # Avoid the warning: the given NumPy array is not writeable + return np.asarray(image).copy() + + ws = np.float32( + np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) + m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff)) + + # image = Image.fromarray(image) + mix = np.zeros([image.shape[1], image.shape[0], 3]) + for i in range(self.mixture_width): + image_aug = image.copy() + image_aug = Image.fromarray(image_aug) + depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) + for _ in range(depth): + op = np.random.choice(self.augmentations) + image_aug = op(image_aug, self.aug_severity) + mix += ws[i] * np.asarray(image_aug) + + mixed = (1 - m) * image + m * mix + return mixed.astype(np.uint8) diff --git a/ppcls/data/preprocess/ops/randaugment.py b/ppcls/data/preprocess/ops/randaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3e9695c08a1ba7a39508a58012c385b4a3a5df --- /dev/null +++ b/ppcls/data/preprocess/ops/randaugment.py @@ -0,0 +1,106 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is based on https://github.com/heartInsert/randaugment + +from PIL import Image, ImageEnhance, ImageOps +import numpy as np +import random + + +class RandAugment(object): + def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)): + self.num_layers = num_layers + self.magnitude = magnitude + self.max_level = 10 + + abso_level = self.magnitude / self.max_level + self.level_map = { + "shearX": 0.3 * abso_level, + "shearY": 0.3 * abso_level, + "translateX": 150.0 / 331 * abso_level, + "translateY": 150.0 / 331 * abso_level, + "rotate": 30 * abso_level, + "color": 0.9 * abso_level, + "posterize": int(4.0 * abso_level), + "solarize": 256.0 * abso_level, + "contrast": 0.9 * abso_level, + "sharpness": 0.9 * abso_level, + "brightness": 0.9 * abso_level, + "autocontrast": 0, + "equalize": 0, + "invert": 0 + } + + # from https://stackoverflow.com/questions/5252170/ + # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, + Image.new("RGBA", rot.size, (128, ) * 4), + rot).convert(img.mode) + + rnd_ch_op = random.choice + + self.func = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0), + Image.BICUBIC, + fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0), + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])), + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "posterize": lambda img, magnitude: + ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: + ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: + ImageEnhance.Contrast(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "sharpness": lambda img, magnitude: + ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "brightness": lambda img, magnitude: + ImageEnhance.Brightness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "autocontrast": lambda img, magnitude: + ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img) + } + + def __call__(self, img): + avaiable_op_names = list(self.level_map.keys()) + for layer_num in range(self.num_layers): + op_name = np.random.choice(avaiable_op_names) + img = self.func[op_name](img, self.level_map[op_name]) + return img diff --git a/ppcls/data/preprocess/ops/random_erasing.py b/ppcls/data/preprocess/ops/random_erasing.py new file mode 100644 index 0000000000000000000000000000000000000000..0c3c315820463e2549b54be277840278327804cd --- /dev/null +++ b/ppcls/data/preprocess/ops/random_erasing.py @@ -0,0 +1,59 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#This code is based on https://github.com/zhunzhong07/Random-Erasing + +import math +import random + +import numpy as np + + +class RandomErasing(object): + def __init__(self, + EPSILON=0.5, + sl=0.02, + sh=0.4, + r1=0.3, + mean=[0., 0., 0.]): + self.EPSILON = EPSILON + self.mean = mean + self.sl = sl + self.sh = sh + self.r1 = r1 + + def __call__(self, img): + if random.uniform(0, 1) > self.EPSILON: + return img + + for attempt in range(100): + area = img.shape[1] * img.shape[2] + + target_area = random.uniform(self.sl, self.sh) * area + aspect_ratio = random.uniform(self.r1, 1 / self.r1) + + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + + if w < img.shape[2] and h < img.shape[1]: + x1 = random.randint(0, img.shape[1] - h) + y1 = random.randint(0, img.shape[2] - w) + if img.shape[0] == 3: + img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] + img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] + img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] + else: + img[0, x1:x1 + h, y1:y1 + w] = self.mean[1] + return img + return img