# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This code is based on https://github.com/ecs-vlc/FMix # reference: https://arxiv.org/abs/2002.12047 import math import random import numpy as np from scipy.stats import beta def fftfreqnd(h, w=None, z=None): """ Get bin values for discrete fourier transform of size (h, w, z) :param h: Required, first dimension size :param w: Optional, second dimension size :param z: Optional, third dimension size """ fz = fx = 0 fy = np.fft.fftfreq(h) if w is not None: fy = np.expand_dims(fy, -1) if w % 2 == 1: fx = np.fft.fftfreq(w)[:w // 2 + 2] else: fx = np.fft.fftfreq(w)[:w // 2 + 1] if z is not None: fy = np.expand_dims(fy, -1) if z % 2 == 1: fz = np.fft.fftfreq(z)[:, None] else: fz = np.fft.fftfreq(z)[:, None] return np.sqrt(fx * fx + fy * fy + fz * fz) def get_spectrum(freqs, decay_power, ch, h, w=0, z=0): """ Samples a fourier image with given size and frequencies decayed by decay power :param freqs: Bin values for the discrete fourier transform :param decay_power: Decay power for frequency decay prop 1/f**d :param ch: Number of channels for the resulting mask :param h: Required, first dimension size :param w: Optional, second dimension size :param z: Optional, third dimension size """ scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) **decay_power) param_size = [ch] + list(freqs.shape) + [2] param = np.random.randn(*param_size) scale = np.expand_dims(scale, -1)[None, :] return scale * param def make_low_freq_image(decay, shape, ch=1): """ Sample a low frequency image from fourier space :param decay_power: Decay power for frequency decay prop 1/f**d :param shape: Shape of desired mask, list up to 3 dims :param ch: Number of channels for desired mask """ freqs = fftfreqnd(*shape) spectrum = get_spectrum(freqs, decay, ch, *shape) #.reshape((1, *shape[:-1], -1)) spectrum = spectrum[:, 0] + 1j * spectrum[:, 1] mask = np.real(np.fft.irfftn(spectrum, shape)) if len(shape) == 1: mask = mask[:1, :shape[0]] if len(shape) == 2: mask = mask[:1, :shape[0], :shape[1]] if len(shape) == 3: mask = mask[:1, :shape[0], :shape[1], :shape[2]] mask = mask mask = (mask - mask.min()) mask = mask / mask.max() return mask def sample_lam(alpha, reformulate=False): """ Sample a lambda from symmetric beta distribution with given alpha :param alpha: Alpha value for beta distribution :param reformulate: If True, uses the reformulation of [1]. """ if reformulate: lam = beta.rvs(alpha + 1, alpha) else: lam = beta.rvs(alpha, alpha) return lam def binarise_mask(mask, lam, in_shape, max_soft=0.0): """ Binarises a given low frequency image such that it has mean lambda. :param mask: Low frequency image, usually the result of `make_low_freq_image` :param lam: Mean value of final mask :param in_shape: Shape of inputs :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. :return: """ idx = mask.reshape(-1).argsort()[::-1] mask = mask.reshape(-1) num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor( lam * mask.size) eff_soft = max_soft if max_soft > lam or max_soft > (1 - lam): eff_soft = min(lam, 1 - lam) soft = int(mask.size * eff_soft) num_low = int(num - soft) num_high = int(num + soft) mask[idx[:num_high]] = 1 mask[idx[num_low:]] = 0 mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low)) mask = mask.reshape((1, 1, in_shape[0], in_shape[1])) return mask def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False): """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises it based on this lambda :param alpha: Alpha value for beta distribution from which to sample mean of mask :param decay_power: Decay power for frequency decay prop 1/f**d :param shape: Shape of desired mask, list up to 3 dims :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. :param reformulate: If True, uses the reformulation of [1]. """ if isinstance(shape, int): shape = (shape, ) # Choose lambda lam = sample_lam(alpha, reformulate) # Make mask, get mean / std mask = make_low_freq_image(decay_power, shape) mask = binarise_mask(mask, lam, shape, max_soft) return float(lam), mask def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False): """ :param x: Image batch on which to apply fmix of shape [b, c, shape*] :param alpha: Alpha value for beta distribution from which to sample mean of mask :param decay_power: Decay power for frequency decay prop 1/f**d :param shape: Shape of desired mask, list up to 3 dims :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. :param reformulate: If True, uses the reformulation of [1]. :return: mixed input, permutation indices, lambda value of mix, """ lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate) index = np.random.permutation(x.shape[0]) x1, x2 = x * mask, x[index] * (1 - mask) return x1 + x2, index, lam class FMixBase: """ FMix augmentation Args: decay_power (float): Decay power for frequency decay prop 1/f**d alpha (float): Alpha value for beta distribution from which to sample mean of mask size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. reformulate (bool): If True, uses the reformulation of [1]. """ def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False): super().__init__() self.decay_power = decay_power self.reformulate = reformulate self.size = size self.alpha = alpha self.max_soft = max_soft self.index = None self.lam = None def __call__(self, x): raise NotImplementedError def loss(self, *args, **kwargs): raise NotImplementedError