diff --git a/configs/cond_dcgan_mnist.yaml b/configs/cond_dcgan_mnist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c3aba181863b481ba44720d376973f761e855e3 --- /dev/null +++ b/configs/cond_dcgan_mnist.yaml @@ -0,0 +1,67 @@ +epochs: 200 +output_dir: output_dir + +model: + name: GANModel + generator: + name: ConditionalDeepConvGenerator + latent_dim: 128 + output_nc: 1 + size: 28 + ngf: 64 + n_class: 10 + discriminator: + name: NLayerDiscriminatorWithClassification + ndf: 16 + n_layers: 3 + input_nc: 1 + norm_type: batch + n_class: 10 + use_sigmoid: True + gan_mode: vanilla + +dataset: + train: + name: CommonVisionDataset + class_name: MNIST + dataroot: None + num_workers: 4 + batch_size: 64 + mode: train + return_cls: True + transforms: + - name: Normalize + mean: [127.5] + std: [127.5] + keys: [image] + test: + name: CommonVisionDataset + class_name: MNIST + dataroot: None + num_workers: 0 + batch_size: 64 + mode: test + transforms: + - name: Normalize + mean: [127.5] + std: [127.5] + keys: [image] + return_cls: True + + +optimizer: + name: Adam + beta1: 0.5 + +lr_scheduler: + name: linear + learning_rate: 0.0002 + start_epoch: 100 + decay_epochs: 100 + +log_config: + interval: 100 + visiual_interval: 500 + +snapshot_config: + interval: 5 diff --git a/configs/wgan_mnist.yaml b/configs/wgan_mnist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4db41860e908bd04fbb20b6024e708a756a51287 --- /dev/null +++ b/configs/wgan_mnist.yaml @@ -0,0 +1,65 @@ +epochs: 200 +output_dir: output_dir + +model: + name: GANModel + generator: + name: DeepConvGenerator + latent_dim: 128 + output_nc: 1 + size: 28 + ngf: 64 + discriminator: + name: NLayerDiscriminator + ndf: 16 + n_layers: 3 + input_nc: 1 + norm_type: instance + gan_mode: wgan + n_critic: 5 + +dataset: + train: + name: CommonVisionDataset + class_name: MNIST + dataroot: None + num_workers: 4 + batch_size: 64 + mode: train + return_cls: False + transforms: + - name: Normalize + mean: [127.5] + std: [127.5] + keys: [image] + test: + name: CommonVisionDataset + class_name: MNIST + dataroot: None + num_workers: 0 + batch_size: 64 + mode: test + transforms: + - name: Normalize + mean: [127.5] + std: [127.5] + keys: [image] + return_cls: False + + +optimizer: + name: Adam + beta1: 0.5 + +lr_scheduler: + name: linear + learning_rate: 0.0002 + start_epoch: 100 + decay_epochs: 100 + +log_config: + interval: 100 + visiual_interval: 500 + +snapshot_config: + interval: 5 diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index a1e541803039f0e0458276d84cc512af327ae3e3..5b9d9568837d04be1ace8b6516a3204f95f861f8 100644 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -17,4 +17,5 @@ from .single_dataset import SingleDataset from .paired_dataset import PairedDataset from .sr_image_dataset import SRImageDataset from .makeup_dataset import MakeupDataset +from .common_vision_dataset import CommonVisionDataset from .animeganv2_dataset import AnimeGANV2Dataset diff --git a/ppgan/datasets/common_vision_dataset.py b/ppgan/datasets/common_vision_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f6bab62aeb0b20221132b9bb2c16aa4c7510f46a --- /dev/null +++ b/ppgan/datasets/common_vision_dataset.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from .builder import DATASETS +from .base_dataset import BaseDataset +from .transforms.builder import build_transforms + + +@DATASETS.register() +class CommonVisionDataset(BaseDataset): + """ + Dataset for using paddle vision default datasets + """ + def __init__(self, cfg): + """Initialize this dataset class. + + Args: + cfg (dict) -- stores all the experiment flags + """ + super(CommonVisionDataset, self).__init__(cfg) + + dataset_cls = getattr(paddle.vision.datasets, cfg.pop('class_name')) + transform = build_transforms(cfg.pop('transforms', None)) + self.return_cls = cfg.pop('return_cls', True) + + param_dict = {} + param_names = list(dataset_cls.__init__.__code__.co_varnames) + if 'transform' in param_names: + param_dict['transform'] = transform + for name in param_names: + if name in cfg: + param_dict[name] = cfg.get(name) + + self.dataset = dataset_cls(**param_dict) + + def __getitem__(self, index): + return_dict = {} + return_list = self.dataset[index] + if isinstance(return_list, (tuple, list)): + if len(return_list) == 2: + return_dict['img'] = return_list[0] + if self.return_cls: + return_dict['class_id'] = np.asarray(return_list[1]) + else: + return_dict['img'] = return_list[0] + else: + return_dict['img'] = return_list + + return return_dict + + def __len__(self): + return len(self.dataset) diff --git a/ppgan/datasets/single_dataset.py b/ppgan/datasets/single_dataset.py index 9cc72bb4951bff59de58e4df987e7faed63d00b5..246f769259cae4d98105dbfd4f00b271eb5f1bd9 100644 --- a/ppgan/datasets/single_dataset.py +++ b/ppgan/datasets/single_dataset.py @@ -57,7 +57,7 @@ class SingleDataset(BaseDataset): return len(self.A_paths) def get_path_by_indexs(self, indexs): - if isinstance(indexs, paddle.Variable): + if isinstance(indexs, paddle.Tensor): indexs = indexs.numpy() current_paths = [] for index in indexs: diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 62839bc80ddfe8a5dce1bfc27100c935a2360763..12a78ff7c0734e890e4c18f71075ae9c69f11e64 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .base_model import BaseModel +from .gan_model import GANModel from .cycle_gan_model import CycleGANModel from .pix2pix_model import Pix2PixModel from .srgan_model import SRGANModel diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index 9157dd5ec2fe583282df058187fea0067121467f..a2c991511c79e76e71535b8b8dd7c9b244d198da 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .nlayers import NLayerDiscriminator +from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification from .discriminator_ugatit import UGATITDiscriminator from .dcdiscriminator import DCDiscriminator from .discriminator_animegan import AnimeDiscriminator diff --git a/ppgan/models/discriminators/nlayers.py b/ppgan/models/discriminators/nlayers.py index 9395c8227e92f79168ba1431e0abe7079b83c3af..938b1f43ee297333e9b485f434ee242d6ba4ec05 100644 --- a/ppgan/models/discriminators/nlayers.py +++ b/ppgan/models/discriminators/nlayers.py @@ -27,7 +27,7 @@ from .builder import DISCRIMINATORS @DISCRIMINATORS.register() class NLayerDiscriminator(nn.Layer): """Defines a PatchGAN discriminator""" - def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance'): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance', use_sigmoid=False): """Construct a PatchGAN discriminator Parameters: @@ -35,6 +35,7 @@ class NLayerDiscriminator(nn.Layer): ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_type (str) -- normalization layer type + use_sigmoid (bool) -- whether use sigmoid at last """ super(NLayerDiscriminator, self).__init__() norm_layer = build_norm_layer(norm_type) @@ -139,7 +140,28 @@ class NLayerDiscriminator(nn.Layer): ] # output 1 channel prediction map self.model = nn.Sequential(*sequence) + self.final_act = F.sigmoid if use_sigmoid else (lambda x:x) def forward(self, input): """Standard forward.""" - return self.model(input) + return self.final_act(self.model(input)) + + +@DISCRIMINATORS.register() +class NLayerDiscriminatorWithClassification(NLayerDiscriminator): + def __init__(self, input_nc, n_class=10, **kwargs): + input_nc = input_nc + n_class + super(NLayerDiscriminatorWithClassification, self).__init__(input_nc, **kwargs) + + self.n_class = n_class + + def forward(self, x, class_id): + if self.n_class > 0: + class_id = (class_id % self.n_class).detach() + class_id = F.one_hot(class_id, self.n_class).astype('float32') + class_id = class_id.reshape([x.shape[0], -1, 1, 1]) + class_id = class_id.tile([1,1,*x.shape[2:]]) + x = paddle.concat([x, class_id], 1) + + return super(NLayerDiscriminatorWithClassification, self).forward(x) + diff --git a/ppgan/models/gan_model.py b/ppgan/models/gan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2981d11ab327978bebf5dd1cadc0f999901fb088 --- /dev/null +++ b/ppgan/models/gan_model.py @@ -0,0 +1,185 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +from .base_model import BaseModel + +from .builder import MODELS +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator +from .losses import GANLoss + +from ..solver import build_optimizer +from ..modules.init import init_weights +from ..utils.visual import make_grid + + +@MODELS.register() +class GANModel(BaseModel): + """ This class implements the vanilla GAN model with some tricks. + + vanilla GAN paper: https://arxiv.org/abs/1406.2661 + """ + def __init__(self, cfg): + """Initialize the GAN Model class. + + Parameters: + cfg (config dict)-- stores all the experiment flags; needs to be a subclass of Dict + """ + super(GANModel, self).__init__(cfg) + self.step = 0 + self.n_critic = cfg.model.get('n_critic', 1) + self.visual_interval = cfg.log_config.visiual_interval + self.samples_every_row = cfg.model.get('samples_every_row', 8) + + # define networks (both generator and discriminator) + self.nets['netG'] = build_generator(cfg.model.generator) + init_weights(self.nets['netG']) + + # define a discriminator + if self.is_train: + self.nets['netD'] = build_discriminator(cfg.model.discriminator) + init_weights(self.nets['netD']) + + if self.is_train: + self.losses = {} + # define loss functions + self.criterionGAN = GANLoss(cfg.model.gan_mode) + + # build optimizers + self.build_lr_scheduler() + self.optimizers['optimizer_G'] = build_optimizer( + cfg.optimizer, + self.lr_scheduler, + parameter_list=self.nets['netG'].parameters()) + self.optimizers['optimizer_D'] = build_optimizer( + cfg.optimizer, + self.lr_scheduler, + parameter_list=self.nets['netD'].parameters()) + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (list): include the data itself and its metadata information. + """ + if isinstance(input, (list, tuple)): + input = input[0] + if not isinstance(input, dict): + input = {'img': input} + self.D_real_inputs = [paddle.to_tensor(input['img'])] + if 'class_id' in input: # n class input + self.n_class = self.nets['netG'].n_class + self.D_real_inputs += [paddle.to_tensor(input['class_id'], dtype='int64')] + else: + self.n_class = 0 + + batch_size = self.D_real_inputs[0].shape[0] + self.G_inputs = self.nets['netG'].random_inputs(batch_size) + if not isinstance(self.G_inputs, (list, tuple)): + self.G_inputs = [self.G_inputs] + + if not hasattr(self, 'G_fixed_inputs'): + self.G_fixed_inputs = [t for t in self.G_inputs] + if self.n_class > 0: + rows_num = (batch_size - 1) // self.samples_every_row + 1 + class_ids = paddle.randint(0, self.n_class, [rows_num, 1]) + class_ids = class_ids.tile([1, self.samples_every_row]) + class_ids = class_ids.reshape([-1,])[:batch_size].detach() + self.G_fixed_inputs[1] = class_ids.detach() + + def forward(self): + """Run forward pass; called by both functions and .""" + self.fake_imgs = self.nets['netG'](*self.G_inputs) # G(img, class_id) + + # put items to visual dict + self.visual_items['fake_imgs'] = make_grid(self.fake_imgs, self.samples_every_row).detach() + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # Fake; stop backprop to the generator by detaching fake_imgs + # use conditional GANs; we need to feed both input and output to the discriminator + self.loss_D = 0 + self.D_fake_inputs = [self.fake_imgs.detach()] + if len(self.G_inputs) > 1 and self.G_inputs[1] is not None: + self.D_fake_inputs += [self.G_inputs[1]] + pred_fake = self.nets['netD'](*self.D_fake_inputs) + # Real + real_imgs = self.D_real_inputs[0] + self.visual_items['real_imgs'] = make_grid(real_imgs, self.samples_every_row).detach() + pred_real = self.nets['netD'](*self.D_real_inputs) + + self.loss_D_fake = self.criterionGAN(pred_fake, False, True) + self.loss_D_real = self.criterionGAN(pred_real, True, True) + + # combine loss and calculate gradients + if self.cfg.model.gan_mode in ['vanilla', 'lsgan']: + self.loss_D = self.loss_D + (self.loss_D_fake + self.loss_D_real) * 0.5 + else: + self.loss_D = self.loss_D + self.loss_D_fake + self.loss_D_real + + self.loss_D.backward() + + self.losses['D_fake_loss'] = self.loss_D_fake + self.losses['D_real_loss'] = self.loss_D_real + + def backward_G(self): + """Calculate GAN loss for the generator""" + # First, G(imgs) should fake the discriminator + self.D_fake_inputs = [self.fake_imgs] + if len(self.G_inputs) > 1 and self.G_inputs[1] is not None: + self.D_fake_inputs += [self.G_inputs[1]] + pred_fake = self.nets['netD'](*self.D_fake_inputs) + + self.loss_G_GAN = self.criterionGAN(pred_fake, True, False) + + # combine loss and calculate gradients + self.loss_G = self.loss_G_GAN + + self.loss_G.backward() + + self.losses['G_adv_loss'] = self.loss_G_GAN + + def optimize_parameters(self): + + # compute fake images: G(imgs) + self.forward() + + # update D + self.set_requires_grad(self.nets['netD'], True) + self.optimizers['optimizer_D'].clear_grad() + self.backward_D() + self.optimizers['optimizer_D'].step() + self.set_requires_grad(self.nets['netD'], False) + + # weight clip + if self.cfg.model.gan_mode == 'wgan': + with paddle.no_grad(): + for p in self.nets['netD'].parameters(): + p[:] = p.clip(-0.01, 0.01) + + if self.step % self.n_critic == 0: + # update G + self.optimizers['optimizer_G'].clear_grad() + self.backward_G() + self.optimizers['optimizer_G'].step() + + if self.step % self.visual_interval == 0: + with paddle.no_grad(): + self.visual_items['fixed_generated_imgs'] = make_grid( + self.nets['netG'](*self.G_fixed_inputs), self.samples_every_row + ) + + self.step += 1 diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 0c6fd94b7868b196142f622e1d99910f830d2cbe..ad04c1cdf2d164eaa62b288a902aedf6594dcdb9 100644 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -16,6 +16,7 @@ from .resnet import ResnetGenerator from .unet import UnetGenerator from .rrdb_net import RRDBNet from .makeup import GeneratorPSGANAttention +from .deep_conv import DeepConvGenerator, ConditionalDeepConvGenerator from .resnet_ugatit import ResnetUGATITGenerator from .dcgenerator import DCGenerator from .generater_animegan import AnimeGenerator, AnimeGeneratorLite diff --git a/ppgan/models/generators/deep_conv.py b/ppgan/models/generators/deep_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9712c9f6b1c505d9a981bc4d8c45db53739b0188 --- /dev/null +++ b/ppgan/models/generators/deep_conv.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .builder import GENERATORS + + +@GENERATORS.register() +class DeepConvGenerator(nn.Layer): + """Create a Deep Convolutional generator""" + def __init__(self, latent_dim, output_nc, size=64, ngf=64): + """Construct a Deep Convolutional generator + Args: + latent_dim (int) -- the number of latent dimension + output_nc (int) -- the number of channels in output images + size (int) -- size of output tensor + ngf (int) -- the number of filters in the last conv layer + + Refer to https://arxiv.org/abs/1511.06434 + """ + super(DeepConvGenerator, self).__init__() + + self.latent_dim = latent_dim + self.ngf = ngf + self.init_size = size // 4 + self.l1 = nn.Sequential(nn.Linear(latent_dim, ngf*2 * self.init_size ** 2)) + + self.conv_blocks = nn.Sequential( + nn.BatchNorm2D(ngf*2), + nn.Upsample(scale_factor=2), + nn.Conv2D(ngf*2, ngf*2, 3, stride=1, padding=1), + nn.BatchNorm2D(ngf*2, 0.2), + nn.LeakyReLU(0.2), + nn.Upsample(scale_factor=2), + nn.Conv2D(ngf*2, ngf, 3, stride=1, padding=1), + nn.BatchNorm2D(ngf, 0.2), + nn.LeakyReLU(0.2), + nn.Conv2D(ngf, output_nc, 3, stride=1, padding=1), + nn.Tanh(), + ) + + def random_inputs(self, batch_size): + return paddle.randn([batch_size, self.latent_dim]) + + def forward(self, z): + out = self.l1(z) + out = out.reshape([out.shape[0], self.ngf * 2, self.init_size, self.init_size]) + img = self.conv_blocks(out) + return img + + +@GENERATORS.register() +class ConditionalDeepConvGenerator(DeepConvGenerator): + def __init__(self, latent_dim, output_nc, n_class=10, **kwargs): + super(ConditionalDeepConvGenerator, self).__init__(latent_dim + n_class, output_nc, **kwargs) + + self.n_class = n_class + self.latent_dim = latent_dim + + def random_inputs(self, batch_size): + return_list = [super(ConditionalDeepConvGenerator, self).random_inputs(batch_size)] + class_id = paddle.randint(0, self.n_class, [batch_size]) + return return_list + [class_id] + + def forward(self, x, class_id=None): + if self.n_class > 0: + class_id = (class_id % self.n_class).detach() + class_id = F.one_hot(class_id, self.n_class).astype('float32') + class_id = class_id.reshape([x.shape[0], -1]) + x = paddle.concat([x, class_id], 1) + + return super(ConditionalDeepConvGenerator, self).forward(x) diff --git a/ppgan/models/losses.py b/ppgan/models/losses.py index 7dd63f24a3120b66108f500231d47218a587c0b4..9139266894e44d3a284d6347d6043b33a1c43c1c 100644 --- a/ppgan/models/losses.py +++ b/ppgan/models/losses.py @@ -16,6 +16,7 @@ import numpy as np import paddle import paddle.nn as nn +import paddle.nn.functional as F class GANLoss(nn.Layer): @@ -44,7 +45,7 @@ class GANLoss(nn.Layer): self.loss = nn.MSELoss() elif gan_mode == 'vanilla': self.loss = nn.BCEWithLogitsLoss() - elif gan_mode in ['wgangp']: + elif gan_mode in ['wgan', 'wgangp', 'hinge', 'logistic']: self.loss = None else: raise NotImplementedError('gan mode %s not implemented' % gan_mode) @@ -77,12 +78,13 @@ class GANLoss(nn.Layer): # target_tensor.stop_gradient = True return target_tensor - def __call__(self, prediction, target_is_real): + def __call__(self, prediction, target_is_real, is_updating_D=None): """Calculate loss given Discriminator's output and grount truth labels. Parameters: prediction (tensor) - - tpyically the prediction output from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images + is_updating_D (bool) - - if we are in updating D step or not Returns: the calculated loss. @@ -90,9 +92,20 @@ class GANLoss(nn.Layer): if self.gan_mode in ['lsgan', 'vanilla']: target_tensor = self.get_target_tensor(prediction, target_is_real) loss = self.loss(prediction, target_tensor) - elif self.gan_mode == 'wgangp': + elif self.gan_mode.find('wgan') != -1: if target_is_real: loss = -prediction.mean() else: loss = prediction.mean() + elif self.gan_mode == 'hinge': + if target_is_real: + loss = F.relu(1 - prediction) if is_updating_D else -prediction + else: + loss = F.relu(1 + prediction) if is_updating_D else prediction + loss = loss.mean() + elif self.gan_mode == 'logistic': + if target_is_real: + loss = F.softplus(-prediction).mean() + else: + loss = F.softplus(prediction).mean() return loss diff --git a/ppgan/utils/visual.py b/ppgan/utils/visual.py index 2b4a184d08f27d7eb6ebe8c6d3b1d12b940bacb3..13806382920a3d6c0bb3ad2b238cbce58a1b6aea 100644 --- a/ppgan/utils/visual.py +++ b/ppgan/utils/visual.py @@ -12,9 +12,86 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math +import paddle import numpy as np from PIL import Image +irange = range +def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False): + """Make a grid of images. + Args: + tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) + or a list of images all of the same size. + nrow (int, optional): Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + normalize (bool, optional): If True, shift the image to the range (0, 1), + by the min and max values specified by :attr:`range`. Default: ``False``. + range (tuple, optional): tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + scale_each (bool, optional): If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + """ + if not (isinstance(tensor, paddle.Tensor) or + (isinstance(tensor, list) and all(isinstance(tensor, t) for t in tensor))): + raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor))) + + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = paddle.stack(tensor, 0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.unsqueeze(0) + if tensor.dim() == 3: # single image + if tensor.shape[0] == 1: # if single-channel, convert to 3-channel + tensor = paddle.concat([tensor, tensor, tensor], 0) + tensor = tensor.unsqueeze(0) + + if tensor.dim() == 4 and tensor.shape[1] == 1: # single-channel images + tensor = paddle.concat([tensor, tensor, tensor], 1) + + if normalize is True: + tensor = tensor.astype(tensor.dtype) # avoid modifying tensor in-place + if range is not None: + assert isinstance(range, tuple), \ + "range has to be a tuple (min, max) if specified. min and max are numbers" + + def norm_ip(img, min, max): + img[:] = img.clip(min=min, max=max) + img[:] = (img - min) / (max - min + 1e-5) + + def norm_range(t, range): + if range is not None: + norm_ip(t, range[0], range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, range) + else: + norm_range(tensor, range) + + if tensor.shape[0] == 1: + return tensor.squeeze(0) + + # make the mini-batch of images into a grid + nmaps = tensor.shape[0] + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.shape[2]), int(tensor.shape[3]) + num_channels = tensor.shape[1] + canvas = paddle.zeros((num_channels, height * ymaps, width * xmaps), dtype=tensor.dtype) + k = 0 + for y in irange(ymaps): + for x in irange(xmaps): + if k >= nmaps: + break + canvas[:, y * height:(y + 1) * height, x * width:(x + 1) * width] = tensor[k] + k = k + 1 + return canvas + def tensor2img(input_image, min_max=(-1., 1.), imtype=np.uint8): """"Converts a Tensor array into a numpy image array.