# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import base64 from typing import Callable from typing import List from typing import Tuple from typing import Union import cv2 import numpy as np import paddle import paddle.nn.functional as F from paddleseg.transforms import functional class Compose: """ Do transformation on input data with corresponding pre-processing and augmentation operations. The shape of input data to all operations is [height, width, channels]. """ def __init__(self, transforms: Callable, to_rgb: bool = True): if not isinstance(transforms, list): raise TypeError('The transforms must be a list!') self.transforms = transforms self.to_rgb = to_rgb def __call__(self, data: dict) -> dict: if 'trans_info' not in data: data['trans_info'] = [] for op in self.transforms: data = op(data) if data is None: return None data['img'] = np.transpose(data['img'], (2, 0, 1)) for key in data.get('gt_fields', []): if len(data[key].shape) == 2: continue data[key] = np.transpose(data[key], (2, 0, 1)) return data class LoadImages: """ Read images from image path. Args: to_rgb (bool, optional): If converting image to RGB color space. Default: True. """ def __init__(self, to_rgb: bool = True): self.to_rgb = to_rgb def __call__(self, data: dict) -> dict: if isinstance(data['img'], str): data['img'] = cv2.imread(data['img']) for key in data.get('gt_fields', []): if isinstance(data[key], str): data[key] = cv2.imread(data[key], cv2.IMREAD_UNCHANGED) # if alpha and trimap has 3 channels, extract one. if key in ['alpha', 'trimap']: if len(data[key].shape) > 2: data[key] = data[key][:, :, 0] if self.to_rgb: data['img'] = cv2.cvtColor(data['img'], cv2.COLOR_BGR2RGB) for key in data.get('gt_fields', []): if len(data[key].shape) == 2: continue data[key] = cv2.cvtColor(data[key], cv2.COLOR_BGR2RGB) return data class ResizeByShort: """ Resize the short side of an image to given size, and then scale the other side proportionally. Args: short_size (int): The target size of short side. """ def __init__(self, short_size: int = 512): self.short_size = short_size def __call__(self, data: dict) -> dict: data['trans_info'].append(('resize', data['img'].shape[0:2])) data['img'] = functional.resize_short(data['img'], self.short_size) for key in data.get('gt_fields', []): data[key] = functional.resize_short(data[key], self.short_size) return data class ResizeToIntMult: """ Resize to some int muitple, d.g. 32. """ def __init__(self, mult_int: int = 32): self.mult_int = mult_int def __call__(self, data: dict) -> dict: data['trans_info'].append(('resize', data['img'].shape[0:2])) h, w = data['img'].shape[0:2] rw = w - w % 32 rh = h - h % 32 data['img'] = functional.resize(data['img'], (rw, rh)) for key in data.get('gt_fields', []): data[key] = functional.resize(data[key], (rw, rh)) return data class Normalize: """ Normalize an image. Args: mean (list, optional): The mean value of a data set. Default: [0.5, 0.5, 0.5]. std (list, optional): The standard deviation of a data set. Default: [0.5, 0.5, 0.5]. Raises: ValueError: When mean/std is not list or any value in std is 0. """ def __init__(self, mean: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5), std: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5)): self.mean = mean self.std = std if not (isinstance(self.mean, (list, tuple)) and isinstance(self.std, (list, tuple))): raise ValueError("{}: input type is invalid. It should be list or tuple".format(self)) from functools import reduce if reduce(lambda x, y: x * y, self.std) == 0: raise ValueError('{}: std is invalid!'.format(self)) def __call__(self, data: dict) -> dict: mean = np.array(self.mean)[np.newaxis, np.newaxis, :] std = np.array(self.std)[np.newaxis, np.newaxis, :] data['img'] = functional.normalize(data['img'], mean, std) if 'fg' in data.get('gt_fields', []): data['fg'] = functional.normalize(data['fg'], mean, std) if 'bg' in data.get('gt_fields', []): data['bg'] = functional.normalize(data['bg'], mean, std) return data def reverse_transform(alpha: paddle.Tensor, trans_info: List[str]): """recover pred to origin shape""" for item in trans_info[::-1]: if item[0] == 'resize': h, w = item[1][0], item[1][1] alpha = F.interpolate(alpha, [h, w], mode='bilinear') elif item[0] == 'padding': h, w = item[1][0], item[1][1] alpha = alpha[:, :, 0:h, 0:w] else: raise Exception("Unexpected info '{}' in im_info".format(item[0])) return alpha def save_alpha_pred(alpha: np.ndarray, trimap: np.ndarray = None): """ The value of alpha is range [0, 1], shape should be [h,w] """ if isinstance(trimap, str): trimap = cv2.imread(trimap, 0) alpha[trimap == 0] = 0 alpha[trimap == 255] = 255 alpha = (alpha).astype('uint8') return alpha def cv2_to_base64(image: np.ndarray): """ Convert data from BGR to base64 format. """ data = cv2.imencode('.png', image)[1] return base64.b64encode(data.tostring()).decode('utf8') def base64_to_cv2(b64str: str): """ Convert data from base64 to BGR format. """ data = base64.b64decode(b64str.encode('utf8')) data = np.fromstring(data, np.uint8) data = cv2.imdecode(data, cv2.IMREAD_COLOR) return data