diff --git a/README.md b/README.md index 16c9f445d80b437e4c8861d23c079d0d6bb44877..9b2354aa76ca27750a87de31956412b66e467b26 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional * [PReNet](./docs/en_US/tutorials/prenet.md) * [SwinIR](./docs/en_US/tutorials/swinir.md) * [InvDN](./docs/en_US/tutorials/invdn.md) +* [AOT-GAN](./docs/en_US/tutorials/aotgan.md) ## Composite Application diff --git a/README_cn.md b/README_cn.md index 0ec28e289404dfe8fdb92df72a1e6b80c16f2b1a..48f53fa07c9d514cfaf7f4af58f6ed096f27268a 100644 --- a/README_cn.md +++ b/README_cn.md @@ -141,6 +141,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆) * 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)、[SwinIR](./docs/zh_CN/tutorials/swinir.md)、[InvDN](./docs/zh_CN/tutorials/invdn.md) * 视频去模糊:[EDVR](./docs/zh_CN/tutorials/video_super_resolution.md) * 图像去雨:[PReNet](./docs/zh_CN/tutorials/prenet.md) + * 图像补全:[AOT-GAN](./docs/zh_CN/tutorials/aotgan.md) ## 产业级应用 diff --git a/applications/tools/aotgan.py b/applications/tools/aotgan.py new file mode 100644 index 0000000000000000000000000000000000000000..e545eda8b022bc3916e77d333c62dbebb7197f55 --- /dev/null +++ b/applications/tools/aotgan.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import os +import sys + +sys.path.insert(0, os.getcwd()) +from ppgan.apps import AOTGANPredictor +import argparse +from ppgan.utils.config import get_config + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--input_image_path", + type=str, + default=None, + help="path to input image") + + parser.add_argument("--input_mask_path", + type=str, + default=None, + help="path to input mask") + + parser.add_argument("--output_path", + type=str, + default=None, + help="path to output image dir") + + parser.add_argument("--weight_path", + type=str, + default=None, + help="path to model weight") + + parser.add_argument("--config-file", + type=str, + default=None, + help="path to yaml file") + + parser.add_argument("--cpu", + dest="cpu", + action="store_true", + help="cpu mode.") + + args = parser.parse_args() + + if args.cpu: + paddle.set_device('cpu') + + cfg = get_config(args.config_file) + + predictor = AOTGANPredictor(output_path=args.output_path, + weight_path=args.weight_path, + gen_cfg=cfg.predict) + predictor.run(input_image_path=args.input_image_path, input_mask_path=args.input_mask_path) diff --git a/configs/aotgan.yaml b/configs/aotgan.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2d979abe7e429e9a37bb7c583fdfb8957d3fc95 --- /dev/null +++ b/configs/aotgan.yaml @@ -0,0 +1,71 @@ +total_iters: 1000000 +output_dir: output_dir +checkpoints_dir: checkpoints +epochs: 5 + +model: + name: AOTGANModel + generator: + name: InpaintGenerator + rates: [1, 2, 4, 8] + block_num: 8 + discriminator: + name: Discriminator + inc: 3 + criterion: + name: AOTGANCriterionLoss + pretrained: https://paddlegan.bj.bcebos.com/models/vgg19feats.pdparams + l1_weight: 1 + perceptual_weight: 1 + style_weight: 250 + adversal_weight: 0.01 + img_size: 512 + +dataset: + train: + name: AOTGANDataset + dataset_path: data/aotgan + batch_size: 8 # Multi-Card:4 + img_size: 512 + test: + name: AOTGANDataset_test + dataset_path: data/aotgan + batch_size: 1 + img_size: 512 + +lr_scheduler: # abundoned + name: MultiStepDecay + learning_rate: 0.0001 + milestones: [990000] + gamma: 0.1 + +optimizer: + lr: 0.0001 + optimG: + name: Adam + net_names: + - net_gen + beta1: 0.5 + beta2: 0.999 + optimD: + name: Adam + net_names: + - net_des + beta1: 0.5 + beta2: 0.999 + +log_config: + interval: 100 + visiual_interval: 100 + +snapshot_config: + interval: 1000 + +predict: + name: AOTGANGenerator + rates: [1, 2, 4, 8] + block_num: 8 + img_size: 512 + +export_model: + - {name: 'net_gen', inputs_num: 1} diff --git a/docs/en_US/tutorials/aotgan.md b/docs/en_US/tutorials/aotgan.md new file mode 100644 index 0000000000000000000000000000000000000000..9d0c3609ed359e9ed1586fccfb24f9d1db8ddae4 --- /dev/null +++ b/docs/en_US/tutorials/aotgan.md @@ -0,0 +1,89 @@ +# AOT GAN + +## 1 Principle + + The Aggregated COntextual-Transformation GAN (AOT-GAN) is for high-resolution image inpainting.The AOT blocks aggregate contextual +transformations from various receptive fields, allowing to capture both informative distant image contexts and rich patterns of interest +for context reasoning. + +![](https://ai-studio-static-online.cdn.bcebos.com/c3b71d7f28ce4906aa7cccb10ed09ae5e317513b6dbd471aa5cca8144a7fd593) + +**Paper:** [Aggregated Contextual Transformations for High-Resolution Image Inpainting](https://paperswithcode.com/paper/aggregated-contextual-transformations-for) + +**Official Repo:** [https://github.com/megvii-research/NAFNet](https://github.com/megvii-research/NAFNet) + + +## 2 How to use + +### 2.1 Prediction + +Download pretrained generator weights from: (https://paddlegan.bj.bcebos.com/models/AotGan_g.pdparams) + +``` +python applications/tools/aotgan.py \ + --input_image_path data/aotgan/armani1.jpg \ + --input_mask_path data/aotgan/armani1.png \ + --weight_path test/aotgan/g.pdparams \ + --output_path output_dir/armani_pred.jpg \ + --config-file configs/aotgan.yaml +``` +Parameters: +* input_image_path:input image +* input_mask_path:input mask +* weight_path:pretrained generator weights +* output_path:predicted image +* config-file:yaml file,same with the training process + +AI Studio Project:(https://aistudio.baidu.com/aistudio/datasetdetail/165081) + +### 2.2 Train + +Data Preparation: + +The pretained model uses 'Place365Standard' and 'NVIDIA Irregular Mask' as its training datasets. You can download then from ([Place365Standard](http://places2.csail.mit.edu/download.html)) and ([NVIDIA Irregular Mask Dataset](https://nv-adlr.github.io/publication/partialconv-inpainting)). + +``` +└─data + └─aotgan + ├─train_img + ├─train_mask + ├─val_img + └─val_mask +``` +Train(Single Card): + +`python -u tools/main.py --config-file configs/aotgan.yaml` + +Train(Mult-Card): + +``` +!python -m paddle.distributed.launch \ + tools/main.py \ + --config-file configs/photopen.yaml \ + -o dataset.train.batch_size=6 +``` +Train(continue): + +``` +python -u tools/main.py \ + --config-file configs/aotgan.yaml \ + --resume output_dir/[path_to_checkpoint]/iter_[iternumber]_checkpoint.pdparams +``` + +# Results + +On Places365-Val Dataset + +| mask | PSNR | SSIM | download | +| ---- | ---- | ---- | ---- | +| 20-30% | 26.04001 | 0.89011 | [download](https://paddlegan.bj.bcebos.com/models/AotGan_g.pdparams) | + +# References + +@inproceedings{yan2021agg, + author = {Zeng, Yanhong and Fu, Jianlong and Chao, Hongyang and Guo, Baining}, + title = {Aggregated Contextual Transformations for High-Resolution Image Inpainting}, + booktitle = {Arxiv}, + pages={-}, + year = {2020} +} diff --git a/docs/zh_CN/tutorials/aotgan.md b/docs/zh_CN/tutorials/aotgan.md new file mode 100644 index 0000000000000000000000000000000000000000..716d613b8b469324ace96379ef6caeef1ec6c4aa --- /dev/null +++ b/docs/zh_CN/tutorials/aotgan.md @@ -0,0 +1,101 @@ +# AOT GAN + +## 1. 简介 + +本应用的 AOT GAN 模型出自论文《Aggregated Contextual Transformations for High-Resolution Image Inpainting》,其通过聚合不同膨胀率的空洞卷积学习到的图片特征,刷出了inpainting任务的新SOTA。模型推理效果如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/c3b71d7f28ce4906aa7cccb10ed09ae5e317513b6dbd471aa5cca8144a7fd593) + +**论文:** [Aggregated Contextual Transformations for High-Resolution Image Inpainting](https://paperswithcode.com/paper/aggregated-contextual-transformations-for) + +**参考repo:** [https://github.com/megvii-research/NAFNet](https://github.com/megvii-research/NAFNet) + +## 2.快速体验 + +预训练模型权重文件 g.pdparams 可以从如下地址下载: (https://paddlegan.bj.bcebos.com/models/AotGan_g.pdparams) + +输入一张 512x512 尺寸的图片和擦除 mask 给模型,输出一张补全(inpainting)的图片。预测代码如下: + +``` +python applications/tools/aotgan.py \ + --input_image_path data/aotgan/armani1.jpg \ + --input_mask_path data/aotgan/armani1.png \ + --weight_path test/aotgan/g.pdparams \ + --output_path output_dir/armani_pred.jpg \ + --config-file configs/aotgan.yaml +``` + +**参数说明:** +* input_image_path:输入图片路径 +* input_mask_path:输入擦除 mask 路径 +* weight_path:训练完成的模型权重存储路径,为 statedict 格式(.pdparams)的 Paddle 模型行权重文件 +* output_path:预测生成图片的存储路径 +* config-file:存储参数设定的yaml文件存储路径,与训练过程使用同一个yaml文件,预测参数由 predict 下字段设定 + +AI Studio 快速体验项目:(https://aistudio.baidu.com/aistudio/datasetdetail/165081) + +## 3.训练 + +**数据准备:** + +* 训练用的图片解压到项目路径下的 data/aotgan/train_img 文件夹内,可包含多层目录,dataloader会递归读取每层目录下的图片。训练用的mask图片解压到项目路径下的 data/aotgan/train_mask 文件夹内。 +* 验证用的图片和mask图片相应的放到项目路径下的 data/aotgan/val_img 文件夹和 data/aotgan/val_mask 文件夹内。 + +数据集目录结构如下: + +``` +└─data + └─aotgan + ├─train_img + ├─train_mask + ├─val_img + └─val_mask +``` + +* 训练预训练模型的权重使用了 Place365Standard 数据集的训练集图片,以及 NVIDIA Irregular Mask Dataset 数据集的测试集掩码图片。Place365Standard 的训练集为 160万张长或宽最小为 512 像素的图片。NVIDIA Irregular Mask Dataset 的测试集为 12000 张尺寸为 512 x 512 的不规则掩码图片。数据集下载链接:[Place365Standard](http://places2.csail.mit.edu/download.html)、[NVIDIA Irregular Mask Dataset](https://nv-adlr.github.io/publication/partialconv-inpainting) + +### 3.1 gpu 单卡训练 + +`python -u tools/main.py --config-file configs/aotgan.yaml` + +* config-file:训练使用的超参设置 yamal 文件的存储路径 + +### 3.2 gpu 多卡训练 + +``` +!python -m paddle.distributed.launch \ + tools/main.py \ + --config-file configs/photopen.yaml \ + -o dataset.train.batch_size=6 +``` + +* config-file:训练使用的超参设置 yamal 文件的存储路径 +* -o dataset.train.batch_size=6:-o 设置参数覆盖 yaml 文件中的值,这里调整了 batch_size 参数 + +### 3.3 继续训练 + +``` +python -u tools/main.py \ + --config-file configs/aotgan.yaml \ + --resume output_dir/[path_to_checkpoint]/iter_[iternumber]_checkpoint.pdparams +``` + +* config-file:训练使用的超参设置 yamal 文件的存储路径 +* resume:指定读取的 checkpoint 路径 + +### 3.4 实验结果展示 + +在Places365模型的验证集上的指标如下 + +| mask | PSNR | SSIM | download | +| ---- | ---- | ---- | ---- | +| 20-30% | 26.04001 | 0.89011 | [download](https://paddlegan.bj.bcebos.com/models/AotGan_g.pdparams) | + +## 4. 参考链接与文献 +@inproceedings{yan2021agg, + author = {Zeng, Yanhong and Fu, Jianlong and Chao, Hongyang and Guo, Baining}, + title = {Aggregated Contextual Transformations for High-Resolution Image Inpainting}, + booktitle = {Arxiv}, + pages={-}, + year = {2020} +} diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index 10d25ff3a705339e9a9b6ec31818c83e1733240f..ad49bea04fdbbac91ea4c5683f16e464e941d333 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -39,3 +39,4 @@ from .singan_predictor import SinGANPredictor from .gpen_predictor import GPENPredictor from .swinir_predictor import SwinIRPredictor from .invdn_predictor import InvDNPredictor +from .aotgan_predictor import AOTGANPredictor diff --git a/ppgan/apps/aotgan_predictor.py b/ppgan/apps/aotgan_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..73304ee29d652d778d1b7a7a2a219b04044d81ab --- /dev/null +++ b/ppgan/apps/aotgan_predictor.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PIL import Image, ImageOps +import cv2 +import numpy as np +import os + +import paddle +from paddle.vision.transforms import Resize + +from .base_predictor import BasePredictor +from ppgan.models.generators import InpaintGenerator +from ..utils.filesystem import load + + +class AOTGANPredictor(BasePredictor): + def __init__(self, + output_path, + weight_path, + gen_cfg): + + # initialize model + gen = InpaintGenerator( + gen_cfg.rates, + gen_cfg.block_num, + ) + gen.eval() + para = load(weight_path) + if 'net_gen' in para: + gen.set_state_dict(para['net_gen']) + else: + gen.set_state_dict(para) + + self.gen = gen + self.output_path = output_path + self.gen_cfg = gen_cfg + + + def run(self, input_image_path, input_mask_path): + img = Image.open(input_image_path) + mask = Image.open(input_mask_path) + img = Resize([self.gen_cfg.img_size, self.gen_cfg.img_size], interpolation='bilinear')(img) + mask = Resize([self.gen_cfg.img_size, self.gen_cfg.img_size], interpolation='nearest')(mask) + img = img.convert('RGB') + mask = mask.convert('L') + img = np.array(img) + mask = np.array(mask) + + # normalize image data to (-1, +1),image tensor shape:[n=1, c=3, h=512, w=512] + img = (img.astype('float32') / 255.) * 2. - 1. + img = np.transpose(img, (2, 0, 1)) + img = paddle.to_tensor(np.expand_dims(img, 0)) + # mask tensor shape:[n=1, c=3, h=512, w=512], value 0 denotes known pixels and 1 denotes missing regions + mask = np.expand_dims(mask.astype('float32') / 255., 0) + mask = paddle.to_tensor(np.expand_dims(mask, 0)) + + # predict + img_masked = (img * (1 - mask)) + mask # put the mask onto the image + input_data = paddle.concat((img_masked, mask), axis=1) # concatenate + pred_img = self.gen(input_data) # predict by masked image + comp_img = (1 - mask) * img + mask * pred_img # compound the inpainted image + img_save = ((comp_img.numpy()[0].transpose((1,2,0)) + 1.) / 2. * 255).astype('uint8') + + pic = cv2.cvtColor(img_save,cv2.COLOR_BGR2RGB) + path, _ = os.path.split(self.output_path) + if not os.path.exists(path): + os.mkdir(path) + cv2.imwrite(self.output_path, pic) + print('Predicted pictures are saved: '+self.output_path+' 。') diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py old mode 100644 new mode 100755 index 3e5a487d33e82ff72d45d7419656426c874b06e0..e660d8dac5d67c319c0d19cbc3f9c4fbafcc3ff6 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -35,3 +35,4 @@ from .swinir_dataset import SwinIRDataset from .gfpgan_datasets import FFHQDegradationDataset from .paired_image_datasets import PairedImageDataset from .invdn_dataset import InvDNDataset +from .aotgan_dataset import AOTGANDataset diff --git a/ppgan/datasets/aotgan_dataset.py b/ppgan/datasets/aotgan_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c49b203186cbda9c9f930bf56bb1adf466796d7d --- /dev/null +++ b/ppgan/datasets/aotgan_dataset.py @@ -0,0 +1,186 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PIL import Image, ImageOps +import os +import numpy as np +import logging + +from paddle.io import Dataset, DataLoader +from paddle.vision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, RandomRotation, ColorJitter, Resize + +from .builder import DATASETS + +logger = logging.getLogger(__name__) + +@DATASETS.register() +class AOTGANDataset(Dataset): + def __init__(self, dataset_path, img_size, istrain=True): + super(AOTGANDataset, self).__init__() + + self.image_path = [] + def get_all_sub_dirs(root_dir): # read all image files including subdirectories + file_list = [] + def get_sub_dirs(r_dir): + for root, dirs, files in os.walk(r_dir): + if len(files) > 0: + for f in files: + file_list.append(os.path.join(root, f)) + if len(dirs) > 0: + for d in dirs: + get_sub_dirs(os.path.join(root, d)) + break + get_sub_dirs(root_dir) + return file_list + + # set data path + if istrain: + self.img_list = get_all_sub_dirs(os.path.join(dataset_path, 'train_img')) + self.mask_dir = os.path.join(dataset_path, 'train_mask') + else: + self.img_list = get_all_sub_dirs(os.path.join(dataset_path, 'val_img')) + self.mask_dir = os.path.join(dataset_path, 'val_mask') + self.img_list = np.sort(np.array(self.img_list)) + _, _, mask_list = next(os.walk(self.mask_dir)) + self.mask_list = np.sort(mask_list) + + + self.istrain = istrain + + # augumentations + if istrain: + self.img_trans = Compose([ + Resize(img_size), + RandomResizedCrop(img_size), + RandomHorizontalFlip(), + ColorJitter(0.05, 0.05, 0.05, 0.05), + ]) + self.mask_trans = Compose([ + Resize([img_size, img_size], interpolation='nearest'), + RandomHorizontalFlip(), + ]) + else: + self.img_trans = Compose([ + Resize([img_size, img_size], interpolation='bilinear'), + ]) + self.mask_trans = Compose([ + Resize([img_size, img_size], interpolation='nearest'), + ]) + + self.istrain = istrain + + # feed data + def __getitem__(self, idx): + img = Image.open(self.img_list[idx]) + mask = Image.open(os.path.join(self.mask_dir, self.mask_list[np.random.randint(0, self.mask_list.shape[0])])) + img = self.img_trans(img) + mask = self.mask_trans(mask) + + mask = mask.rotate(np.random.randint(0, 45)) + img = img.convert('RGB') + mask = mask.convert('L') + + img = np.array(img).astype('float32') + img = (img / 255.) * 2. - 1. + img = np.transpose(img, (2, 0, 1)) + mask = np.array(mask).astype('float32') / 255. + mask = np.expand_dims(mask, 0) + + return {'img':img, 'mask':mask, 'img_path':self.img_list[idx]} + + def __len__(self): + return len(self.img_list) + + def name(self): + return 'PlaceDateset' + +@DATASETS.register() +class AOTGANDataset_test(Dataset): + def __init__(self, dataset_path, img_size, istrain=True): + super(AOTGANDataset_test, self).__init__() + + self.image_path = [] + def get_all_sub_dirs(root_dir): # read all image files including subdirectories + file_list = [] + def get_sub_dirs(r_dir): + for root, dirs, files in os.walk(r_dir): + if len(files) > 0: + for f in files: + file_list.append(os.path.join(root, f)) + if len(dirs) > 0: + for d in dirs: + get_sub_dirs(os.path.join(root, d)) + break + get_sub_dirs(root_dir) + return file_list + + # set data path + if istrain: + self.img_list = get_all_sub_dirs(os.path.join(dataset_path, 'train_img')) + self.mask_dir = os.path.join(dataset_path, 'train_mask') + else: + self.img_list = get_all_sub_dirs(os.path.join(dataset_path, 'val_img')) + self. mask_dir = os.path.join(dataset_path, 'val_mask') + self.img_list = np.sort(np.array(self.img_list)) + _, _, mask_list = next(os.walk(self.mask_dir)) + self.mask_list = np.sort(mask_list) + + + self.istrain = istrain + + # augumentations + if istrain: + self.img_trans = Compose([ + RandomResizedCrop(img_size), + RandomHorizontalFlip(), + ColorJitter(0.05, 0.05, 0.05, 0.05), + ]) + self.mask_trans = Compose([ + Resize([img_size, img_size], interpolation='nearest'), + RandomHorizontalFlip(), + ]) + else: + self.img_trans = Compose([ + Resize([img_size, img_size], interpolation='bilinear'), + ]) + self.mask_trans = Compose([ + Resize([img_size, img_size], interpolation='nearest'), + ]) + + self.istrain = istrain + + # feed data + def __getitem__(self, idx): + img = Image.open(self.img_list[idx]) + mask = Image.open(os.path.join(self.mask_dir, self.mask_list[np.random.randint(0, self.mask_list.shape[0])])) + img = self.img_trans(img) + mask = self.mask_trans(mask) + + mask = mask.rotate(np.random.randint(0, 45)) + img = img.convert('RGB') + mask = mask.convert('L') + + img = np.array(img).astype('float32') + img = (img / 255.) * 2. - 1. + img = np.transpose(img, (2, 0, 1)) + mask = np.array(mask).astype('float32') / 255. + mask = np.expand_dims(mask, 0) + + return {'img':img, 'mask':mask, 'img_path':self.img_list[idx]} + + def __len__(self): + return len(self.img_list) + + def name(self): + return 'PlaceDateset_test' diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 24990dc821af49760bd9a5217005a36a8b3c524b..ffeb4662a723ef3a7b168c7587684a94235e837f 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -41,3 +41,4 @@ from .gpen_model import GPENModel from .swinir_model import SwinIRModel from .gfpgan_model import GFPGANModel from .invdn_model import InvDNModel +from .aotgan_model import AOTGANModel diff --git a/ppgan/models/aotgan_model.py b/ppgan/models/aotgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5b2970f67ce18ee803504b59e496cba6367ce47f --- /dev/null +++ b/ppgan/models/aotgan_model.py @@ -0,0 +1,163 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from .base_model import BaseModel + +from .builder import MODELS +from .generators.builder import build_generator +from .criterions import build_criterion +from .discriminators.builder import build_discriminator + +from ..modules.init import init_weights +from ..solver import build_optimizer + +# gaussion blur on mask +def gaussian_blur(input, kernel_size, sigma): + def get_gaussian_kernel(kernel_size: int, sigma: float) -> paddle.Tensor: + def gauss_fcn(x, window_size, sigma): + return -(x - window_size // 2)**2 / float(2 * sigma**2) + gauss = paddle.stack([paddle.exp(paddle.to_tensor(gauss_fcn(x, kernel_size, sigma)))for x in range(kernel_size)]) + return gauss / gauss.sum() + + + b, c, h, w = input.shape + ksize_x, ksize_y = kernel_size + sigma_x, sigma_y = sigma + kernel_x = get_gaussian_kernel(ksize_x, sigma_x) + kernel_y = get_gaussian_kernel(ksize_y, sigma_y) + kernel_2d = paddle.matmul(kernel_x, kernel_y, transpose_y=True) + kernel = kernel_2d.reshape([1, 1, ksize_x, ksize_y]) + kernel = kernel.repeat_interleave(c, 0) + padding = [(k - 1) // 2 for k in kernel_size] + return F.conv2d(input, kernel, padding=padding, stride=1, groups=c) + +# GAN Loss +class Adversal(): + def __init__(self, ksize=71): + self.ksize = ksize + self.loss_fn = nn.MSELoss() + + def __call__(self, netD, fake, real, masks): + fake_detach = fake.detach() + + g_fake = netD(fake) + d_fake = netD(fake_detach) + d_real = netD(real) + + _, _, h, w = g_fake.shape + b, c, ht, wt = masks.shape + + # align image shape with mask + if h != ht or w != wt: + g_fake = F.interpolate(g_fake, size=(ht, wt), mode='bilinear', align_corners=True) + d_fake = F.interpolate(d_fake, size=(ht, wt), mode='bilinear', align_corners=True) + d_real = F.interpolate(d_real, size=(ht, wt), mode='bilinear', align_corners=True) + d_fake_label = gaussian_blur(masks, (self.ksize, self.ksize), (10, 10)).detach() + d_real_label = paddle.zeros_like(d_real) + g_fake_label = paddle.ones_like(g_fake) + + dis_loss = [self.loss_fn(d_fake, d_fake_label).mean(), self.loss_fn(d_real, d_real_label).mean()] + gen_loss = (self.loss_fn(g_fake, g_fake_label) * masks / paddle.mean(masks)).mean() + + return dis_loss, gen_loss + +@MODELS.register() +class AOTGANModel(BaseModel): + def __init__(self, + generator, + discriminator, + criterion, + l1_weight, + perceptual_weight, + style_weight, + adversal_weight, + img_size, + ): + + super(AOTGANModel, self).__init__() + + # define nets + self.nets['net_gen'] = build_generator(generator) + self.nets['net_des'] = build_discriminator(discriminator) + self.net_vgg = build_criterion(criterion) + + self.adv_loss = Adversal() + + self.l1_weight = l1_weight + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.adversal_weight = adversal_weight + self.img_size = img_size + + def setup_input(self, input): + self.img = input['img'] + self.mask = input['mask'] + self.img_masked = (self.img * (1 - self.mask)) + self.mask + self.img_paths = input['img_path'] + + def forward(self): + input_x = paddle.concat([self.img_masked, self.mask], 1) + self.pred_img = self.nets['net_gen'](input_x) + self.comp_img = (1 - self.mask) * self.img + self.mask * self.pred_img + self.visual_items['pred_img'] = self.pred_img + + def train_iter(self, optimizers=None): + self.forward() + l1_loss, perceptual_loss, style_loss = self.net_vgg(self.img, self.pred_img, self.img_size) + self.losses['l1'] = l1_loss * self.l1_weight + self.losses['perceptual'] = perceptual_loss * self.perceptual_weight + self.losses['style'] = style_loss * self.style_weight + dis_loss, gen_loss = self.adv_loss(self.nets['net_des'], self.comp_img, self.img, self.mask) + self.losses['adv_g'] = gen_loss * self.adversal_weight + loss_d_fake = dis_loss[0] + loss_d_real = dis_loss[1] + self.losses['adv_d'] = loss_d_fake + loss_d_real + + loss_g = self.losses['l1'] + self.losses['perceptual'] + self.losses['style'] + self.losses['adv_g'] + loss_d = self.losses['adv_d'] + + self.optimizers['optimG'].clear_grad() + self.optimizers['optimD'].clear_grad() + loss_g.backward() + loss_d.backward() + self.optimizers['optimG'].step() + self.optimizers['optimD'].step() + + def test_iter(self, metrics=None): + self.eval() + with paddle.no_grad(): + self.forward() + self.train() + + def setup_optimizers(self, lr, cfg): + for opt_name, opt_cfg in cfg.items(): + if opt_name == 'lr': + learning_rate = opt_cfg + continue + cfg_ = opt_cfg.copy() + net_names = cfg_.pop('net_names') + parameters = [] + for net_name in net_names: + parameters += self.nets[net_name].parameters() + if opt_name == 'optimG': + lr = learning_rate * 4 + else: + lr = learning_rate + self.optimizers[opt_name] = build_optimizer( + cfg_, lr, parameters) + + return self.optimizers diff --git a/ppgan/models/criterions/__init__.py b/ppgan/models/criterions/__init__.py index fdd609c4125927c3ae09daf8104a88a422e444eb..219b72e9ebe202442f44d38ae9b131181e7b2bfb 100644 --- a/ppgan/models/criterions/__init__.py +++ b/ppgan/models/criterions/__init__.py @@ -11,3 +11,4 @@ from .builder import build_criterion from .ssim import SSIM from .id_loss import IDLoss from .gfpgan_loss import GFPGANGANLoss, GFPGANL1Loss, GFPGANPerceptualLoss +from .aotgan_perceptual_loss import AOTGANCriterionLoss diff --git a/ppgan/models/criterions/aotgan_perceptual_loss.py b/ppgan/models/criterions/aotgan_perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..48b86d9418bfb31cdd5f04997ab5a161b523e908 --- /dev/null +++ b/ppgan/models/criterions/aotgan_perceptual_loss.py @@ -0,0 +1,223 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.utils import spectral_norm + +from ppgan.utils.download import get_path_from_url +from .builder import CRITERIONS + +# VGG19(ImageNet pretrained) +class VGG19F(nn.Layer): + def __init__(self): + super(VGG19F, self).__init__() + + self.feature_0 = nn.Conv2D(3, 64, 3, 1, 1) + self.relu_1 = nn.ReLU() + self.feature_2 = nn.Conv2D(64, 64, 3, 1, 1) + self.relu_3 = nn.ReLU() + + self.mp_4 = nn.MaxPool2D(2, 2, 0) + self.feature_5 = nn.Conv2D(64, 128, 3, 1, 1) + self.relu_6 = nn.ReLU() + self.feature_7 = nn.Conv2D(128, 128, 3, 1, 1) + self.relu_8 = nn.ReLU() + + self.mp_9 = nn.MaxPool2D(2, 2, 0) + self.feature_10 = nn.Conv2D(128, 256, 3, 1, 1) + self.relu_11 = nn.ReLU() + self.feature_12 = nn.Conv2D(256, 256, 3, 1, 1) + self.relu_13 = nn.ReLU() + self.feature_14 = nn.Conv2D(256, 256, 3, 1, 1) + self.relu_15 = nn.ReLU() + self.feature_16 = nn.Conv2D(256, 256, 3, 1, 1) + self.relu_17 = nn.ReLU() + + self.mp_18 = nn.MaxPool2D(2, 2, 0) + self.feature_19 = nn.Conv2D(256, 512, 3, 1, 1) + self.relu_20 = nn.ReLU() + self.feature_21 = nn.Conv2D(512, 512, 3, 1, 1) + self.relu_22 = nn.ReLU() + self.feature_23 = nn.Conv2D(512, 512, 3, 1, 1) + self.relu_24 = nn.ReLU() + self.feature_25 = nn.Conv2D(512, 512, 3, 1, 1) + self.relu_26 = nn.ReLU() + + self.mp_27 = nn.MaxPool2D(2, 2, 0) + self.feature_28 = nn.Conv2D(512, 512, 3, 1, 1) + self.relu_29 = nn.ReLU() + self.feature_30 = nn.Conv2D(512, 512, 3, 1, 1) + self.relu_31 = nn.ReLU() + self.feature_32 = nn.Conv2D(512, 512, 3, 1, 1) + self.relu_33 = nn.ReLU() + self.feature_34 = nn.Conv2D(512, 512, 3, 1, 1) + self.relu_35 = nn.ReLU() + + def forward(self, x): + x = self.stand(x) + feats = [] + group = [] + x = self.feature_0(x) + x = self.relu_1(x) + group.append(x) + x = self.feature_2(x) + x = self.relu_3(x) + group.append(x) + feats.append(group) + + group = [] + x = self.mp_4(x) + x = self.feature_5(x) + x = self.relu_6(x) + group.append(x) + x = self.feature_7(x) + x = self.relu_8(x) + group.append(x) + feats.append(group) + + group = [] + x = self.mp_9(x) + x = self.feature_10(x) + x = self.relu_11(x) + group.append(x) + x = self.feature_12(x) + x = self.relu_13(x) + group.append(x) + x = self.feature_14(x) + x = self.relu_15(x) + group.append(x) + x = self.feature_16(x) + x = self.relu_17(x) + group.append(x) + feats.append(group) + + group = [] + x = self.mp_18(x) + x = self.feature_19(x) + x = self.relu_20(x) + group.append(x) + x = self.feature_21(x) + x = self.relu_22(x) + group.append(x) + x = self.feature_23(x) + x = self.relu_24(x) + group.append(x) + x = self.feature_25(x) + x = self.relu_26(x) + group.append(x) + feats.append(group) + + group = [] + x = self.mp_27(x) + x = self.feature_28(x) + x = self.relu_29(x) + group.append(x) + x = self.feature_30(x) + x = self.relu_31(x) + group.append(x) + x = self.feature_32(x) + x = self.relu_33(x) + group.append(x) + x = self.feature_34(x) + x = self.relu_35(x) + group.append(x) + feats.append(group) + + return feats + + def stand(self, x): + mean = paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1]) + std = paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]) + y = (x + 1.) / 2. + y = (y - mean) / std + return y + +# l1 loss +class L1(): + def __init__(self,): + self.calc = nn.L1Loss() + + def __call__(self, x, y): + return self.calc(x, y) + +# perceptual loss +class Perceptual(): + def __init__(self, vgg, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): + super(Perceptual, self).__init__() + self.vgg = vgg + self.criterion = nn.L1Loss() + self.weights = weights + + def __call__(self, x, y, img_size): + x = F.interpolate(x, (img_size, img_size), mode='bilinear', align_corners=True) + y = F.interpolate(y, (img_size, img_size), mode='bilinear', align_corners=True) + x_features = self.vgg(x) + y_features = self.vgg(y) + content_loss = 0.0 + for i in range(len(self.weights)): + content_loss += self.weights[i] * self.criterion(x_features[i][0], y_features[i][0]) # 此vgg19预训练模型无bn层,所以尝试不用rate + return content_loss + +# style loss +class Style(): + def __init__(self, vgg): + super(Style, self).__init__() + self.vgg = vgg + self.criterion = nn.L1Loss() + + def compute_gram(self, x): + b, c, h, w = x.shape + f = x.reshape([b, c, w * h]) + f_T = f.transpose([0, 2, 1]) + G = paddle.matmul(f, f_T) / (h * w * c) + return G + + def __call__(self, x, y, img_size): + x = F.interpolate(x, (img_size, img_size), mode='bilinear', align_corners=True) + y = F.interpolate(y, (img_size, img_size), mode='bilinear', align_corners=True) + x_features = self.vgg(x) + y_features = self.vgg(y) + style_loss = 0.0 + blocks = [2, 3, 4, 5] + layers = [2, 4, 4, 2] + for b, l in list(zip(blocks, layers)): + b = b - 1 + l = l - 1 + style_loss += self.criterion(self.compute_gram(x_features[b][l]), self.compute_gram(y_features[b][l])) + return style_loss + +# sum of weighted losses +@CRITERIONS.register() +class AOTGANCriterionLoss(nn.Layer): + def __init__(self, + pretrained, + ): + super(AOTGANCriterionLoss, self).__init__() + self.model = VGG19F() + weight_path = get_path_from_url(pretrained) + vgg_weight = paddle.load(weight_path) + self.model.set_state_dict(vgg_weight) + print('PerceptualVGG loaded pretrained weight.') + self.l1_loss = L1() + self.perceptual_loss = Perceptual(self.model) + self.style_loss = Style(self.model) + + def forward(self, img_r, img_f, img_size): + l1_loss = self.l1_loss(img_r, img_f) + perceptual_loss = self.perceptual_loss(img_r, img_f, img_size) + style_loss = self.style_loss(img_r, img_f, img_size) + + return l1_loss, perceptual_loss, style_loss diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index 9d64778c8829e71bb4e55545a29296d39f6021c6..3fe48bcf5b8baa1e42bd52ea400eb43b8fedd39b 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -26,3 +26,4 @@ from .discriminator_lapstyle import LapStyleDiscriminator from .discriminator_photopen import MultiscaleDiscriminator from .discriminator_singan import SinGANDiscriminator from .arcface_arch_paddle import ResNetArcFace +from .discriminator_aotgan import Discriminator diff --git a/ppgan/models/discriminators/discriminator_aotgan.py b/ppgan/models/discriminators/discriminator_aotgan.py new file mode 100644 index 0000000000000000000000000000000000000000..57a34109364e58845877e0460ccec7f03b18bd61 --- /dev/null +++ b/ppgan/models/discriminators/discriminator_aotgan.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +from paddle.nn.utils import spectral_norm + +from .builder import DISCRIMINATORS + +@DISCRIMINATORS.register() +class Discriminator(nn.Layer): + def __init__(self, inc = 3): + super(Discriminator, self).__init__() + self.conv = nn.Sequential( + spectral_norm(nn.Conv2D(inc, 64, 4, 2, 1, bias_attr=False)), + nn.LeakyReLU(0.2), + spectral_norm(nn.Conv2D(64, 128, 4, 2, 1, bias_attr=False)), + nn.LeakyReLU(0.2), + spectral_norm(nn.Conv2D(128, 256, 4, 2, 1, bias_attr=False)), + nn.LeakyReLU(0.2), + spectral_norm(nn.Conv2D(256, 512, 4, 1, 1, bias_attr=False)), + nn.LeakyReLU(0.2), + nn.Conv2D(512, 1, 4, 1, 1) + ) + + def forward(self, x): + feat = self.conv(x) + return feat diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py old mode 100644 new mode 100755 index 71834e2879d5828d13499d5abd41f9d0e6034d30..ba5a7308b0e6f664ce3ed5204e36d51f6358c7fd --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -46,3 +46,4 @@ from .swinir import SwinIR from .gfpganv1_clean_arch import GFPGANv1Clean from .gfpganv1_arch import GFPGANv1, StyleGAN2DiscriminatorGFPGAN from .invdn import InvDN +from .generater_aotgan import InpaintGenerator diff --git a/ppgan/models/generators/generater_aotgan.py b/ppgan/models/generators/generater_aotgan.py new file mode 100644 index 0000000000000000000000000000000000000000..4b670b23a08600d55a39c1fb01176118b53e84d8 --- /dev/null +++ b/ppgan/models/generators/generater_aotgan.py @@ -0,0 +1,98 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.utils import spectral_norm + +from .builder import GENERATORS + +# Aggregated Contextual Transformations Block +class AOTBlock(nn.Layer): + def __init__(self, dim, rates): + super(AOTBlock, self).__init__() + + self.rates = rates + for i, rate in enumerate(rates): + self.__setattr__( + 'block{}'.format(str(i).zfill(2)), + nn.Sequential( + nn.Pad2D(rate, mode='reflect'), + nn.Conv2D(dim, dim//4, 3, 1, 0, dilation=int(rate)), + nn.ReLU())) + self.fuse = nn.Sequential( + nn.Pad2D(1, mode='reflect'), + nn.Conv2D(dim, dim, 3, 1, 0, dilation=1)) + self.gate = nn.Sequential( + nn.Pad2D(1, mode='reflect'), + nn.Conv2D(dim, dim, 3, 1, 0, dilation=1)) + + def forward(self, x): + out = [self.__getattr__(f'block{str(i).zfill(2)}')(x) for i in range(len(self.rates))] + out = paddle.concat(out, 1) + out = self.fuse(out) + mask = my_layer_norm(self.gate(x)) + mask = F.sigmoid(mask) + return x * (1 - mask) + out * mask + +def my_layer_norm(feat): + mean = feat.mean((2, 3), keepdim=True) + std = feat.std((2, 3), keepdim=True) + 1e-9 + feat = 2 * (feat - mean) / std - 1 + feat = 5 * feat + return feat + +class UpConv(nn.Layer): + def __init__(self, inc, outc, scale=2): + super(UpConv, self).__init__() + self.scale = scale + self.conv = nn.Conv2D(inc, outc, 3, 1, 1) + + def forward(self, x): + return self.conv(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)) + +# generator +@GENERATORS.register() +class InpaintGenerator(nn.Layer): + def __init__(self, rates, block_num): + super(InpaintGenerator, self).__init__() + + self.encoder = nn.Sequential( + nn.Pad2D(3, mode='reflect'), + nn.Conv2D(4, 64, 7, 1, 0), + nn.ReLU(), + nn.Conv2D(64, 128, 4, 2, 1), + nn.ReLU(), + nn.Conv2D(128, 256, 4, 2, 1), + nn.ReLU() + ) + + self.middle = nn.Sequential(*[AOTBlock(256, rates) for _ in range(block_num)]) + + self.decoder = nn.Sequential( + UpConv(256, 128), + nn.ReLU(), + UpConv(128, 64), + nn.ReLU(), + nn.Conv2D(64, 3, 3, 1, 1) + ) + + def forward(self, x): + x = self.encoder(x) + x = self.middle(x) + x = self.decoder(x) + x = paddle.tanh(x) + + return x diff --git a/test_tipc/configs/aotgan/train_infer_python.txt b/test_tipc/configs/aotgan/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..432baaea5618a668d946b4fc6db5b473a478ba9b --- /dev/null +++ b/test_tipc/configs/aotgan/train_infer_python.txt @@ -0,0 +1,59 @@ +===========================train_params=========================== +model_name:aotgan +python:python3.7 +gpu_list:0 +## +auto_cast:null +epochs:lite_train_lite_infer=10|lite_train_whole_infer=10|whole_train_whole_infer=200 +output_dir:./output/ +dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1 +pretrained_model:null +train_model_name:aotgan*/*checkpoint.pdparams +train_infer_img_dir:./data/aotgan +null:null +## +trainer:norm_train +norm_train:tools/main.py -c configs/aotgan.yaml --seed 123 -o log_config.interval=1 snapshot_config.interval=1 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +--output_dir:./output/ +load:null +norm_export:tools/export_model.py -c configs/aotgan.yaml --inputs_size="-1,4,-1,-1" --model_name inference --load +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +inference_dir:inference +train_model:./inference/aotgan/aotganmodel_netG +infer_export:null +infer_quant:False +inference:tools/inference.py --model_type aotgan --seed 123 -c configs/aotgan.yaml --output_path test_tipc/output/ +--device:cpu +null:null +null:null +null:null +null:null +null:null +--model_path: +null:null +null:null +--benchmark:True +null:null +===========================train_benchmark_params========================== +batch_size:1 +fp_items:fp32 +epoch:10 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,256,256]}] diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index aedeee4acd77cb7cef8f10a4cd4e3cf1f3dd6de5..b29b2ef68e6c1092f505ec08fbb1f68a17cab9b5 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -81,6 +81,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/gfpgan_tipc_data.zip --no-check-certificate mkdir -p ./data/gfpgan_data cd ./data/ && unzip -q gfpgan_tipc_data.zip -d gfpgan_data/ && cd ../ ;; + aotgan) + rm -rf ./data/aotgan* + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/aotgan.zip --no-check-certificate + cd ./data/ && unzip -q aotgan.zip && cd ../ ;; esac elif [ ${MODE} = "whole_train_whole_infer" ];then if [ ${model_name} == "Pix2pix" ]; then diff --git a/tools/inference.py b/tools/inference.py index 0c87503851d0b41b8993e74d5c7d1e74a9ffeab2..536b7b3c16ad5377dacc5254ad6d3b8881c161dc 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -19,7 +19,7 @@ from ppgan.metrics import build_metric MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \ - "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan", "swinir", "invdn"] + "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan", "swinir", "invdn", "aotgan"] def parse_args(): @@ -423,6 +423,16 @@ def main(): for metric in metrics.values(): metric.update(image_numpy, gt_numpy) break + elif model_type == 'aotgan': + input_data = paddle.concat((data['img'], data['mask']), axis=1).numpy() + input_handles[0].copy_from_cpu(input_data) + predictor.run() + prediction = output_handle.copy_to_cpu() + prediction = paddle.to_tensor(prediction) + image_numpy = tensor2img(prediction, min_max) + save_image( + image_numpy, + os.path.join(args.output_path, "aotgan/{}.png".format(i))) if metrics: log_file = open(metric_file, 'a')