visual.py 6.6 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import math
import paddle
L
LielinJiang 已提交
17 18 19
import numpy as np
from PIL import Image

20
irange = range
郑启航 已提交
21 22


23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False):
    """Make a grid of images.
    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
        nrow (int, optional): Number of images displayed in each row of the grid.
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        normalize (bool, optional): If True, shift the image to the range (0, 1),
            by the min and max values specified by :attr:`range`. Default: ``False``.
        range (tuple, optional): tuple (min, max) where min and max are numbers,
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
    """
    if not (isinstance(tensor, paddle.Tensor) or
39 40 41 42
            (isinstance(tensor, list)
             and all(isinstance(t, paddle.Tensor) for t in tensor))):
        raise TypeError('tensor or list of tensors expected, got {}'.format(
            type(tensor)))
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

    # if list of tensors, convert to a 4D mini-batch Tensor
    if isinstance(tensor, list):
        tensor = paddle.stack(tensor, 0)

    if tensor.dim() == 2:  # single image H x W
        tensor = tensor.unsqueeze(0)
    if tensor.dim() == 3:  # single image
        if tensor.shape[0] == 1:  # if single-channel, convert to 3-channel
            tensor = paddle.concat([tensor, tensor, tensor], 0)
        tensor = tensor.unsqueeze(0)

    if tensor.dim() == 4 and tensor.shape[1] == 1:  # single-channel images
        tensor = paddle.concat([tensor, tensor, tensor], 1)

    if normalize is True:
        tensor = tensor.astype(tensor.dtype)  # avoid modifying tensor in-place
        if range is not None:
            assert isinstance(range, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, min, max):
            img[:] = img.clip(min=min, max=max)
            img[:] = (img - min) / (max - min + 1e-5)

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
                norm_ip(t, float(t.min()), float(t.max()))

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range)
        else:
            norm_range(tensor, range)

    if tensor.shape[0] == 1:
        return tensor.squeeze(0)

    # make the mini-batch of images into a grid
    nmaps = tensor.shape[0]
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height, width = int(tensor.shape[2]), int(tensor.shape[3])
    num_channels = tensor.shape[1]
郑启航 已提交
89 90
    canvas = paddle.zeros((num_channels, height * ymaps, width * xmaps),
                          dtype=tensor.dtype)
91 92 93 94 95
    k = 0
    for y in irange(ymaps):
        for x in irange(xmaps):
            if k >= nmaps:
                break
郑启航 已提交
96 97
            canvas[:, y * height:(y + 1) * height,
                   x * width:(x + 1) * width] = tensor[k]
98 99 100
            k = k + 1
    return canvas

L
LielinJiang 已提交
101

郑启航 已提交
102
def tensor2img(input_image, min_max=(-1., 1.), image_num=1, imtype=np.uint8):
L
LielinJiang 已提交
103 104 105 106
    """"Converts a Tensor array into a numpy image array.

    Parameters:
        input_image (tensor) --  the input image tensor array
郑启航 已提交
107
        image_num (int)      --  the convert iamge numbers
L
LielinJiang 已提交
108 109
        imtype (type)        --  the desired type of the converted numpy array
    """
110
    def processing(img, transpose=True):
郑启航 已提交
111 112 113 114 115
        """"processing one numpy image.

        Parameters:
            im (tensor) --  the input image numpy array
        """
116 117 118 119 120 121 122 123
        if img.shape[0] == 1:  # grayscale to RGB
            img = np.tile(img, (3, 1, 1))
        img = img.clip(min_max[0], min_max[1])
        img = (img - min_max[0]) / (min_max[1] - min_max[0])
        if imtype == np.uint8:
            img = img * 255.0  # scaling
        img = np.transpose(img, (1, 2, 0)) if transpose else img  # tranpose
        return img
郑启航 已提交
124

L
LielinJiang 已提交
125 126
    if not isinstance(input_image, np.ndarray):
        image_numpy = input_image.numpy()  # convert it into a numpy array
郑启航 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        ndim = image_numpy.ndim
        if ndim == 4:
            image_numpy = image_numpy[0:image_num]
        elif ndim == 3:
            # NOTE for eval mode, need add dim
            image_numpy = np.expand_dims(image_numpy, 0)
            image_num = 1
        else:
            raise ValueError(
                "Image numpy ndim is {} not 3 or 4, Please check data".format(
                    ndim))

        if image_num == 1:
            # for one image, log HWC image
            image_numpy = processing(image_numpy[0])
        else:
            # for more image, log NCHW image
            image_numpy = np.stack(
                [processing(im, transpose=False) for im in image_numpy])

L
LielinJiang 已提交
147 148
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
149
    image_numpy = image_numpy.round()
L
LielinJiang 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
    return image_numpy.astype(imtype)


def save_image(image_numpy, image_path, aspect_ratio=1.0):
    """Save a numpy image to the disk

    Parameters:
        image_numpy (numpy array) -- input numpy array
        image_path (str)          -- the path of the image
    """

    image_pil = Image.fromarray(image_numpy)
    h, w, _ = image_numpy.shape

    if aspect_ratio > 1.0:
        image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
    if aspect_ratio < 1.0:
        image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
    image_pil.save(image_path)
169 170 171 172 173 174 175 176 177 178


def mask2image(mask: np.array, format="HWC"):
    H, W = mask.shape

    canvas = np.zeros((H, W, 3), dtype=np.uint8)
    for i in range(int(mask.max())):
        color = np.random.rand(1, 1, 3) * 255
        canvas += (mask == i)[:, :, None] * color.astype(np.uint8)
    return canvas