cls_transforms.py 10.1 KB
Newer Older
S
syyxsxx 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
S
syyxsxx 已提交
2 3 4 5 6
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
S
syyxsxx 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
S
syyxsxx 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .ops import *
import random
import os.path as osp
import numpy as np
from PIL import Image, ImageEnhance


class ClsTransform:
    """分类Transform的基类
    """

    def __init__(self):
        pass


class Compose(ClsTransform):
    """根据数据预处理/增强算子对输入数据进行操作。
       所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。

    Args:
        transforms (list): 数据预处理/增强算子。

    Raises:
        TypeError: 形参数据类型不满足需求。
        ValueError: 数据长度不匹配。
    """

    def __init__(self, transforms):
        if not isinstance(transforms, list):
            raise TypeError('The transforms must be a list!')
        if len(transforms) < 1:
            raise ValueError('The length of transforms ' + \
                            'must be equal or larger than 1!')
        self.transforms = transforms

    def __call__(self, im, label=None):
        """
        Args:
            im (str/np.ndarray): 图像路径/图像np.ndarray数据。
            label (int): 每张图像所对应的类别序号。
        Returns:
            tuple: 根据网络所需字段所组成的tuple;
                字段由transforms中的最后一个数据预处理操作决定。
        """
        if isinstance(im, np.ndarray):
            if len(im.shape) != 3:
                raise Exception(
                    "im should be 3-dimension, but now is {}-dimensions".
                    format(len(im.shape)))
        else:
            try:
                im = cv2.imread(im).astype('float32')
            except:
                raise TypeError('Can\'t read The image file {}!'.format(im))
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        for op in self.transforms:
S
syyxsxx 已提交
71 72 73 74
            outputs = op(im, label)
            im = outputs[0]
            if len(outputs) == 2:
                label = outputs[1]
S
syyxsxx 已提交
75 76 77 78 79 80 81 82 83
        return outputs

    def add_augmenters(self, augmenters):
        if not isinstance(augmenters, list):
            raise Exception(
                "augmenters should be list type in func add_augmenters()")
        transform_names = [type(x).__name__ for x in self.transforms]
        for aug in augmenters:
            if type(aug).__name__ in transform_names:
S
syyxsxx 已提交
84 85 86
                print(
                    "{} is already in ComposedTransforms, need to remove it from add_augmenters().".
                    format(type(aug).__name__))
S
syyxsxx 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
        self.transforms = augmenters + self.transforms


class Normalize(ClsTransform):
    """对图像进行标准化。

    1. 对图像进行归一化到区间[0.0, 1.0]。
    2. 对图像进行减均值除以标准差操作。

    Args:
        mean (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
        std (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。

    """

    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.mean = mean
        self.std = std

    def __call__(self, im, label=None):
        """
        Args:
            im (np.ndarray): 图像np.ndarray数据。
            label (int): 每张图像所对应的类别序号。

        Returns:
            tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
                   当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
        """
        mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
        std = np.array(self.std)[np.newaxis, np.newaxis, :]
        im = normalize(im, mean, std)
        if label is None:
            return (im, )
        else:
            return (im, label)


class ResizeByShort(ClsTransform):
    """根据图像短边对图像重新调整大小(resize)。

    1. 获取图像的长边和短边长度。
    2. 根据短边与short_size的比例,计算长边的目标长度,
       此时高、宽的resize比例为short_size/原图短边长度。
    3. 如果max_size>0,调整resize比例:
       如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度;
    4. 根据调整大小的比例对图像进行resize。

    Args:
        short_size (int): 调整大小后的图像目标短边长度。默认为256。
        max_size (int): 长边目标长度的最大限制。默认为-1。
    """

    def __init__(self, short_size=256, max_size=-1):
        self.short_size = short_size
        self.max_size = max_size

    def __call__(self, im, label=None):
        """
        Args:
            im (np.ndarray): 图像np.ndarray数据。
            label (int): 每张图像所对应的类别序号。

        Returns:
            tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
                   当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
        """
        im_short_size = min(im.shape[0], im.shape[1])
        im_long_size = max(im.shape[0], im.shape[1])
        scale = float(self.short_size) / im_short_size
        if self.max_size > 0 and np.round(scale *
                                          im_long_size) > self.max_size:
            scale = float(self.max_size) / float(im_long_size)
        resized_width = int(round(im.shape[1] * scale))
        resized_height = int(round(im.shape[0] * scale))
        im = cv2.resize(
            im, (resized_width, resized_height),
            interpolation=cv2.INTER_LINEAR)

        if label is None:
            return (im, )
        else:
            return (im, label)


class CenterCrop(ClsTransform):
    """以图像中心点扩散裁剪长宽为`crop_size`的正方形

    1. 计算剪裁的起始点。
    2. 剪裁图像。

    Args:
        crop_size (int): 裁剪的目标边长。默认为224。
    """

    def __init__(self, crop_size=224):
        self.crop_size = crop_size

    def __call__(self, im, label=None):
        """
        Args:
            im (np.ndarray): 图像np.ndarray数据。
            label (int): 每张图像所对应的类别序号。

        Returns:
            tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
                   当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
        """
        im = center_crop(im, self.crop_size)
        if label is None:
            return (im, )
        else:
            return (im, label)


class ArrangeClassifier(ClsTransform):
    """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用

    Args:
        mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。

    Raises:
        ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
    """

    def __init__(self, mode=None):
        if mode not in ['train', 'eval', 'test', 'quant']:
            raise ValueError(
                "mode must be in ['train', 'eval', 'test', 'quant']!")
        self.mode = mode

    def __call__(self, im, label=None):
        """
        Args:
            im (np.ndarray): 图像np.ndarray数据。
            label (int): 每张图像所对应的类别序号。

        Returns:
            tuple: 当mode为'train'或'eval'时,返回(im, label),分别对应图像np.ndarray数据、
                图像类别id;当mode为'test'或'quant'时,返回(im, ),对应图像np.ndarray数据。
        """
        im = permute(im, False).astype('float32')
        if self.mode == 'train' or self.mode == 'eval':
            outputs = (im, label)
        else:
            outputs = (im, )
        return outputs


class ComposedClsTransforms(Compose):
    """ 分类模型的基础Transforms流程,具体如下
        训练阶段:
        1. 随机从图像中crop一块子图,并resize成crop_size大小
        2. 将1的输出按0.5的概率随机进行水平翻转
        3. 将图像进行归一化
        验证/预测阶段:
        1. 将图像按比例Resize,使得最小边长度为crop_size[0] * 1.14
        2. 从图像中心crop出一个大小为crop_size的图像
        3. 将图像进行归一化

        Args:
            mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
            crop_size(int|list): 输入模型里的图像大小
            mean(list): 图像均值
            std(list): 图像方差
    """

    def __init__(self,
                 mode,
                 crop_size=[224, 224],
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]):
        width = crop_size
        if isinstance(crop_size, list):
            if crop_size[0] != crop_size[1]:
                raise Exception(
                    "In classifier model, width and height should be equal, please modify your parameter `crop_size`"
                )
            width = crop_size[0]
        if width % 32 != 0:
            raise Exception(
                "In classifier model, width and height should be multiple of 32, e.g 224、256、320...., please modify your parameter `crop_size`"
            )

        if mode == 'train':
S
syyxsxx 已提交
272
            pass
S
syyxsxx 已提交
273 274 275 276 277 278 279 280 281
        else:
            # 验证/预测时的transforms
            transforms = [
                ResizeByShort(short_size=int(width * 1.14)),
                CenterCrop(crop_size=width), Normalize(
                    mean=mean, std=std)
            ]

        super(ComposedClsTransforms, self).__init__(transforms)