__init__.py 5.4 KB
Newer Older
F
Felix 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
F
Felix 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ppcls.data.preprocess.ops.autoaugment import ImageNetPolicy as RawImageNetPolicy
from ppcls.data.preprocess.ops.randaugment import RandAugment as RawRandAugment
G
gaotingquan 已提交
17
from ppcls.data.preprocess.ops.timm_autoaugment import RawTimmAutoAugment
F
Felix 已提交
18 19 20 21 22 23 24 25 26
from ppcls.data.preprocess.ops.cutout import Cutout

from ppcls.data.preprocess.ops.hide_and_seek import HideAndSeek
from ppcls.data.preprocess.ops.random_erasing import RandomErasing
from ppcls.data.preprocess.ops.grid import GridMask

from ppcls.data.preprocess.ops.operators import DecodeImage
from ppcls.data.preprocess.ops.operators import ResizeImage
from ppcls.data.preprocess.ops.operators import CropImage
H
add xbm  
HydrogenSulfate 已提交
27
from ppcls.data.preprocess.ops.operators import CenterCrop, Resize
F
Felix 已提交
28
from ppcls.data.preprocess.ops.operators import RandCropImage
H
HydrogenSulfate 已提交
29
from ppcls.data.preprocess.ops.operators import RandCropImageV2
F
Felix 已提交
30 31 32 33
from ppcls.data.preprocess.ops.operators import RandFlipImage
from ppcls.data.preprocess.ops.operators import NormalizeImage
from ppcls.data.preprocess.ops.operators import ToCHWImage
from ppcls.data.preprocess.ops.operators import AugMix
W
weishengyu 已提交
34
from ppcls.data.preprocess.ops.operators import Pad
H
HydrogenSulfate 已提交
35 36
from ppcls.data.preprocess.ops.operators import ToTensor
from ppcls.data.preprocess.ops.operators import Normalize
D
dongshuilong 已提交
37
from ppcls.data.preprocess.ops.operators import RandomHorizontalFlip
H
add xbm  
HydrogenSulfate 已提交
38
from ppcls.data.preprocess.ops.operators import RandomResizedCrop
39 40 41
from ppcls.data.preprocess.ops.operators import CropWithPadding
from ppcls.data.preprocess.ops.operators import RandomInterpolationAugment
from ppcls.data.preprocess.ops.operators import ColorJitter
Z
zh-hike 已提交
42
from ppcls.data.preprocess.ops.operators import RandomGrayscale
Z
zhiboniu 已提交
43
from ppcls.data.preprocess.ops.operators import RandomCropImage
H
HydrogenSulfate 已提交
44
from ppcls.data.preprocess.ops.operators import RandomRotation
Z
zhiboniu 已提交
45
from ppcls.data.preprocess.ops.operators import Padv2
46
from ppcls.data.preprocess.ops.operators import RandomRot90
G
gaotingquan 已提交
47
from .ops.operators import format_data
D
dongshuilong 已提交
48
from paddle.vision.transforms import Pad as Pad_paddle_vision
F
Felix 已提交
49

50
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
悟、's avatar
悟、 已提交
51
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
F
Felix 已提交
52

Z
zh-hike 已提交
53 54
from .ops.randaugmentmc import RandAugmentMC, RandomApply

F
Felix 已提交
55 56
import numpy as np
from PIL import Image
C
cuicheng01 已提交
57
import random
Z
zh-hike 已提交
58 59
from paddle.vision.transforms import transforms as T
from paddle.vision.transforms.transforms import RandomCrop, ToTensor, Normalize
F
Felix 已提交
60 61 62 63 64 65 66 67 68 69 70


def transform(data, ops=[]):
    """ transform """
    for op in ops:
        data = op(data)
    return data


class AutoAugment(RawImageNetPolicy):
    """ ImageNetPolicy wrapper to auto fit different img types """
71

F
Felix 已提交
72
    def __init__(self, *args, **kwargs):
G
gaotingquan 已提交
73
        super().__init__(*args, **kwargs)
F
Felix 已提交
74 75 76 77 78 79

    def __call__(self, img):
        if not isinstance(img, Image.Image):
            img = np.ascontiguousarray(img)
            img = Image.fromarray(img)

G
gaotingquan 已提交
80
        img = super().__call__(img)
F
Felix 已提交
81 82 83 84 85 86 87 88 89

        if isinstance(img, Image.Image):
            img = np.asarray(img)

        return img


class RandAugment(RawRandAugment):
    """ RandAugment wrapper to auto fit different img types """
90

F
Felix 已提交
91
    def __init__(self, *args, **kwargs):
G
gaotingquan 已提交
92
        super().__init__(*args, **kwargs)
F
Felix 已提交
93 94 95 96 97 98

    def __call__(self, img):
        if not isinstance(img, Image.Image):
            img = np.ascontiguousarray(img)
            img = Image.fromarray(img)

G
gaotingquan 已提交
99 100 101 102 103 104 105 106 107 108 109
        img = super().__call__(img)

        if isinstance(img, Image.Image):
            img = np.asarray(img)

        return img


class TimmAutoAugment(RawTimmAutoAugment):
    """ TimmAutoAugment wrapper to auto fit different img tyeps. """

C
cuicheng01 已提交
110
    def __init__(self, prob=1.0, *args, **kwargs):
G
gaotingquan 已提交
111
        super().__init__(*args, **kwargs)
C
cuicheng01 已提交
112
        self.prob = prob
G
gaotingquan 已提交
113

G
gaotingquan 已提交
114 115
    @format_data
    def __call__(self, img):
G
gaotingquan 已提交
116 117 118
        if not isinstance(img, Image.Image):
            img = np.ascontiguousarray(img)
            img = Image.fromarray(img)
C
cuicheng01 已提交
119 120
        if random.random() < self.prob:
            img = super().__call__(img)
F
Felix 已提交
121 122
        if isinstance(img, Image.Image):
            img = np.asarray(img)
G
gaotingquan 已提交
123 124

        return img
Z
zh-hike 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162


class BaseTransform:
    def __init__(self, cfg) -> None:
        """
        Args:
            cfg: list [dict, dict, dict]
        """
        ts = []
        for op in cfg:
            name = list(op.keys())[0]
            if op[name] is None:
                ts.append(eval(name)())
            else:
                ts.append(eval(name)(**(op[name])))

        self.t = T.Compose(ts)

    def __call__(self, img):
        
        return self.t(img)


class ListTransform:
    def __init__(self, ops) -> None:
        """
        Args:
            ops: list[list[dict, dict], ...]
        """
        self.ts = []
        for op in ops:
            self.ts.append(BaseTransform(op))

    def __call__(self, img):
        results = []
        for op in self.ts:
            results.append(op(img))
        return results