diff --git a/configs/starganv2_afhq.yaml b/configs/starganv2_afhq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a7ec539b09e8b6b8a004ad0f87d3812ae480afb --- /dev/null +++ b/configs/starganv2_afhq.yaml @@ -0,0 +1,141 @@ +epochs: 200 +output_dir: output_dir + +model: + name: StarGANv2Model + latent_dim: &LATENT_DIM 16 + lambda_sty: 1 + lambda_ds: 2 + lambda_cyc: 1 + generator: + name: StarGANv2Generator + img_size: &IMAGE_SIZE 256 + w_hpf: 0 + style_dim: &STYLE_DIM 64 + style: + name: StarGANv2Style + img_size: *IMAGE_SIZE + style_dim: *STYLE_DIM + num_domains: &NUM_DOMAINS 3 + mapping: + name: StarGANv2Mapping + latent_dim: *LATENT_DIM + style_dim: *STYLE_DIM + num_domains: *NUM_DOMAINS + discriminator: + name: StarGANv2Discriminator + img_size: *IMAGE_SIZE + num_domains: *NUM_DOMAINS + +dataset: + train: + name: StarGANv2Dataset + dataroot: data/stargan-v2/afhq/train + is_train: True + num_workers: 8 + batch_size: 4 + preprocess: + - name: LoadImageFromFile + key: src + - name: LoadImageFromFile + key: ref + - name: LoadImageFromFile + key: ref2 + - name: Transforms + input_keys: [src, ref, ref2] + pipeline: + - name: RandomResizedCropProb + prob: 0.9 + size: [*IMAGE_SIZE, *IMAGE_SIZE] + scale: [0.8, 1.0] + ratio: [0.9, 1.1] + interpolation: 'bilinear' + keys: [image, image, image] + - name: Resize + size: [*IMAGE_SIZE, *IMAGE_SIZE] + interpolation: 'bilinear' + keys: [image, image, image] + - name: RandomHorizontalFlip + prob: 0.5 + keys: [image, image, image] + - name: Transpose + keys: [image, image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image, image] + + test: + name: StarGANv2Dataset + dataroot: data/stargan-v2/afhq/val + is_train: False + num_workers: 8 + batch_size: 16 + test_count: 16 + preprocess: + - name: LoadImageFromFile + key: src + - name: LoadImageFromFile + key: ref + - name: Transforms + input_keys: [src, ref] + pipeline: + - name: Resize + size: [*IMAGE_SIZE, *IMAGE_SIZE] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] + +lr_scheduler: + name: LinearDecay + learning_rate: 0.0001 + start_epoch: 100 + decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 365 + +optimizer: + generator: + name: Adam + net_names: + - generator + beta1: 0.0 + beta2: 0.99 + weight_decay: 0.0001 + style_encoder: + name: Adam + net_names: + - style_encoder + beta1: 0.0 + beta2: 0.99 + weight_decay: 0.0001 + mapping_network: + name: Adam + net_names: + - mapping_network + beta1: 0.0 + beta2: 0.99 + weight_decay: 0.0001 + discriminator: + name: Adam + net_names: + - discriminator + beta1: 0.0 + beta2: 0.99 + weight_decay: 0.0001 + +validate: + interval: 5000 + save_img: false + +log_config: + interval: 5 + visiual_interval: 100 + +snapshot_config: + interval: 5 diff --git a/configs/starganv2_celeba_hq.yaml b/configs/starganv2_celeba_hq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c5e00e130695772a610affecfe8bd4dc4c2571dd --- /dev/null +++ b/configs/starganv2_celeba_hq.yaml @@ -0,0 +1,144 @@ +epochs: 200 +output_dir: output_dir + +model: + name: StarGANv2Model + latent_dim: &LATENT_DIM 16 + lambda_sty: 1 + lambda_ds: 1 + lambda_cyc: 1 + generator: + name: StarGANv2Generator + img_size: &IMAGE_SIZE 256 + w_hpf: 1 + style_dim: &STYLE_DIM 64 + style: + name: StarGANv2Style + img_size: *IMAGE_SIZE + style_dim: *STYLE_DIM + num_domains: &NUM_DOMAINS 2 + mapping: + name: StarGANv2Mapping + latent_dim: *LATENT_DIM + style_dim: *STYLE_DIM + num_domains: *NUM_DOMAINS + fan: + name: FAN + fname_pretrained: models/stargan-v2/wing.pdparams + discriminator: + name: StarGANv2Discriminator + img_size: *IMAGE_SIZE + num_domains: *NUM_DOMAINS + +dataset: + train: + name: StarGANv2Dataset + dataroot: data/stargan-v2/celeba_hq/train/ + is_train: True + num_workers: 8 + batch_size: 4 + preprocess: + - name: LoadImageFromFile + key: src + - name: LoadImageFromFile + key: ref + - name: LoadImageFromFile + key: ref2 + - name: Transforms + input_keys: [src, ref, ref2] + pipeline: + - name: RandomResizedCropProb + prob: 0.9 + size: [*IMAGE_SIZE, *IMAGE_SIZE] + scale: [0.8, 1.0] + ratio: [0.9, 1.1] + interpolation: 'bilinear' + keys: [image, image, image] + - name: Resize + size: [*IMAGE_SIZE, *IMAGE_SIZE] + interpolation: 'bilinear' + keys: [image, image, image] + - name: RandomHorizontalFlip + prob: 0.5 + keys: [image, image, image] + - name: Transpose + keys: [image, image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image, image] + + test: + name: StarGANv2Dataset + dataroot: data/stargan-v2/celeba_hq/val/ + is_train: False + num_workers: 8 + batch_size: 16 + test_count: 16 + preprocess: + - name: LoadImageFromFile + key: src + - name: LoadImageFromFile + key: ref + - name: Transforms + input_keys: [src, ref] + pipeline: + - name: Resize + size: [*IMAGE_SIZE, *IMAGE_SIZE] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] + +lr_scheduler: + name: LinearDecay + learning_rate: 0.0001 + start_epoch: 100 + decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 365 + +optimizer: + generator: + name: Adam + net_names: + - generator + beta1: 0.0 + beta2: 0.99 + weight_decay: 0.0001 + style_encoder: + name: Adam + net_names: + - style_encoder + beta1: 0.0 + beta2: 0.99 + weight_decay: 0.0001 + mapping_network: + name: Adam + net_names: + - mapping_network + beta1: 0.0 + beta2: 0.99 + weight_decay: 0.0001 + discriminator: + name: Adam + net_names: + - discriminator + beta1: 0.0 + beta2: 0.99 + weight_decay: 0.0001 + +validate: + interval: 5000 + save_img: false + +log_config: + interval: 5 + visiual_interval: 100 + +snapshot_config: + interval: 5 diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py old mode 100644 new mode 100755 index 4761233c96b1902a12d31b8d6ef44d3f27be0e87..cd5b5445e3afc78a3619ad855671d922258d9a93 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -20,3 +20,4 @@ from .makeup_dataset import MakeupDataset from .common_vision_dataset import CommonVisionDataset from .animeganv2_dataset import AnimeGANV2Dataset from .wav2lip_dataset import Wav2LipDataset +from .starganv2_dataset import StarGANv2Dataset diff --git a/ppgan/datasets/preprocess/transforms.py b/ppgan/datasets/preprocess/transforms.py index ab378c2b9c9e2ce7aa05b9ed5565e962cf42d11d..dcda652d16be479640c47a219b2108039b0bee86 100644 --- a/ppgan/datasets/preprocess/transforms.py +++ b/ppgan/datasets/preprocess/transforms.py @@ -267,6 +267,25 @@ class SRNoise(T.BaseTransform): return image +@TRANSFORMS.register() +class RandomResizedCropProb(T.RandomResizedCrop): + """RandomResizedCropProb. + + Args: + prob (float): probabilty of using random-resized cropping. + size (int): cropped size. + """ + def __init__(self, prob, size, scale, ratio, interpolation, keys=None): + super().__init__(size, scale, ratio, interpolation) + self.prob = prob + self.keys = keys + + def _apply_image(self, image): + if random.random() < self.prob: + image = super()._apply_image(image) + return image + + @TRANSFORMS.register() class Add(T.BaseTransform): def __init__(self, value, keys=None): diff --git a/ppgan/datasets/starganv2_dataset.py b/ppgan/datasets/starganv2_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..cd1621e07cef5221287f319a84304f47446185ab --- /dev/null +++ b/ppgan/datasets/starganv2_dataset.py @@ -0,0 +1,178 @@ + +import paddle +from .base_dataset import BaseDataset +from .builder import DATASETS +import os +from itertools import chain +from pathlib import Path +import traceback +import random +import numpy as np +from PIL import Image + +from paddle.io import Dataset, WeightedRandomSampler + + +def listdir(dname): + fnames = list(chain(*[list(Path(dname).rglob('*.' + ext)) + for ext in ['png', 'jpg', 'jpeg', 'JPG']])) + return fnames + + +def _make_balanced_sampler(labels): + class_counts = np.bincount(labels) + class_weights = 1. / class_counts + weights = class_weights[labels] + return WeightedRandomSampler(weights, len(weights)) + + +class ImageFolder(Dataset): + def __init__(self, root, use_sampler=False): + self.samples, self.targets = self._make_dataset(root) + self.use_sampler = use_sampler + if self.use_sampler: + self.sampler = _make_balanced_sampler(self.targets) + self.iter_sampler = iter(self.sampler) + + def _make_dataset(self, root): + domains = os.listdir(root) + fnames, labels = [], [] + for idx, domain in enumerate(sorted(domains)): + class_dir = os.path.join(root, domain) + cls_fnames = listdir(class_dir) + fnames += cls_fnames + labels += [idx] * len(cls_fnames) + return fnames, labels + + def __getitem__(self, i): + if self.use_sampler: + try: + index = next(self.iter_sampler) + except StopIteration: + self.iter_sampler = iter(self.sampler) + index = next(self.iter_sampler) + else: + index = i + fname = self.samples[index] + label = self.targets[index] + return fname, label + + def __len__(self): + return len(self.targets) + + +class ReferenceDataset(Dataset): + def __init__(self, root, use_sampler=None): + self.samples, self.targets = self._make_dataset(root) + self.use_sampler = use_sampler + if self.use_sampler: + self.sampler = _make_balanced_sampler(self.targets) + self.iter_sampler = iter(self.sampler) + + def _make_dataset(self, root): + domains = os.listdir(root) + fnames, fnames2, labels = [], [], [] + for idx, domain in enumerate(sorted(domains)): + class_dir = os.path.join(root, domain) + cls_fnames = listdir(class_dir) + fnames += cls_fnames + fnames2 += random.sample(cls_fnames, len(cls_fnames)) + labels += [idx] * len(cls_fnames) + return list(zip(fnames, fnames2)), labels + + def __getitem__(self, i): + if self.use_sampler: + try: + index = next(self.iter_sampler) + except StopIteration: + self.iter_sampler = iter(self.sampler) + index = next(self.iter_sampler) + else: + index = i + fname, fname2 = self.samples[index] + label = self.targets[index] + return fname, fname2, label + + def __len__(self): + return len(self.targets) + + + +@DATASETS.register() +class StarGANv2Dataset(BaseDataset): + """ + """ + def __init__(self, dataroot, is_train, preprocess, test_count=0): + """Initialize single dataset class. + + Args: + dataroot (str): Directory of dataset. + preprocess (list[dict]): A sequence of data preprocess config. + """ + super(StarGANv2Dataset, self).__init__(preprocess) + + self.dataroot = dataroot + self.is_train = is_train + if self.is_train: + self.src_loader = ImageFolder(self.dataroot, use_sampler=True) + self.ref_loader = ReferenceDataset(self.dataroot, use_sampler=True) + self.counts = len(self.src_loader) + else: + files = os.listdir(self.dataroot) + if 'src' in files and 'ref' in files: + self.src_loader = ImageFolder(os.path.join(self.dataroot, 'src')) + self.ref_loader = ImageFolder(os.path.join(self.dataroot, 'ref')) + else: + self.src_loader = ImageFolder(self.dataroot) + self.ref_loader = ImageFolder(self.dataroot) + self.counts = min(test_count, len(self.src_loader)) + self.counts = min(self.counts, len(self.ref_loader)) + + + def _fetch_inputs(self): + try: + x, y = next(self.iter_src) + except (AttributeError, StopIteration): + self.iter_src = iter(self.src_loader) + x, y = next(self.iter_src) + return x, y + + def _fetch_refs(self): + try: + x, x2, y = next(self.iter_ref) + except (AttributeError, StopIteration): + self.iter_ref = iter(self.ref_loader) + x, x2, y = next(self.iter_ref) + return x, x2, y + + def __getitem__(self, idx): + if self.is_train: + x, y = self._fetch_inputs() + x_ref, x_ref2, y_ref = self._fetch_refs() + datas = { + 'src_path': x, + 'src_cls': y, + 'ref_path': x_ref, + 'ref2_path': x_ref2, + 'ref_cls': y_ref, + } + else: + x, y = self.src_loader[idx] + x_ref, y_ref = self.ref_loader[idx] + datas = { + 'src_path': x, + 'src_cls': y, + 'ref_path': x_ref, + 'ref_cls': y_ref, + } + + if hasattr(self, 'preprocess') and self.preprocess: + datas = self.preprocess(datas) + + return datas + + def __len__(self): + return self.counts + + def prepare_data_infos(self, dataroot): + pass diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py old mode 100644 new mode 100755 index 4b0b5ea9de65a96852e1b4d684a9824d55f0e454..e734f84e5d5ba8a6560e3dedc2aa7993fa8a561b --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -142,6 +142,7 @@ class Trainer: self.time_count = {} self.best_metric = {} + self.model.set_total_iter(self.total_iters) def distributed_data_parallel(self): paddle.distributed.init_parallel_env() diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 77b7c454852dfbf4434a315f0e4374ad4f6a6539..52d07a98abd22fa30257f03d9b59942a0a6799e6 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -26,3 +26,4 @@ from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel from .styleganv2_model import StyleGAN2Model from .wav2lip_model import Wav2LipModel from .wav2lip_hq_model import Wav2LipModelHq +from .starganv2_model import StarGANv2Model diff --git a/ppgan/models/base_model.py b/ppgan/models/base_model.py old mode 100644 new mode 100755 index e70524715328491cef9430823fc97b6514add7fa..4fd0be8531bf14de1b294c027e2b8b00e718ffd9 --- a/ppgan/models/base_model.py +++ b/ppgan/models/base_model.py @@ -95,6 +95,9 @@ class BaseModel(ABC): """Calculate losses, gradients, and update network weights; called in every training iteration""" pass + def set_total_iter(self, total_iter): + self.total_iter = total_iter + def test_iter(self, metrics=None): """Calculate metrics; called in every test iteration""" self.eval() diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index cbdbc5eee41aba076fb3b7659e199901b9bd00e1..d62fb864abd2f589883a054c67979e05498269ba 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -20,3 +20,4 @@ from .discriminator_animegan import AnimeDiscriminator from .discriminator_styleganv2 import StyleGANv2Discriminator from .syncnet import SyncNetColor from .wav2lip_disc_qual import Wav2LipDiscQual +from .discriminator_starganv2 import StarGANv2Discriminator diff --git a/ppgan/models/discriminators/discriminator_starganv2.py b/ppgan/models/discriminators/discriminator_starganv2.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ff50eb0fc2a1aa72557aa8069382c33818abe4 --- /dev/null +++ b/ppgan/models/discriminators/discriminator_starganv2.py @@ -0,0 +1,39 @@ + +import paddle.nn as nn +import paddle + +from .builder import DISCRIMINATORS +from ..generators.generator_starganv2 import ResBlk + +import numpy as np + + +@DISCRIMINATORS.register() +class StarGANv2Discriminator(nn.Layer): + def __init__(self, img_size=256, num_domains=2, max_conv_dim=512): + super().__init__() + dim_in = 2**14 // img_size + blocks = [] + blocks += [nn.Conv2D(3, dim_in, 3, 1, 1)] + + repeat_num = int(np.log2(img_size)) - 2 + for _ in range(repeat_num): + dim_out = min(dim_in*2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, downsample=True)] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.Conv2D(dim_out, dim_out, 4, 1, 0)] + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.Conv2D(dim_out, num_domains, 1, 1, 0)] + self.main = nn.Sequential(*blocks) + + def forward(self, x, y): + out = self.main(x) + out = paddle.reshape(out, (out.shape[0], -1)) # (batch, num_domains) + idx = paddle.zeros_like(out) + for i in range(idx.shape[0]): + idx[i, y[i]] = 1 + s = idx * out + s = paddle.sum(s, axis=1) + return s diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py old mode 100644 new mode 100755 index c017baf9759672345ce9f7b7be1e1e0dcd8a5227..b4202544cf81040bdeab4de60243a37499bb8be4 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -26,3 +26,5 @@ from .resnet_ugatit_p2c import ResnetUGATITP2CGenerator from .generator_styleganv2 import StyleGANv2Generator from .generator_pixel2style2pixel import Pixel2Style2Pixel from .drn import DRNGenerator +from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Mapping, FAN + diff --git a/ppgan/models/generators/generator_starganv2.py b/ppgan/models/generators/generator_starganv2.py new file mode 100755 index 0000000000000000000000000000000000000000..bed8c01ac25019b7d4625d22a4792976e3211f77 --- /dev/null +++ b/ppgan/models/generators/generator_starganv2.py @@ -0,0 +1,350 @@ + +import paddle +from paddle import nn +import paddle.nn.functional as F + +from .builder import GENERATORS +import numpy as np +import math + +from ppgan.modules.wing import CoordConvTh, ConvBlock, HourGlass, preprocess + + +class AvgPool2D(nn.Layer): + """ + AvgPool2D + Peplace avg_pool2d because paddle.grad will cause avg_pool2d to report an error when training. + In the future Paddle framework will supports avg_pool2d and remove this class. + """ + def __init__(self): + super(AvgPool2D, self).__init__() + self.filter = paddle.to_tensor([[1, 1], + [1, 1]], dtype='float32') + + def forward(self, x): + filter = self.filter.unsqueeze(0).unsqueeze(1).tile([x.shape[1], 1, 1, 1]) + return F.conv2d(x, filter, stride=2, padding=0, groups=x.shape[1]) / 4 + + +class ResBlk(nn.Layer): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize=False, downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2D(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2D(dim_in, dim_out, 3, 1, 1) + if self.normalize: + self.norm1 = nn.InstanceNorm2D(dim_in, weight_attr=True, bias_attr=True) + self.norm2 = nn.InstanceNorm2D(dim_in, weight_attr=True, bias_attr=True) + if self.learned_sc: + self.conv1x1 = nn.Conv2D(dim_in, dim_out, 1, 1, 0, bias_attr=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = AvgPool2D()(x) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = AvgPool2D()(x) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x / math.sqrt(2) # unit variance + + +class AdaIN(nn.Layer): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2D(num_features, weight_attr=False, bias_attr=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + # h = h.view(h.size(0), h.size(1), 1, 1) + h = paddle.reshape(h, (h.shape[0], h.shape[1], 1, 1)) + gamma, beta = paddle.chunk(h, chunks=2, axis=1) + return (1 + gamma) * self.norm(x) + beta + + +class AdainResBlk(nn.Layer): + def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.w_hpf = w_hpf + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2D(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2D(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2D(dim_in, dim_out, 1, 1, 0, bias_attr=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + if self.w_hpf == 0: + out = (out + self._shortcut(x)) / math.sqrt(2) + return out + + +class HighPass(nn.Layer): + def __init__(self, w_hpf): + super(HighPass, self).__init__() + self.filter = paddle.to_tensor([[-1, -1, -1], + [-1, 8., -1], + [-1, -1, -1]]) / w_hpf + + def forward(self, x): + # filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1) + filter = self.filter.unsqueeze(0).unsqueeze(1).tile([x.shape[1], 1, 1, 1]) + return F.conv2d(x, filter, padding=1, groups=x.shape[1]) + + +@GENERATORS.register() +class StarGANv2Generator(nn.Layer): + def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1): + super().__init__() + dim_in = 2**14 // img_size + self.img_size = img_size + self.from_rgb = nn.Conv2D(3, dim_in, 3, 1, 1) + self.encode = nn.LayerList() + self.decode = nn.LayerList() + self.to_rgb = nn.Sequential( + nn.InstanceNorm2D(dim_in, weight_attr=True, bias_attr=True), + nn.LeakyReLU(0.2), + nn.Conv2D(dim_in, 3, 1, 1, 0)) + + # down/up-sampling blocks + repeat_num = int(np.log2(img_size)) - 4 + if w_hpf > 0: + repeat_num += 1 + for _ in range(repeat_num): + dim_out = min(dim_in*2, max_conv_dim) + self.encode.append( + ResBlk(dim_in, dim_out, normalize=True, downsample=True)) + if len(self.decode) == 0: + self.decode.append(AdainResBlk(dim_out, dim_in, style_dim, + w_hpf=w_hpf, upsample=True)) + else: + self.decode.insert( + 0, AdainResBlk(dim_out, dim_in, style_dim, + w_hpf=w_hpf, upsample=True)) # stack-like + dim_in = dim_out + + # bottleneck blocks + for _ in range(2): + self.encode.append( + ResBlk(dim_out, dim_out, normalize=True)) + self.decode.insert( + 0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf)) + + if w_hpf > 0: + self.hpf = HighPass(w_hpf) + + def forward(self, x, s, masks=None): + x = self.from_rgb(x) + cache = {} + for block in self.encode: + if (masks is not None) and (x.shape[2] in [32, 64, 128]): + cache[x.shape[2]] = x + x = block(x) + for block in self.decode: + x = block(x, s) + if (masks is not None) and (x.shape[2] in [32, 64, 128]): + mask = masks[0] if x.shape[2] in [32] else masks[1] + mask = F.interpolate(mask, size=[x.shape[2], x.shape[2]], mode='bilinear') + x = x + self.hpf(mask * cache[x.shape[2]]) + return self.to_rgb(x) + + +@GENERATORS.register() +class StarGANv2Mapping(nn.Layer): + def __init__(self, latent_dim=16, style_dim=64, num_domains=2): + super().__init__() + layers = [] + layers += [nn.Linear(latent_dim, 512)] + layers += [nn.ReLU()] + for _ in range(3): + layers += [nn.Linear(512, 512)] + layers += [nn.ReLU()] + self.shared = nn.Sequential(*layers) + + self.unshared = nn.LayerList() + for _ in range(num_domains): + self.unshared.append(nn.Sequential(nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, style_dim))) + + def forward(self, z, y): + h = self.shared(z) + out = [] + for layer in self.unshared: + out += [layer(h)] + out = paddle.stack(out, axis=1) # (batch, num_domains, style_dim) + idx = paddle.to_tensor(np.array(range(y.shape[0]))).astype('int') + s = [] + for i in range(idx.shape[0]): + s += [out[idx[i].numpy().astype(np.int).tolist()[0], y[i].numpy().astype(np.int).tolist()[0]]] + s = paddle.stack(s) + s = paddle.reshape(s, (s.shape[0], -1)) + return s + + +@GENERATORS.register() +class StarGANv2Style(nn.Layer): + def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512): + super().__init__() + dim_in = 2**14 // img_size + blocks = [] + blocks += [nn.Conv2D(3, dim_in, 3, 1, 1)] + + repeat_num = int(np.log2(img_size)) - 2 + for _ in range(repeat_num): + dim_out = min(dim_in*2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, downsample=True)] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.Conv2D(dim_out, dim_out, 4, 1, 0)] + blocks += [nn.LeakyReLU(0.2)] + self.shared = nn.Sequential(*blocks) + + self.unshared = nn.LayerList() + for _ in range(num_domains): + self.unshared.append(nn.Linear(dim_out, style_dim)) + + def forward(self, x, y): + h = self.shared(x) + h = paddle.reshape(h, (h.shape[0], -1)) + out = [] + for layer in self.unshared: + out += [layer(h)] + out = paddle.stack(out, axis=1) # (batch, num_domains, style_dim) + idx = paddle.to_tensor(np.array(range(y.shape[0]))).astype('int') + s = [] + for i in range(idx.shape[0]): + s += [out[idx[i].numpy().astype(np.int).tolist()[0], y[i].numpy().astype(np.int).tolist()[0]]] + s = paddle.stack(s) + s = paddle.reshape(s, (s.shape[0], -1)) + return s + + +@GENERATORS.register() +class FAN(nn.Layer): + def __init__(self, num_modules=1, end_relu=False, num_landmarks=98, fname_pretrained=None): + super(FAN, self).__init__() + self.num_modules = num_modules + self.end_relu = end_relu + + # Base part + self.conv1 = CoordConvTh(256, 256, True, False, + in_channels=3, out_channels=64, + kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2D(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + self.add_sublayer('m0', HourGlass(1, 4, 256, first_one=True)) + self.add_sublayer('top_m_0', ConvBlock(256, 256)) + self.add_sublayer('conv_last0', nn.Conv2D(256, 256, 1, 1, 0)) + self.add_sublayer('bn_end0', nn.BatchNorm2D(256)) + self.add_sublayer('l0', nn.Conv2D(256, num_landmarks+1, 1, 1, 0)) + + if fname_pretrained is not None: + self.load_pretrained_weights(fname_pretrained) + + def load_pretrained_weights(self, fname): + import pickle + import six + + with open(fname, 'rb') as f: + checkpoint = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') + + model_weights = self.state_dict() + model_weights.update({k: v for k, v in checkpoint['state_dict'].items() + if k in model_weights}) + self.set_state_dict(model_weights) + + def forward(self, x): + x, _ = self.conv1(x) + x = F.relu(self.bn1(x), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + x = self.conv3(x) + x = self.conv4(x) + + outputs = [] + boundary_channels = [] + tmp_out = None + ll, boundary_channel = self._sub_layers['m0'](x, tmp_out) + ll = self._sub_layers['top_m_0'](ll) + ll = F.relu(self._sub_layers['bn_end0'] + (self._sub_layers['conv_last0'](ll)), True) + + # Predict heatmaps + tmp_out = self._sub_layers['l0'](ll) + if self.end_relu: + tmp_out = F.relu(tmp_out) # HACK: Added relu + outputs.append(tmp_out) + boundary_channels.append(boundary_channel) + return outputs, boundary_channels + + @paddle.no_grad() + def get_heatmap(self, x, b_preprocess=True): + ''' outputs 0-1 normalized heatmap ''' + x = F.interpolate(x, size=[256, 256], mode='bilinear') + x_01 = x*0.5 + 0.5 + outputs, _ = self(x_01) + heatmaps = outputs[-1][:, :-1, :, :] + scale_factor = x.shape[2] // heatmaps.shape[2] + if b_preprocess: + heatmaps = F.interpolate(heatmaps, scale_factor=scale_factor, + mode='bilinear', align_corners=True) + heatmaps = preprocess(heatmaps) + return heatmaps diff --git a/ppgan/models/starganv2_model.py b/ppgan/models/starganv2_model.py new file mode 100755 index 0000000000000000000000000000000000000000..d386e68fb1f003022a8b6f13600c97fb24e84cc4 --- /dev/null +++ b/ppgan/models/starganv2_model.py @@ -0,0 +1,289 @@ +from paddle.fluid.layers.nn import soft_relu +from .base_model import BaseModel + +from paddle import nn +import paddle +import paddle.nn.functional as F +from .builder import MODELS +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator +from ..modules.init import kaiming_normal_, constant_ +from ppgan.utils.visual import make_grid, tensor2img + +import numpy as np + + +def translate_using_reference(nets, w_hpf, x_src, x_ref, y_ref): + N, C, H, W = x_src.shape + wb = paddle.to_tensor(np.ones((1, C, H, W))).astype('float32') + x_src_with_wb = paddle.concat([wb, x_src], axis=0) + + masks = nets['fan'].get_heatmap(x_src) if w_hpf > 0 else None + s_ref = nets['style_encoder'](x_ref, y_ref) + s_ref_list = paddle.unsqueeze(s_ref, axis=[1]) + s_ref_lists = [] + for _ in range(N): + s_ref_lists.append(s_ref_list) + s_ref_list = paddle.stack(s_ref_lists, axis=1) + s_ref_list = paddle.reshape(s_ref_list, (s_ref_list.shape[0], s_ref_list.shape[1], s_ref_list.shape[3])) + x_concat = [x_src_with_wb] + for i, s_ref in enumerate(s_ref_list): + x_fake = nets['generator'](x_src, s_ref, masks=masks) + x_fake_with_ref = paddle.concat([x_ref[i:i+1], x_fake], axis=0) + x_concat += [x_fake_with_ref] + + x_concat = paddle.concat(x_concat, axis=0) + img = tensor2img(make_grid(x_concat, nrow=N+1, range=(0, 1))) + del x_concat + return img + + +def compute_d_loss(nets, lambda_reg, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None): + assert (z_trg is None) != (x_ref is None) + # with real images + x_real.stop_gradient = False + out = nets['discriminator'](x_real, y_org) + loss_real = adv_loss(out, 1) + loss_reg = r1_reg(out, x_real) + + # with fake images + with paddle.no_grad(): + if z_trg is not None: + s_trg = nets['mapping_network'](z_trg, y_trg) + else: # x_ref is not None + s_trg = nets['style_encoder'](x_ref, y_trg) + + x_fake = nets['generator'](x_real, s_trg, masks=masks) + out = nets['discriminator'](x_fake, y_trg) + loss_fake = adv_loss(out, 0) + + loss = loss_real + loss_fake + lambda_reg * loss_reg + return loss, {'real': loss_real.numpy(), + 'fake': loss_fake.numpy(), + 'reg': loss_reg.numpy()} + + +def adv_loss(logits, target): + assert target in [1, 0] + targets = paddle.full_like(logits, fill_value=target) + loss = F.binary_cross_entropy_with_logits(logits, targets) + return loss + + +def r1_reg(d_out, x_in): + # zero-centered gradient penalty for real images + batch_size = x_in.shape[0] + grad_dout = paddle.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.shape == x_in.shape) + reg = 0.5 * paddle.reshape(grad_dout2, (batch_size, -1)).sum(1).mean(0) + return reg + +def soft_update(source, target, beta=1.0): + assert 0.0 <= beta <= 1.0 + target_model_map = dict(target.named_parameters()) + for param_name, source_param in source.named_parameters(): + target_param = target_model_map[param_name] + target_param.set_value(beta * source_param + (1.0 - beta) * target_param) + +def dump_model(model): + params = {} + for k in model.state_dict().keys(): + if k.endswith('.scale'): + params[k] = model.state_dict()[k].shape + return params + + +def compute_g_loss(nets, w_hpf, lambda_sty, lambda_ds, lambda_cyc, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None): + assert (z_trgs is None) != (x_refs is None) + if z_trgs is not None: + z_trg, z_trg2 = z_trgs + if x_refs is not None: + x_ref, x_ref2 = x_refs + + # adversarial loss + if z_trgs is not None: + s_trg = nets['mapping_network'](z_trg, y_trg) + else: + s_trg = nets['style_encoder'](x_ref, y_trg) + + x_fake = nets['generator'](x_real, s_trg, masks=masks) + out = nets['discriminator'](x_fake, y_trg) + loss_adv = adv_loss(out, 1) + + # style reconstruction loss + s_pred = nets['style_encoder'](x_fake, y_trg) + loss_sty = paddle.mean(paddle.abs(s_pred - s_trg)) + + # diversity sensitive loss + if z_trgs is not None: + s_trg2 = nets['mapping_network'](z_trg2, y_trg) + else: + s_trg2 = nets['style_encoder'](x_ref2, y_trg) + x_fake2 = nets['generator'](x_real, s_trg2, masks=masks) + loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2)) + + # cycle-consistency loss + masks = nets['fan'].get_heatmap(x_fake) if w_hpf > 0 else None + s_org = nets['style_encoder'](x_real, y_org) + x_rec = nets['generator'](x_fake, s_org, masks=masks) + loss_cyc = paddle.mean(paddle.abs(x_rec - x_real)) + + loss = loss_adv + lambda_sty * loss_sty \ + - lambda_ds * loss_ds + lambda_cyc * loss_cyc + return loss, {'adv': loss_adv.numpy(), + 'sty': loss_sty.numpy(), + 'ds:': loss_ds.numpy(), + 'cyc': loss_cyc.numpy()} + + +def he_init(module): + if isinstance(module, nn.Conv2D): + kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') + if module.bias is not None: + constant_(module.bias, 0) + if isinstance(module, nn.Linear): + kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') + if module.bias is not None: + constant_(module.bias, 0) + + +@MODELS.register() +class StarGANv2Model(BaseModel): + def __init__( + self, + generator, + style=None, + mapping=None, + discriminator=None, + fan=None, + latent_dim=16, + lambda_reg=1, + lambda_sty=1, + lambda_ds=1, + lambda_cyc=1, + ): + super(StarGANv2Model, self).__init__() + self.w_hpf = generator['w_hpf'] + self.nets_ema = {} + self.nets['generator'] = build_generator(generator) + self.nets_ema['generator'] = build_generator(generator) + self.nets['style_encoder'] = build_generator(style) + self.nets_ema['style_encoder'] = build_generator(style) + self.nets['mapping_network'] = build_generator(mapping) + self.nets_ema['mapping_network'] = build_generator(mapping) + if discriminator: + self.nets['discriminator'] = build_discriminator(discriminator) + if self.w_hpf > 0: + fan_model = build_generator(fan) + fan_model.eval() + self.nets['fan'] = fan_model + self.nets_ema['fan'] = fan_model + self.latent_dim = latent_dim + self.lambda_reg = lambda_reg + self.lambda_sty = lambda_sty + self.lambda_ds = lambda_ds + self.lambda_cyc = lambda_cyc + + self.nets['generator'].apply(he_init) + self.nets['style_encoder'].apply(he_init) + self.nets['mapping_network'].apply(he_init) + self.nets['discriminator'].apply(he_init) + + # remember the initial value of ds weight + self.initial_lambda_ds = self.lambda_ds + + def setup_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Args: + input (dict): include the data itself and its metadata information. + + The option 'direction' can be used to swap images in domain A and domain B. + """ + pass + self.input = input + self.input['z_trg'] = paddle.randn((input['src'].shape[0], self.latent_dim)) + self.input['z_trg2'] = paddle.randn((input['src'].shape[0], self.latent_dim)) + + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + def _reset_grad(self, optims): + for optim in optims.values(): + optim.clear_gradients() + + def train_iter(self, optimizers=None): + #TODO + x_real, y_org = self.input['src'], self.input['src_cls'] + x_ref, x_ref2, y_trg = self.input['ref'], self.input['ref2'], self.input['ref_cls'] + z_trg, z_trg2 = self.input['z_trg'], self.input['z_trg2'] + + masks = self.nets['fan'].get_heatmap(x_real) if self.w_hpf > 0 else None + + # train the discriminator + d_loss, d_losses_latent = compute_d_loss( + self.nets, self.lambda_reg, x_real, y_org, y_trg, z_trg=z_trg, masks=masks) + self._reset_grad(optimizers) + d_loss.backward() + optimizers['discriminator'].minimize(d_loss) + + d_loss, d_losses_ref = compute_d_loss( + self.nets, self.lambda_reg, x_real, y_org, y_trg, x_ref=x_ref, masks=masks) + self._reset_grad(optimizers) + d_loss.backward() + optimizers['discriminator'].step() + + # train the generator + g_loss, g_losses_latent = compute_g_loss( + self.nets, self.w_hpf, self.lambda_sty, self.lambda_ds, self.lambda_cyc, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks) + self._reset_grad(optimizers) + g_loss.backward() + optimizers['generator'].step() + optimizers['mapping_network'].step() + optimizers['style_encoder'].step() + + g_loss, g_losses_ref = compute_g_loss( + self.nets, self.w_hpf, self.lambda_sty, self.lambda_ds, self.lambda_cyc, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks) + self._reset_grad(optimizers) + g_loss.backward() + optimizers['generator'].step() + + # compute moving average of network parameters + soft_update(self.nets['generator'], self.nets_ema['generator'], beta=0.999) + soft_update(self.nets['mapping_network'], self.nets_ema['mapping_network'], beta=0.999) + soft_update(self.nets['style_encoder'], self.nets_ema['style_encoder'], beta=0.999) + + # decay weight for diversity sensitive loss + if self.lambda_ds > 0: + self.lambda_ds -= (self.initial_lambda_ds / self.total_iter) + + for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref], + ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']): + for key, value in loss.items(): + self.losses[prefix + key] = value + self.losses['G/lambda_ds'] = self.lambda_ds + self.losses['Total iter'] = int(self.total_iter) + + def test_iter(self, metrics=None): + #TODO + self.nets_ema['generator'].eval() + self.nets_ema['style_encoder'].eval() + soft_update(self.nets['generator'], self.nets_ema['generator'], beta=0.999) + soft_update(self.nets['mapping_network'], self.nets_ema['mapping_network'], beta=0.999) + soft_update(self.nets['style_encoder'], self.nets_ema['style_encoder'], beta=0.999) + src_img = self.input['src'] + ref_img = self.input['ref'] + ref_label = self.input['ref_cls'] + with paddle.no_grad(): + img = translate_using_reference(self.nets_ema, self.w_hpf, + paddle.to_tensor(src_img).astype('float32'), + paddle.to_tensor(ref_img).astype('float32'), + paddle.to_tensor(ref_label).astype('float32')) + self.visual_items['reference'] = img + self.nets_ema['generator'].train() + self.nets_ema['style_encoder'].train() diff --git a/ppgan/modules/wing.py b/ppgan/modules/wing.py new file mode 100755 index 0000000000000000000000000000000000000000..6b583db4c7d58fb1052d737838f54e35e2598c21 --- /dev/null +++ b/ppgan/modules/wing.py @@ -0,0 +1,279 @@ +""" +StarGAN v2 +Copyright (c) 2020-present NAVER Corp. + +""" + +from collections import namedtuple +from copy import deepcopy +from functools import partial + +from munch import Munch +import numpy as np +import cv2 +from skimage.filters import gaussian +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ppgan.models.generators.builder import GENERATORS + + +class HourGlass(nn.Layer): + def __init__(self, num_modules, depth, num_features, first_one=False): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + self.coordconv = CoordConvTh(64, 64, True, True, 256, first_one, + out_channels=256, + kernel_size=1, stride=1, padding=0) + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_sublayer('b1_' + str(level), ConvBlock(256, 256)) + self.add_sublayer('b2_' + str(level), ConvBlock(256, 256)) + if level > 1: + self._generate_network(level - 1) + else: + self.add_sublayer('b2_plus_' + str(level), ConvBlock(256, 256)) + self.add_sublayer('b3_' + str(level), ConvBlock(256, 256)) + + def _forward(self, level, inp): + up1 = inp + up1 = self._sub_layers['b1_' + str(level)](up1) + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._sub_layers['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._sub_layers['b2_plus_' + str(level)](low2) + low3 = low2 + low3 = self._sub_layers['b3_' + str(level)](low3) + up2 = F.interpolate(low3, scale_factor=2, mode='nearest') + + return up1 + up2 + + def forward(self, x, heatmap): + x, last_channel = self.coordconv(x, heatmap) + return self._forward(self.depth, x), last_channel + + +class AddCoordsTh(nn.Layer): + def __init__(self, height=64, width=64, with_r=False, with_boundary=False): + super(AddCoordsTh, self).__init__() + self.with_r = with_r + self.with_boundary = with_boundary + + with paddle.no_grad(): + x_coords = paddle.arange(height).unsqueeze(1).expand((height, width)).astype('float32') + y_coords = paddle.arange(width).unsqueeze(0).expand((height, width)).astype('float32') + x_coords = (x_coords / (height - 1)) * 2 - 1 + y_coords = (y_coords / (width - 1)) * 2 - 1 + coords = paddle.stack([x_coords, y_coords], axis=0) # (2, height, width) + + if self.with_r: + rr = paddle.sqrt(paddle.pow(x_coords, 2) + paddle.pow(y_coords, 2)) # (height, width) + rr = (rr / paddle.max(rr)).unsqueeze(0) + coords = paddle.concat([coords, rr], axis=0) + + self.coords = coords.unsqueeze(0) # (1, 2 or 3, height, width) + self.x_coords = x_coords + self.y_coords = y_coords + + def forward(self, x, heatmap=None): + """ + x: (batch, c, x_dim, y_dim) + """ + coords = self.coords.tile((x.shape[0], 1, 1, 1)) + + if self.with_boundary and heatmap is not None: + boundary_channel = paddle.clip(heatmap[:, -1:, :, :], 0.0, 1.0) + zero_tensor = paddle.zeros_like(self.x_coords) + xx_boundary_channel = paddle.where(boundary_channel > 0.05, self.x_coords, zero_tensor) + yy_boundary_channel = paddle.where(boundary_channel > 0.05, self.y_coords, zero_tensor) + coords = paddle.concat([coords, xx_boundary_channel, yy_boundary_channel], axis=1) + + x_and_coords = paddle.concat([x, coords], axis=1) + return x_and_coords + + +class CoordConvTh(nn.Layer): + """CoordConv layer as in the paper.""" + def __init__(self, height, width, with_r, with_boundary, + in_channels, first_one=False, *args, **kwargs): + super(CoordConvTh, self).__init__() + self.addcoords = AddCoordsTh(height, width, with_r, with_boundary) + in_channels += 2 + if with_r: + in_channels += 1 + if with_boundary and not first_one: + in_channels += 2 + self.conv = nn.Conv2D(in_channels=in_channels, *args, **kwargs) + + def forward(self, input_tensor, heatmap=None): + ret = self.addcoords(input_tensor, heatmap) + last_channel = ret[:, -2:, :, :] + ret = self.conv(ret) + return ret, last_channel + + +class ConvBlock(nn.Layer): + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2D(in_planes) + conv3x3 = partial(nn.Conv2D, kernel_size=3, stride=1, padding=1, bias_attr=False, dilation=1) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2D(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) + self.bn3 = nn.BatchNorm2D(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) + + self.downsample = None + if in_planes != out_planes: + self.downsample = nn.Sequential(nn.BatchNorm2D(in_planes), + nn.ReLU(True), + nn.Conv2D(in_planes, out_planes, 1, 1, bias_attr=False)) + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = paddle.concat((out1, out2, out3), 1) + if self.downsample is not None: + residual = self.downsample(residual) + out3 += residual + return out3 + + +# ========================== # +# Mask related functions # +# ========================== # + + +def normalize(x, eps=1e-6): + """Apply min-max normalization.""" + # x = x.contiguous() + N, C, H, W = x.shape + x_ = paddle.reshape(x, (N*C, -1)) + max_val = paddle.max(x_, axis=1, keepdim=True)[0] + min_val = paddle.min(x_, axis=1, keepdim=True)[0] + x_ = (x_ - min_val) / (max_val - min_val + eps) + out = paddle.reshape(x_, (N, C, H, W)) + return out + + +def truncate(x, thres=0.1): + """Remove small values in heatmaps.""" + return paddle.where(x < thres, paddle.zeros_like(x), x) + + +def resize(x, p=2): + """Resize heatmaps.""" + return x**p + + +def shift(x, N): + """Shift N pixels up or down.""" + x = x.numpy() + up = N >= 0 + N = abs(N) + _, _, H, W = x.shape + head = np.arange(N) + tail = np.arange(H-N) + + if up: + head = np.arange(H-N)+N + tail = np.arange(N) + else: + head = np.arange(N) + (H-N) + tail = np.arange(H-N) + + # permutation indices + perm = np.concatenate([head, tail]) + out = x[:, :, perm, :] + out = paddle.to_tensor(out) + return out + + +IDXPAIR = namedtuple('IDXPAIR', 'start end') +index_map = Munch(chin=IDXPAIR(0 + 8, 33 - 8), + eyebrows=IDXPAIR(33, 51), + eyebrowsedges=IDXPAIR(33, 46), + nose=IDXPAIR(51, 55), + nostrils=IDXPAIR(55, 60), + eyes=IDXPAIR(60, 76), + lipedges=IDXPAIR(76, 82), + lipupper=IDXPAIR(77, 82), + liplower=IDXPAIR(83, 88), + lipinner=IDXPAIR(88, 96)) +OPPAIR = namedtuple('OPPAIR', 'shift resize') + + +def preprocess(x): + """Preprocess 98-dimensional heatmaps.""" + N, C, H, W = x.shape + x = truncate(x) + x = normalize(x) + + sw = H // 256 + operations = Munch(chin=OPPAIR(0, 3), + eyebrows=OPPAIR(-7*sw, 2), + nostrils=OPPAIR(8*sw, 4), + lipupper=OPPAIR(-8*sw, 4), + liplower=OPPAIR(8*sw, 4), + lipinner=OPPAIR(-2*sw, 3)) + + for part, ops in operations.items(): + start, end = index_map[part] + x[:, start:end] = resize(shift(x[:, start:end], ops.shift), ops.resize) + + zero_out = paddle.concat([paddle.arange(0, index_map.chin.start), + paddle.arange(index_map.chin.end, 33), + paddle.to_tensor([index_map.eyebrowsedges.start, + index_map.eyebrowsedges.end, + index_map.lipedges.start, + index_map.lipedges.end])]) + x = x.numpy() + zero_out = zero_out.numpy() + x[:, zero_out] = 0 + x = paddle.to_tensor(x) + + start, end = index_map.nose + x[:, start+1:end] = shift(x[:, start+1:end], 4*sw) + x[:, start:end] = resize(x[:, start:end], 1) + + start, end = index_map.eyes + x[:, start:end] = resize(x[:, start:end], 1) + x[:, start:end] = resize(shift(x[:, start:end], -8), 3) + \ + shift(x[:, start:end], -24) + + # Second-level mask + x2 = deepcopy(x) + x2[:, index_map.chin.start:index_map.chin.end] = 0 # start:end was 0:33 + x2[:, index_map.lipedges.start:index_map.lipinner.end] = 0 # start:end was 76:96 + x2[:, index_map.eyebrows.start:index_map.eyebrows.end] = 0 # start:end was 33:51 + + x = paddle.sum(x, axis=1, keepdim=True) # (N, 1, H, W) + x2 = paddle.sum(x2, axis=1, keepdim=True) # mask without faceline and mouth + + x = x.numpy() + x2 = x2.numpy() + x[x != x] = 0 # set nan to zero + x2[x != x] = 0 # set nan to zero + x = paddle.to_tensor(x) + x2 = paddle.to_tensor(x2) + return x.clip(0, 1), x2.clip(0, 1) diff --git a/requirements.txt b/requirements.txt index d1446d98bd21e6c74642bf98b352df25a477cb30..eb3a44aa99fb09e17185018160c32de38b6f462b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ imageio-ffmpeg librosa==0.7.0 numba==0.48 easydict +munch