未验证 提交 f8c0beeb 编写于 作者: W Wei Shengyu 提交者: GitHub

remove data/imaug (#805)

上级 4905424f
......@@ -128,8 +128,8 @@ ResNet系列模型中,相比于其他模型,ResNet_vd模型在预测速度
**A**
* 对于单张图像的增广,可以参考[基于单张图片的数据增广脚本](../../../ppcls/data/imaug/operators.py),参考`ResizeImage`或者`CropImage`等数据算子的写法,创建一个新的类,然后在`__call__`中,实现对应的增广方法即可。
* 对于一个batch图像的增广,可以参考[基于batch数据的数据增广脚本](../../../ppcls/data/imaug/batch_operators.py),参考`MixupOperator`或者`CutmixOperator`等数据算子的写法,创建一个新的类,然后在`__call__`中,实现对应的增广方法即可。
* 对于单张图像的增广,可以参考[基于单张图片的数据增广脚本](../../../ppcls/data/preprocess/ops),参考`ResizeImage`或者`CropImage`等数据算子的写法,创建一个新的类,然后在`__call__`中,实现对应的增广方法即可。
* 对于一个batch图像的增广,可以参考[基于batch数据的数据增广脚本](../../../ppcls/data/preprocess/batch_ops),参考`MixupOperator`或者`CutmixOperator`等数据算子的写法,创建一个新的类,然后在`__call__`中,实现对应的增广方法即可。
## Q3.5: 怎么进一步加速模型训练过程呢?
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .autoaugment import ImageNetPolicy as RawImageNetPolicy
from .randaugment import RandAugment as RawRandAugment
from .cutout import Cutout
from .hide_and_seek import HideAndSeek
from .random_erasing import RandomErasing
from .grid import GridMask
from .operators import DecodeImage
from .operators import ResizeImage
from .operators import CropImage
from .operators import RandCropImage
from .operators import RandFlipImage
from .operators import NormalizeImage
from .operators import ToCHWImage
from .batch_operators import MixupOperator
from .batch_operators import CutmixOperator
from .batch_operators import FmixOperator
import six
import numpy as np
from PIL import Image
def transform(data, ops=[]):
""" transform """
for op in ops:
data = op(data)
return data
class AutoAugment(RawImageNetPolicy):
""" ImageNetPolicy wrapper to auto fit different img types """
def __init__(self, *args, **kwargs):
if six.PY2:
super(AutoAugment, self).__init__(*args, **kwargs)
else:
super().__init__(*args, **kwargs)
def __call__(self, img):
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
if six.PY2:
img = super(AutoAugment, self).__call__(img)
else:
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
return img
class RandAugment(RawRandAugment):
""" RandAugment wrapper to auto fit different img types """
def __init__(self, *args, **kwargs):
if six.PY2:
super(RandAugment, self).__init__(*args, **kwargs)
else:
super().__init__(*args, **kwargs)
def __call__(self, img):
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
if six.PY2:
img = super(RandAugment, self).__call__(img)
else:
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
return img
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is based on https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
class ImageNetPolicy(object):
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
]
def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class CIFAR10Policy(object):
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
]
def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class SVHNPolicy(object):
""" Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(
0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), SubPolicy(
0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), SubPolicy(
0.7, "solarize", 2, 0.6, "translateY", 7,
fillcolor), SubPolicy(0.8, "shearY", 4, 0.8, "invert",
8, fillcolor), SubPolicy(
0.7, "shearX", 9, 0.8,
"translateY", 3,
fillcolor), SubPolicy(
0.8, "shearY", 5, 0.7,
"autocontrast", 3,
fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
]
def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(self,
p1,
operation1,
magnitude_idx1,
p2,
operation2,
magnitude_idx2,
fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10
}
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot,
Image.new("RGBA", rot.size, (128, ) * 4),
rot).convert(img.mode)
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
# "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1:
img = self.operation1(img, self.magnitude1)
if random.random() < self.p2:
img = self.operation2(img, self.magnitude2)
return img
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
from .fmix import sample_mask
class BatchOperator(object):
""" BatchOperator """
def __init__(self, *args, **kwargs):
pass
def _unpack(self, batch):
""" _unpack """
assert isinstance(batch, list), \
'batch should be a list filled with tuples (img, label)'
bs = len(batch)
assert bs > 0, 'size of the batch data should > 0'
imgs, labels = list(zip(*batch))
return np.array(imgs), np.array(labels), bs
def __call__(self, batch):
return batch
class MixupOperator(BatchOperator):
""" Mixup operator """
def __init__(self, alpha=0.2):
assert alpha > 0., \
'parameter alpha[%f] should > 0.0' % (alpha)
self._alpha = alpha
def __call__(self, batch):
imgs, labels, bs = self._unpack(batch)
idx = np.random.permutation(bs)
lam = np.random.beta(self._alpha, self._alpha)
lams = np.array([lam] * bs, dtype=np.float32)
imgs = lam * imgs + (1 - lam) * imgs[idx]
return list(zip(imgs, labels, labels[idx], lams))
class CutmixOperator(BatchOperator):
""" Cutmix operator """
def __init__(self, alpha=0.2):
assert alpha > 0., \
'parameter alpha[%f] should > 0.0' % (alpha)
self._alpha = alpha
def _rand_bbox(self, size, lam):
""" _rand_bbox """
w = size[2]
h = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(w * cut_rat)
cut_h = np.int(h * cut_rat)
# uniform
cx = np.random.randint(w)
cy = np.random.randint(h)
bbx1 = np.clip(cx - cut_w // 2, 0, w)
bby1 = np.clip(cy - cut_h // 2, 0, h)
bbx2 = np.clip(cx + cut_w // 2, 0, w)
bby2 = np.clip(cy + cut_h // 2, 0, h)
return bbx1, bby1, bbx2, bby2
def __call__(self, batch):
imgs, labels, bs = self._unpack(batch)
idx = np.random.permutation(bs)
lam = np.random.beta(self._alpha, self._alpha)
bbx1, bby1, bbx2, bby2 = self._rand_bbox(imgs.shape, lam)
imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[idx, :, bbx1:bbx2, bby1:bby2]
lam = 1 - (float(bbx2 - bbx1) * (bby2 - bby1) /
(imgs.shape[-2] * imgs.shape[-1]))
lams = np.array([lam] * bs, dtype=np.float32)
return list(zip(imgs, labels, labels[idx], lams))
class FmixOperator(BatchOperator):
""" Fmix operator """
def __init__(self, alpha=1, decay_power=3, max_soft=0., reformulate=False):
self._alpha = alpha
self._decay_power = decay_power
self._max_soft = max_soft
self._reformulate = reformulate
def __call__(self, batch):
imgs, labels, bs = self._unpack(batch)
idx = np.random.permutation(bs)
size = (imgs.shape[2], imgs.shape[3])
lam, mask = sample_mask(self._alpha, self._decay_power, \
size, self._max_soft, self._reformulate)
imgs = mask * imgs + (1 - mask) * imgs[idx]
return list(zip(imgs, labels, labels[idx], [lam] * bs))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is based on https://github.com/uoguelph-mlrg/Cutout
import numpy as np
import random
class Cutout(object):
def __init__(self, n_holes=1, length=112):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
""" cutout_image """
h, w = img.shape[:2]
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
img[y1:y2, x1:x2] = 0
return img
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import numpy as np
from scipy.stats import beta
def fftfreqnd(h, w=None, z=None):
""" Get bin values for discrete fourier transform of size (h, w, z)
:param h: Required, first dimension size
:param w: Optional, second dimension size
:param z: Optional, third dimension size
"""
fz = fx = 0
fy = np.fft.fftfreq(h)
if w is not None:
fy = np.expand_dims(fy, -1)
if w % 2 == 1:
fx = np.fft.fftfreq(w)[:w // 2 + 2]
else:
fx = np.fft.fftfreq(w)[:w // 2 + 1]
if z is not None:
fy = np.expand_dims(fy, -1)
if z % 2 == 1:
fz = np.fft.fftfreq(z)[:, None]
else:
fz = np.fft.fftfreq(z)[:, None]
return np.sqrt(fx * fx + fy * fy + fz * fz)
def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
""" Samples a fourier image with given size and frequencies decayed by decay power
:param freqs: Bin values for the discrete fourier transform
:param decay_power: Decay power for frequency decay prop 1/f**d
:param ch: Number of channels for the resulting mask
:param h: Required, first dimension size
:param w: Optional, second dimension size
:param z: Optional, third dimension size
"""
scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)]))
**decay_power)
param_size = [ch] + list(freqs.shape) + [2]
param = np.random.randn(*param_size)
scale = np.expand_dims(scale, -1)[None, :]
return scale * param
def make_low_freq_image(decay, shape, ch=1):
""" Sample a low frequency image from fourier space
:param decay_power: Decay power for frequency decay prop 1/f**d
:param shape: Shape of desired mask, list up to 3 dims
:param ch: Number of channels for desired mask
"""
freqs = fftfreqnd(*shape)
spectrum = get_spectrum(freqs, decay, ch,
*shape) #.reshape((1, *shape[:-1], -1))
spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
mask = np.real(np.fft.irfftn(spectrum, shape))
if len(shape) == 1:
mask = mask[:1, :shape[0]]
if len(shape) == 2:
mask = mask[:1, :shape[0], :shape[1]]
if len(shape) == 3:
mask = mask[:1, :shape[0], :shape[1], :shape[2]]
mask = mask
mask = (mask - mask.min())
mask = mask / mask.max()
return mask
def sample_lam(alpha, reformulate=False):
""" Sample a lambda from symmetric beta distribution with given alpha
:param alpha: Alpha value for beta distribution
:param reformulate: If True, uses the reformulation of [1].
"""
if reformulate:
lam = beta.rvs(alpha + 1, alpha)
else:
lam = beta.rvs(alpha, alpha)
return lam
def binarise_mask(mask, lam, in_shape, max_soft=0.0):
""" Binarises a given low frequency image such that it has mean lambda.
:param mask: Low frequency image, usually the result of `make_low_freq_image`
:param lam: Mean value of final mask
:param in_shape: Shape of inputs
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
:return:
"""
idx = mask.reshape(-1).argsort()[::-1]
mask = mask.reshape(-1)
num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(
lam * mask.size)
eff_soft = max_soft
if max_soft > lam or max_soft > (1 - lam):
eff_soft = min(lam, 1 - lam)
soft = int(mask.size * eff_soft)
num_low = int(num - soft)
num_high = int(num + soft)
mask[idx[:num_high]] = 1
mask[idx[num_low:]] = 0
mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))
mask = mask.reshape((1, 1, in_shape[0], in_shape[1]))
return mask
def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
""" Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
it based on this lambda
:param alpha: Alpha value for beta distribution from which to sample mean of mask
:param decay_power: Decay power for frequency decay prop 1/f**d
:param shape: Shape of desired mask, list up to 3 dims
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
:param reformulate: If True, uses the reformulation of [1].
"""
if isinstance(shape, int):
shape = (shape, )
# Choose lambda
lam = sample_lam(alpha, reformulate)
# Make mask, get mean / std
mask = make_low_freq_image(decay_power, shape)
mask = binarise_mask(mask, lam, shape, max_soft)
return float(lam), mask
def sample_and_apply(x,
alpha,
decay_power,
shape,
max_soft=0.0,
reformulate=False):
"""
:param x: Image batch on which to apply fmix of shape [b, c, shape*]
:param alpha: Alpha value for beta distribution from which to sample mean of mask
:param decay_power: Decay power for frequency decay prop 1/f**d
:param shape: Shape of desired mask, list up to 3 dims
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
:param reformulate: If True, uses the reformulation of [1].
:return: mixed input, permutation indices, lambda value of mix,
"""
lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
index = np.random.permutation(x.shape[0])
x1, x2 = x * mask, x[index] * (1 - mask)
return x1 + x2, index, lam
class FMixBase:
""" FMix augmentation
Args:
decay_power (float): Decay power for frequency decay prop 1/f**d
alpha (float): Alpha value for beta distribution from which to sample mean of mask
size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
reformulate (bool): If True, uses the reformulation of [1].
"""
def __init__(self,
decay_power=3,
alpha=1,
size=(32, 32),
max_soft=0.0,
reformulate=False):
super().__init__()
self.decay_power = decay_power
self.reformulate = reformulate
self.size = size
self.alpha = alpha
self.max_soft = max_soft
self.index = None
self.lam = None
def __call__(self, x):
raise NotImplementedError
def loss(self, *args, **kwargs):
raise NotImplementedError
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is based on https://github.com/akuxcw/GridMask
import numpy as np
from PIL import Image
import pdb
# curr
CURR_EPOCH = 0
# epoch for the prob to be the upper limit
NUM_EPOCHS = 240
class GridMask(object):
def __init__(self, d1=96, d2=224, rotate=1, ratio=0.5, mode=0, prob=1.):
self.d1 = d1
self.d2 = d2
self.rotate = rotate
self.ratio = ratio
self.mode = mode
self.st_prob = prob
self.prob = prob
self.last_prob = -1
def set_prob(self):
global CURR_EPOCH
global NUM_EPOCHS
self.prob = self.st_prob * min(1, 1.0 * CURR_EPOCH / NUM_EPOCHS)
def __call__(self, img):
self.set_prob()
if abs(self.last_prob - self.prob) > 1e-10:
global CURR_EPOCH
global NUM_EPOCHS
print(
"self.prob is updated, self.prob={}, CURR_EPOCH: {}, NUM_EPOCHS: {}".
format(self.prob, CURR_EPOCH, NUM_EPOCHS))
self.last_prob = self.prob
# print("CURR_EPOCH: {}, NUM_EPOCHS: {}, self.prob is set as: {}".format(CURR_EPOCH, NUM_EPOCHS, self.prob) )
if np.random.rand() > self.prob:
return img
_, h, w = img.shape
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(self.d1, self.d2)
#d = self.d
self.l = int(d * self.ratio + 0.5)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
for i in range(-1, hh // d + 1):
s = d * i + st_h
t = s + self.l
s = max(min(s, hh), 0)
t = max(min(t, hh), 0)
mask[s:t, :] *= 0
for i in range(-1, ww // d + 1):
s = d * i + st_w
t = s + self.l
s = max(min(s, ww), 0)
t = max(min(t, ww), 0)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) //
2 + w]
if self.mode == 1:
mask = 1 - mask
mask = np.expand_dims(mask, axis=0)
img = (img * mask).astype(img.dtype)
return img
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is based on https://github.com/kkanshul/Hide-and-Seek
import numpy as np
import random
class HideAndSeek(object):
def __init__(self):
# possible grid size, 0 means no hiding
self.grid_sizes = [0, 16, 32, 44, 56]
# hiding probability
self.hide_prob = 0.5
def __call__(self, img):
# randomly choose one grid size
grid_size = np.random.choice(self.grid_sizes)
_, h, w = img.shape
# hide the patches
if grid_size == 0:
return img
for x in range(0, w, grid_size):
for y in range(0, h, grid_size):
x_end = min(w, x + grid_size)
y_end = min(h, y + grid_size)
if (random.random() <= self.hide_prob):
img[:, x:x_end, y:y_end] = 0
return img
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import six
import math
import random
import cv2
import numpy as np
from .autoaugment import ImageNetPolicy
class OperatorParamError(ValueError):
""" OperatorParamError
"""
pass
class DecodeImage(object):
""" decode image """
def __init__(self, to_rgb=True, to_np=False, channel_first=False):
self.to_rgb = to_rgb
self.to_np = to_np # to numpy
self.channel_first = channel_first # only enabled when to_np is True
def __call__(self, img):
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
data = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(data, 1)
if self.to_rgb:
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
return img
class ResizeImage(object):
""" resize image """
def __init__(self, size=None, resize_short=None, interpolation=-1):
self.interpolation = interpolation if interpolation >= 0 else None
if resize_short is not None and resize_short > 0:
self.resize_short = resize_short
self.w = None
self.h = None
elif size is not None:
self.resize_short = None
self.w = size if type(size) is int else size[0]
self.h = size if type(size) is int else size[1]
else:
raise OperatorParamError("invalid params for ReisizeImage for '\
'both 'size' and 'resize_short' are None")
def __call__(self, img):
img_h, img_w = img.shape[:2]
if self.resize_short is not None:
percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent))
h = int(round(img_h * percent))
else:
w = self.w
h = self.h
if self.interpolation is None:
return cv2.resize(img, (w, h))
else:
return cv2.resize(img, (w, h), interpolation=self.interpolation)
class CropImage(object):
""" crop image """
def __init__(self, size):
if type(size) is int:
self.size = (size, size)
else:
self.size = size # (h, w)
def __call__(self, img):
w, h = self.size
img_h, img_w = img.shape[:2]
w_start = (img_w - w) // 2
h_start = (img_h - h) // 2
w_end = w_start + w
h_end = h_start + h
return img[h_start:h_end, w_start:w_end, :]
class RandCropImage(object):
""" random crop image """
def __init__(self, size, scale=None, ratio=None, interpolation=-1):
self.interpolation = interpolation if interpolation >= 0 else None
if type(size) is int:
self.size = (size, size) # (h, w)
else:
self.size = size
self.scale = [0.08, 1.0] if scale is None else scale
self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
def __call__(self, img):
size = self.size
scale = self.scale
ratio = self.ratio
aspect_ratio = math.sqrt(random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
img_h, img_w = img.shape[:2]
bound = min((float(img_w) / img_h) / (w**2),
(float(img_h) / img_w) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img_w * img_h * random.uniform(scale_min, scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = random.randint(0, img_w - w)
j = random.randint(0, img_h - h)
img = img[j:j + h, i:i + w, :]
if self.interpolation is None:
return cv2.resize(img, size)
else:
return cv2.resize(img, size, interpolation=self.interpolation)
class RandFlipImage(object):
""" random flip image
flip_code:
1: Flipped Horizontally
0: Flipped Vertically
-1: Flipped Horizontally & Vertically
"""
def __init__(self, flip_code=1):
assert flip_code in [-1, 0, 1
], "flip_code should be a value in [-1, 0, 1]"
self.flip_code = flip_code
def __call__(self, img):
if random.randint(0, 1) == 1:
return cv2.flip(img, self.flip_code)
else:
return img
class AutoAugment(object):
def __init__(self):
self.policy = ImageNetPolicy()
def __call__(self, img):
from PIL import Image
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
img = self.policy(img)
img = np.asarray(img)
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', output_fp16=False, channel_num=3):
if isinstance(scale, str):
scale = eval(scale)
assert channel_num in [3, 4], "channel number of input image should be set to 3 or 4."
self.channel_num = channel_num
self.output_dtype = 'float16' if output_fp16 else 'float32'
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
self.order = order
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, img):
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
img = (img.astype('float32') * self.scale - self.mean) / self.std
if self.channel_num == 4:
img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
pad_zeros = np.zeros((1, img_h, img_w)) if self.order == 'chw' else np.zeros((img_h, img_w, 1))
img = (np.concatenate((img, pad_zeros), axis=0) if self.order == 'chw'
else np.concatenate((img, pad_zeros), axis=2))
return img.astype(self.output_dtype)
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self):
pass
def __call__(self, img):
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
return img.transpose((2, 0, 1))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is based on https://github.com/heartInsert/randaugment
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
class RandAugment(object):
def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)):
self.num_layers = num_layers
self.magnitude = magnitude
self.max_level = 10
abso_level = self.magnitude / self.max_level
self.level_map = {
"shearX": 0.3 * abso_level,
"shearY": 0.3 * abso_level,
"translateX": 150.0 / 331 * abso_level,
"translateY": 150.0 / 331 * abso_level,
"rotate": 30 * abso_level,
"color": 0.9 * abso_level,
"posterize": int(4.0 * abso_level),
"solarize": 256.0 * abso_level,
"contrast": 0.9 * abso_level,
"sharpness": 0.9 * abso_level,
"brightness": 0.9 * abso_level,
"autocontrast": 0,
"equalize": 0,
"invert": 0
}
# from https://stackoverflow.com/questions/5252170/
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot,
Image.new("RGBA", rot.size, (128, ) * 4),
rot).convert(img.mode)
rnd_ch_op = random.choice
self.func = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"posterize": lambda img, magnitude:
ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude:
ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude:
ImageEnhance.Contrast(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"sharpness": lambda img, magnitude:
ImageEnhance.Sharpness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"brightness": lambda img, magnitude:
ImageEnhance.Brightness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"autocontrast": lambda img, magnitude:
ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
def __call__(self, img):
avaiable_op_names = list(self.level_map.keys())
for layer_num in range(self.num_layers):
op_name = np.random.choice(avaiable_op_names)
img = self.func[op_name](img, self.level_map[op_name])
return img
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#This code is based on https://github.com/zhunzhong07/Random-Erasing
import math
import random
import numpy as np
class RandomErasing(object):
def __init__(self, EPSILON=0.5, sl=0.02, sh=0.4, r1=0.3,
mean=[0., 0., 0.]):
self.EPSILON = EPSILON
self.mean = mean
self.sl = sl
self.sh = sh
self.r1 = r1
def __call__(self, img):
if random.uniform(0, 1) > self.EPSILON:
return img
for attempt in range(100):
area = img.shape[1] * img.shape[2]
target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img.shape[2] and h < img.shape[1]:
x1 = random.randint(0, img.shape[1] - h)
y1 = random.randint(0, img.shape[2] - w)
if img.shape[0] == 3:
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
else:
img[0, x1:x1 + h, y1:y1 + w] = self.mean[1]
return img
return img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册