diff --git a/README.md b/README.md index cd5008a328caa2b275dfad0855786d999e498cf4..861fac1c6e58e257b423ab9088507246b4e56906 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional ## Composite Application -* [Video restore](./docs/zh_CN/tutorials/video_restore.md) +* [Video restore](./docs/en_US/tutorials/video_restore.md) ## Examples diff --git a/README_cn.md b/README_cn.md index f0b84147dacccfbc390dc2e8cccca5f9bd8ab26b..a6b71db1d48a2eafcc2002ede3107ae04fd07136 100644 --- a/README_cn.md +++ b/README_cn.md @@ -45,8 +45,8 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆) * [U-GAT-IT](./docs/zh_CN/tutorials/ugatit.md) * [Photo2Cartoon](docs/zh_CN/tutorials/photo2cartoon.md) * [Wav2Lip](docs/zh_CN/tutorials/wav2lip.md) -* [Super_Resolution](./docs/en_US/tutorials/super_resolution.md) -* [StyleGAN2](./docs/en_US/tutorials/styleganv2.md) +* [Super_Resolution](./docs/zh_CN/tutorials/super_resolution.md) +* [StyleGAN2](./docs/zh_CN/tutorials/styleganv2.md) ## 复合应用 diff --git a/docs/en_US/tutorials/super_resolution.md b/docs/en_US/tutorials/super_resolution.md index c005bd1bfece771a8ee5c79f73393442b643ddeb..596ad7c4c751c74cb9ff5f7a3422f2ad06325032 100644 --- a/docs/en_US/tutorials/super_resolution.md +++ b/docs/en_US/tutorials/super_resolution.md @@ -4,9 +4,12 @@ Super resolution is a process of upscaling and improving the details within an image. It usually takes a low-resolution image as input and upscales the same image to a higher resolution as output. Here we provide three super-resolution models, namely [RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf), [ESRGAN](https://arxiv.org/abs/1809.00219v2), [LESRCNN](https://arxiv.org/abs/2007.04344). - [RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf) proposed a realworld super-resolution model aiming at better perception. - [ESRGAN](https://arxiv.org/abs/1809.00219v2) is an enhanced SRGAN that improves the three key components of SRGAN. - [LESRCNN](https://arxiv.org/abs/2007.04344) is a lightweight enhanced SR CNN (LESRCNN) with three successive sub-blocks. + + [RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf) focus on designing a novel degradation framework for realworld images by estimating various blur kernels as well as real noise distributions. Based on the novel degradation framework, we can acquire LR images sharing a common domain with real-world images. RealSR is a real-world super-resolution model aiming at better perception. Extensive experiments on synthetic noise data and real-world images demonstrate that RealSR outperforms the state-of-the-art methods, resulting in lower noise and better visual quality. + + [ESRGAN](https://arxiv.org/abs/1809.00219v2) is an enhanced SRGAN. To further enhance the visual quality of SRGAN, ESRGAN improves three key components of srgan. In addition, ESRGAN also introduces the Residual-in-Residual Dense Block (RRDB) without batch normalization as the basic network building unit, lets the discriminator predict relative realness instead of the absolute value, and improves the perceptual loss by using the features before activation. Benefiting from these improvements, the proposed ESRGAN achieves consistently better visual quality with more realistic and natural textures than SRGAN and won the first place in the PIRM2018-SR Challenge. + + Considering that the application of CNN in SISR often consume high computational cost and more memory storage for training a SR model, a lightweight enhanced SR CNN ([LESRCNN](https://arxiv.org/abs/2007.04344)) was proposed.Extensive experiments demonstrate that the proposed LESRCNN outperforms state-of-the-arts on SISR in terms of qualitative and quantitative evaluation. ## 1.2 How to use @@ -189,5 +192,5 @@ The metrics are PSNR / SSIM. author={Guo, Yong and Chen, Jian and Wang, Jingdong and Chen, Qi and Cao, Jiezhang and Deng, Zeshuai and Xu, Yanwu and Tan, Mingkui}, booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, year={2020} -} + } ``` diff --git a/docs/zh_CN/tutorials/super_resolution.md b/docs/zh_CN/tutorials/super_resolution.md new file mode 100644 index 0000000000000000000000000000000000000000..c5a11fedac0af32633420e8ba6d1064b1592aee9 --- /dev/null +++ b/docs/zh_CN/tutorials/super_resolution.md @@ -0,0 +1,176 @@ +# 1 超分 + +## 1.1 原理介绍 + + 超分是放大和改善图像细节的过程。它通常将低分辨率图像作为输入,将同一图像放大到更高分辨率作为输出。这里我们提供了三种超分辨率模型,即[RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf), [ESRGAN](https://arxiv.org/abs/1809.00219v2), [LESRCNN](https://arxiv.org/abs/2007.04344). + [RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf)通过估计各种模糊内核以及实际噪声分布,为现实世界的图像设计一种新颖的真实图片降采样框架。基于该降采样框架,可以获取与真实世界图像共享同一域的低分辨率图像。RealSR是一个旨在提高感知度的真实世界超分辨率模型。对合成噪声数据和真实世界图像进行的大量实验表明,RealSR模型能够有效降低了噪声并提高了视觉质量。 + [ESRGAN](https://arxiv.org/abs/1809.00219v2)是增强型SRGAN,为了进一步提高SRGAN的视觉质量,ESRGAN在SRGAN的基础上改进了SRGAN的三个关键组件。此外,ESRGAN还引入了未经批量归一化的剩余密集块(RRDB)作为基本的网络构建单元,让鉴别器预测相对真实性而不是绝对值,并利用激活前的特征改善感知损失。得益于这些改进,提出的ESRGAN实现了比SRGAN更好的视觉质量和更逼真、更自然的纹理,并在PIRM2018-SR挑战赛中获得第一名。 + 考虑到CNNs在SISR的应用上往往会消耗大量的计算量和存储空间来训练SR模型,轻量级增强SR-CNN([LESRCNN](https://arxiv.org/abs/2007.04344))被提出。大量实验表明,LESRCNN在定性和定量评价方面优于现有的SISR算法。 + + + +## 1.2 如何使用 + +### 1.2.1 数据准备 + + 常用的图像超分数据集如下: + | name | 数据集 | 数据描述 | 下载 | + |---|---|---|---| + | 2K Resolution | [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) | proposed in [NTIRE17](https://data.vision.ee.ethz.ch/cvl/ntire17//) (800 train and 100 validation) | [official website](https://data.vision.ee.ethz.ch/cvl/DIV2K/) | + | Classical SR Testing | Set5 | Set5 test dataset | [Google Drive](https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u) / [Baidu Drive](https://pan.baidu.com/s/1q_1ERCMqALH0xFwjLM0pTg#list/path=%2Fsharelink2016187762-785433459861126%2Fclassical_SR_datasets&parentPath=%2Fsharelink2016187762-785433459861126) | + | Classical SR Testing | Set14 | Set14 test dataset | [Google Drive](https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u) / [Baidu Drive](https://pan.baidu.com/s/1q_1ERCMqALH0xFwjLM0pTg#list/path=%2Fsharelink2016187762-785433459861126%2Fclassical_SR_datasets&parentPath=%2Fsharelink2016187762-785433459861126) | + + 数据集DIV2K, Set5 和 Set14 的组成形式如下: + ``` + PaddleGAN + ├── data + ├── DIV2K + ├── DIV2K_train_HR + ├── DIV2K_train_LR_bicubic + | ├──X2 + | ├──X3 + | └──X4 + ├── DIV2K_valid_HR + ├── DIV2K_valid_LR_bicubic + Set5 + ├── GTmod12 + ├── LRbicx2 + ├── LRbicx3 + ├── LRbicx4 + └── original + Set14 + ├── GTmod12 + ├── LRbicx2 + ├── LRbicx3 + ├── LRbicx4 + └── original + ... + ``` + 使用以下命令处理DIV2K数据集: + ``` + python data/process_div2k_data.py --data-root data/DIV2K + ``` + 程序完成后,检查DIV2K目录中是否有``DIV2K_train_HR_sub``、``X2_sub``、``X3_sub``和``X4_sub``目录 + ``` + PaddleGAN + ├── data + ├── DIV2K + ├── DIV2K_train_HR + ├── DIV2K_train_HR_sub + ├── DIV2K_train_LR_bicubic + | ├──X2 + | ├──X2_sub + | ├──X3 + | ├──X3_sub + | ├──sX4 + | └──X4_sub + ├── DIV2K_valid_HR + ├── DIV2K_valid_LR_bicubic + ... + ``` + +#### Realsr df2k model的数据准备 + + 从 [NTIRE 2020 RWSR](https://competitions.codalab.org/competitions/22220#participate) 下载数据集并解压到您的路径下。 + 将 Corrupted-tr-x.zip 和 Corrupted-tr-y.zip 解压到 ``PaddleGAN/data/ntire20`` 目录下。 + + 运行如下命令: + ``` + python ./data/realsr_preprocess/create_bicubic_dataset.py --dataset df2k --artifacts tdsr + python ./data/realsr_preprocess/collect_noise.py --dataset df2k --artifacts tdsr + ``` + +### 1.2.2 训练/测试 + + 示例以df2k数据集和RealSR模型为例。如果您想使用自己的数据集,可以在配置文件中修改数据集为您自己的数据集。如果您想使用其他模型,可以通过替换配置文件。 + + 训练模型: + ``` + python -u tools/main.py --config-file configs/realsr_bicubic_noise_x4_df2k.yaml + ``` + + 训练模型: + ``` + python tools/main.py --config-file configs/realsr_bicubic_noise_x4_df2k.yaml --evaluate-only --load ${PATH_OF_WEIGHT} + ``` + +## 1.3 实验结果展示 +实验数值结果是在 RGB 通道上进行评估,并在评估之前裁剪每个边界的尺度像素。 + +度量指标为 PSNR / SSIM. + +| 模型 | Set5 | Set14 | DIV2K | +|---|---|---|---| +| realsr_df2k | 28.4385 / 0.8106 | 24.7424 / 0.6678 | 26.7306 / 0.7512 | +| realsr_dped | 20.2421 / 0.6158 | 19.3775 / 0.5259 | 20.5976 / 0.6051 | +| realsr_merge | 24.8315 / 0.7030 | 23.0393 / 0.5986 | 24.8510 / 0.6856 | +| lesrcnn_x4 | 31.9476 / 0.8909 | 28.4110 / 0.7770 | 30.231 / 0.8326 | +| esrgan_psnr_x4 | 32.5512 / 0.8991 | 28.8114 / 0.7871 | 30.7565 / 0.8449 | +| esrgan_x4 | 28.7647 / 0.8187 | 25.0065 / 0.6762 | 26.9013 / 0.7542 | +| drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - | + + + + + +## 1.4 模型下载 +| 模型 | 数据集 | 下载地址 | +|---|---|---| +| realsr_df2k | df2k | [realsr_df2k](https://paddlegan.bj.bcebos.com/models/realsr_df2k.pdparams) +| realsr_dped | dped | [realsr_dped](https://paddlegan.bj.bcebos.com/models/realsr_dped.pdparams) +| realsr_merge | DIV2K | [realsr_merge](https://paddlegan.bj.bcebos.com/models/realsr_merge.pdparams) +| lesrcnn_x4 | DIV2K | [lesrcnn_x4](https://paddlegan.bj.bcebos.com/models/lesrcnn_x4.pdparams) +| esrgan_psnr_x4 | DIV2K | [esrgan_psnr_x4](https://paddlegan.bj.bcebos.com/models/esrgan_psnr_x4.pdparams) +| esrgan_x4 | DIV2K | [esrgan_x4](https://paddlegan.bj.bcebos.com/models/esrgan_x4.pdparams) +| drns_x4 | DIV2K | [drns_x4](https://paddlegan.bj.bcebos.com/models/DRNSx4.pdparams) + + +# 参考文献 + +- 1. [Real-World Super-Resolution via Kernel Estimation and Noise Injection](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf) + + ``` + @inproceedings{ji2020real, + title={Real-World Super-Resolution via Kernel Estimation and Noise Injection}, + author={Ji, Xiaozhong and Cao, Yun and Tai, Ying and Wang, Chengjie and Li, Jilin and Huang, Feiyue}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops}, + pages={466--467}, + year={2020} + } + ``` + +- 2. [ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks](https://arxiv.org/abs/1809.00219v2) + + ``` + @inproceedings{wang2018esrgan, + title={Esrgan: Enhanced super-resolution generative adversarial networks}, + author={Wang, Xintao and Yu, Ke and Wu, Shixiang and Gu, Jinjin and Liu, Yihao and Dong, Chao and Qiao, Yu and Change Loy, Chen}, + booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, + pages={0--0}, + year={2018} + } + ``` + +- 3. [Lightweight image super-resolution with enhanced CNN](https://arxiv.org/abs/2007.04344) + + ``` + @article{tian2020lightweight, + title={Lightweight image super-resolution with enhanced CNN}, + author={Tian, Chunwei and Zhuge, Ruibin and Wu, Zhihao and Xu, Yong and Zuo, Wangmeng and Chen, Chen and Lin, Chia-Wen}, + journal={Knowledge-Based Systems}, + volume={205}, + pages={106235}, + year={2020}, + publisher={Elsevier} + } + ``` +- 4. [Closed-loop Matters: Dual Regression Networks for Single Image Super-Resolution](https://arxiv.org/pdf/2003.07018.pdf) + + ``` + @inproceedings{guo2020closed, + title={Closed-loop Matters: Dual Regression Networks for Single Image Super-Resolution}, + author={Guo, Yong and Chen, Jian and Wang, Jingdong and Chen, Qi and Cao, Jiezhang and Deng, Zeshuai and Xu, Yanwu and Tan, Mingkui}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + year={2020} + } + ``` \ No newline at end of file diff --git a/docs/zh_CN/tutorials/ugatit.md b/docs/zh_CN/tutorials/ugatit.md index b0f2ac870b47c864520a6f3832b0cc1acfd5f282..f50e970a31393ad7cfca0fc358f1aa9d258ef05c 100644 --- a/docs/zh_CN/tutorials/ugatit.md +++ b/docs/zh_CN/tutorials/ugatit.md @@ -1,3 +1,61 @@ -### U-GAT-IT +# 1 U-GAT-IT -待增加,您也可以先参考通用de[训练/评估/推理教程](../get_started.md) +## 1.1 原理介绍 + + 与CycleGAN类似,[U-GAT-IT](https://arxiv.org/abs/1907.10830)使用未配对的图片进行图像风格转换,输入两个不同风格的图像,U-GAT-IT自动执行风格转换。不同的是,U-GAT-IT在历史研究的基础上以端到端的方式引入了一个新的注意模块和一个新的可学习的归一化函数。 + +## 1.2 如何使用 + +### 1.2.1 数据准备 + + U-GAT-IT使用的Selfie2anime数据集可以从[这里](https://www.kaggle.com/arnaud58/selfie2anime)下载,您也可以使用自己的数据集。 + + 数据的组成形式为: + + ``` + ├── dataset + └── YOUR_DATASET_NAME + ├── trainA + ├── trainB + ├── testA + └── testB + ``` + +### 1.2.2 训练/测试 + + 示例以selfie2anime数据集为例。如果您想使用自己的数据集,可以在配置文件中修改数据集为您自己的数据集。 + + 训练模型: + ``` + python -u tools/main.py --config-file configs/ugatit_selfie2anime_light.yaml + ``` + + 测试模型: + ``` + python tools/main.py --config-file configs/ugatit_selfie2anime_light.yaml --evaluate-only --load ${PATH_OF_WEIGHT} + ``` + +## 1.3 结果展示 + +![](../../imgs/ugatit.png) + +## 1.4 模型下载 +| 模型 | 数据集 | 下载地址 | +|---|---|---| +| ugatit_light | selfie2anime | [ugatit_light](https://paddlegan.bj.bcebos.com/models/ugatit_light.pdparams) + + + + +# 参考文献 + +- 1. [U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation](https://arxiv.org/abs/1907.10830) + + ``` + @article{kim2019u, + title={U-GAT-IT: unsupervised generative attentional networks with adaptive layer-instance normalization for image-to-image translation}, + author={Kim, Junho and Kim, Minjae and Kang, Hyeonwoo and Lee, Kwanghee}, + journal={arXiv preprint arXiv:1907.10830}, + year={2019} + } + ``` diff --git a/ppgan/apps/base_predictor.py b/ppgan/apps/base_predictor.py index 2cdd3e31366bd70e640bbe14a03b7efabb2b9718..4569da791b2b128f725fe428bf4f0ae5e5911124 100644 --- a/ppgan/apps/base_predictor.py +++ b/ppgan/apps/base_predictor.py @@ -28,8 +28,8 @@ class BasePredictor(object): # todo self.model = build_model(self.cfg) pass else: - place = paddle.fluid.framework._current_expected_place() - self.exe = paddle.fluid.Executor(place) + place = paddle.get_device() + self.exe = paddle.static.Executor(place) file_names = os.listdir(self.weight_path) for file_name in file_names: if file_name.find('model') > -1: diff --git a/ppgan/apps/dain_predictor.py b/ppgan/apps/dain_predictor.py index 31d9e974782d5007fdaff141fe7b42f8c41834f9..1fa4ab23a443ef786f47e17c01217ebf2838fc73 100644 --- a/ppgan/apps/dain_predictor.py +++ b/ppgan/apps/dain_predictor.py @@ -21,7 +21,6 @@ from tqdm import tqdm from imageio import imread, imsave import paddle -import paddle.fluid as fluid from ppgan.utils.download import get_path_from_url from ppgan.utils.video import video2frames, frames2video diff --git a/ppgan/faceutils/face_detection/detection/core.py b/ppgan/faceutils/face_detection/detection/core.py index b9988f8d707de136057a7fb043cc568baf2b2f2c..adb541ceb3464032c9e83b3b9e4473619fb4e823 100644 --- a/ppgan/faceutils/face_detection/detection/core.py +++ b/ppgan/faceutils/face_detection/detection/core.py @@ -134,7 +134,7 @@ class FaceDetector(object): tensor_or_path)[..., ::-1] elif isinstance( tensor_or_path, - (paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)): + (paddle.static.Variable, paddle.Tensor)): # Call cpu in case its coming from cuda return tensor_or_path.numpy()[ ..., ::-1].copy() if not rgb else tensor_or_path.numpy() diff --git a/ppgan/metrics/compute_fid.py b/ppgan/metrics/compute_fid.py index da213d2586fc9da41b39fb5afb24af6da458cf6a..fb9f82e594a84fb4e2e8bb68f31e3fc3034016ba 100644 --- a/ppgan/metrics/compute_fid.py +++ b/ppgan/metrics/compute_fid.py @@ -16,12 +16,11 @@ import os import fnmatch import numpy as np import cv2 +import paddle from PIL import Image from cv2 import imread from scipy import linalg -import paddle.fluid as fluid from inception import InceptionV3 -from paddle.fluid.dygraph.base import to_variable try: from tqdm import tqdm @@ -89,8 +88,8 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu, images = images.transpose((0, 3, 1, 2)) images /= 255 - images = to_variable(images) - param_dict, _ = fluid.load_dygraph(premodel_path) + images = paddle.to_tensor(images) + param_dict, _ = paddle.load(premodel_path) model.set_dict(param_dict) model.eval() pred = model(images)[0][0] @@ -188,9 +187,9 @@ def _get_activations(files, if style == 'stargan': pred_arr[start:end] = inception_infer(images, premodel_path) else: - with fluid.dygraph.guard(): - images = to_variable(images) - param_dict, _ = fluid.load_dygraph(premodel_path) + with paddle.guard(): + images = paddle.to_tensor(images) + param_dict, _ = paddle.load(premodel_path) model.set_dict(param_dict) model.eval() @@ -202,9 +201,9 @@ def _get_activations(files, def inception_infer(x, model_path): - exe = fluid.Executor() + exe = paddle.static.Executor() [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(model_path, exe) + fetch_targets] = paddle.static.load_inference_model(model_path, exe) results = exe.run(inference_program, feed={feed_target_names[0]: x}, fetch_list=fetch_targets) @@ -264,7 +263,7 @@ def calculate_fid_given_paths(paths, raise RuntimeError('Invalid path: %s' % p) if model is None and style != 'stargan': - with fluid.dygraph.guard(): + with paddle.guard(): block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] model = InceptionV3([block_idx], class_dim=1008) diff --git a/ppgan/metrics/inception.py b/ppgan/metrics/inception.py index 643d4766e4deb71a1b3d5c47a2777cfdc9b677a5..aa5cb08295ab157fa3d57cb604a559681304b44d 100644 --- a/ppgan/metrics/inception.py +++ b/ppgan/metrics/inception.py @@ -14,16 +14,13 @@ import math import paddle -import paddle.fluid as fluid -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear -from paddle.fluid.dygraph.base import to_variable +import paddle.nn as nn +from paddle.nn import Conv2D, AvgPool2D, MaxPool2D, BatchNorm, Linear __all__ = ['InceptionV3'] -class InceptionV3(fluid.dygraph.Layer): +class InceptionV3(nn.Layer): DEFAULT_BLOCK_INDEX = 3 BLOCK_INDEX_BY_DIM = { 64: 0, # First max pooling features @@ -60,21 +57,21 @@ class InceptionV3(fluid.dygraph.Layer): 3, padding=1, name='Conv2d_2b_3x3') - self.maxpool1 = Pool2D(pool_size=3, pool_stride=2, pool_type='max') + self.maxpool1 = MaxPool2D(pool_size=3, pool_stride=2) block0 = [ self.Conv2d_1a_3x3, self.Conv2d_2a_3x3, self.Conv2d_2b_3x3, self.maxpool1 ] - self.blocks.append(fluid.dygraph.Sequential(*block0)) + self.blocks.append(nn.Sequential(*block0)) ### block1 if self.last_needed_block >= 1: self.Conv2d_3b_1x1 = ConvBNLayer(64, 80, 1, name='Conv2d_3b_1x1') self.Conv2d_4a_3x3 = ConvBNLayer(80, 192, 3, name='Conv2d_4a_3x3') - self.maxpool2 = Pool2D(pool_size=3, pool_stride=2, pool_type='max') + self.maxpool2 = MaxPool2D(pool_size=3, pool_stride=2) block1 = [self.Conv2d_3b_1x1, self.Conv2d_4a_3x3, self.maxpool2] - self.blocks.append(fluid.dygraph.Sequential(*block1)) + self.blocks.append(nn.Sequential(*block1)) ### block2 ### Mixed_5b 5c 5d @@ -100,7 +97,7 @@ class InceptionV3(fluid.dygraph.Layer): self.Mixed_5b, self.Mixed_5c, self.Mixed_5d, self.Mixed_6a, self.Mixed_6b, self.Mixed_6c, self.Mixed_6d, self.Mixed_6e ] - self.blocks.append(fluid.dygraph.Sequential(*block2)) + self.blocks.append(nn.Sequential(*block2)) if self.aux_logits: self.AuxLogits = InceptionAux(768, self.class_dim, name='AuxLogits') @@ -110,19 +107,20 @@ class InceptionV3(fluid.dygraph.Layer): self.Mixed_7a = InceptionD(768, name='Mixed_7a') self.Mixed_7b = Fid_inceptionE_1(1280, name='Mixed_7b') self.Mixed_7c = Fid_inceptionE_2(2048, name='Mixed_7c') - self.avgpool = Pool2D(global_pooling=True, pool_type='avg') + self.avgpool = AvgPool2D(global_pooling=True) block3 = [self.Mixed_7a, self.Mixed_7b, self.Mixed_7c, self.avgpool] - self.blocks.append(fluid.dygraph.Sequential(*block3)) + self.blocks.append(nn.Sequential(*block3)) def forward(self, x): out = [] aux = None if self.resize_input: - x = fluid.layers.resize_bilinear(x, - out_shape=[299, 299], - align_corners=False, - align_mode=0) + x = nn.functional.interpolate(x, + size=[299, 299], + mode='bilinear', + align_corners=False, + align_mode=0) if self.normalize_input: x = x * 2 - 1 @@ -139,7 +137,7 @@ class InceptionV3(fluid.dygraph.Layer): return out, aux -class InceptionA(fluid.dygraph.Layer): +class InceptionA(nn.Layer): def __init__(self, in_channels, pool_features, name=None): super(InceptionA, self).__init__() self.branch1x1 = ConvBNLayer(in_channels, @@ -172,11 +170,10 @@ class InceptionA(fluid.dygraph.Layer): padding=1, name=name + '.branch3x3dbl_3') - self.branch_pool0 = Pool2D(pool_size=3, + self.branch_pool0 = AvgPool2D(pool_size=3, pool_stride=1, pool_padding=1, - exclusive=True, - pool_type='avg') + exclusive=True) self.branch_pool = ConvBNLayer(in_channels, pool_features, 1, @@ -194,11 +191,11 @@ class InceptionA(fluid.dygraph.Layer): branch_pool = self.branch_pool0(x) branch_pool = self.branch_pool(branch_pool) - return fluid.layers.concat( + return paddle.concat( [branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1) -class InceptionB(fluid.dygraph.Layer): +class InceptionB(nn.Layer): def __init__(self, in_channels, name=None): super(InceptionB, self).__init__() self.branch3x3 = ConvBNLayer(in_channels, @@ -222,7 +219,7 @@ class InceptionB(fluid.dygraph.Layer): stride=2, name=name + '.branch3x3dbl_3') - self.branch_pool = Pool2D(pool_size=3, pool_stride=2, pool_type='max') + self.branch_pool = MaxPool2D(pool_size=3, pool_stride=2) def forward(self, x): branch3x3 = self.branch3x3(x) @@ -232,11 +229,11 @@ class InceptionB(fluid.dygraph.Layer): branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) branch_pool = self.branch_pool(x) - return fluid.layers.concat([branch3x3, branch3x3dbl, branch_pool], + return paddle.concat([branch3x3, branch3x3dbl, branch_pool], axis=1) -class InceptionC(fluid.dygraph.Layer): +class InceptionC(nn.Layer): def __init__(self, in_channels, c7, name=None): super(InceptionC, self).__init__() self.branch1x1 = ConvBNLayer(in_channels, @@ -278,11 +275,10 @@ class InceptionC(fluid.dygraph.Layer): padding=(0, 3), name=name + '.branch7x7dbl_5') - self.branch_pool0 = Pool2D(pool_size=3, + self.branch_pool0 = AvgPool2D(pool_size=3, pool_stride=1, pool_padding=1, - exclusive=True, - pool_type='avg') + exclusive=True) self.branch_pool = ConvBNLayer(in_channels, 192, 1, @@ -304,11 +300,11 @@ class InceptionC(fluid.dygraph.Layer): branch_pool = self.branch_pool0(x) branch_pool = self.branch_pool(branch_pool) - return fluid.layers.concat( + return paddle.concat( [branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1) -class InceptionD(fluid.dygraph.Layer): +class InceptionD(nn.Layer): def __init__(self, in_channels, name=None): super(InceptionD, self).__init__() self.branch3x3_1 = ConvBNLayer(in_channels, @@ -339,7 +335,7 @@ class InceptionD(fluid.dygraph.Layer): stride=2, name=name + '.branch7x7x3_4') - self.branch_pool = Pool2D(pool_size=3, pool_stride=2, pool_type='max') + self.branch_pool = MaxPool2D(pool_size=3, pool_stride=2) def forward(self, x): branch3x3 = self.branch3x3_1(x) @@ -352,11 +348,11 @@ class InceptionD(fluid.dygraph.Layer): branch_pool = self.branch_pool(x) - return fluid.layers.concat([branch3x3, branch7x7x3, branch_pool], + return paddle.concat([branch3x3, branch7x7x3, branch_pool], axis=1) -class InceptionE(fluid.dygraph.Layer): +class InceptionE(nn.Layer): def __init__(self, in_channels, name=None): super(InceptionE, self).__init__() self.branch1x1 = ConvBNLayer(in_channels, @@ -395,11 +391,10 @@ class InceptionE(fluid.dygraph.Layer): padding=(1, 0), name=name + '.branch3x3dbl_3b') - self.branch_pool0 = Pool2D(pool_size=3, + self.branch_pool0 = AvgPool2D(pool_size=3, pool_stride=1, pool_padding=1, - exclusive=True, - pool_type='avg') + exclusive=True) self.branch_pool = ConvBNLayer(in_channels, 192, 1, @@ -410,42 +405,42 @@ class InceptionE(fluid.dygraph.Layer): branch3x3_1 = self.branch3x3_1(x) branch3x3_2a = self.branch3x3_2a(branch3x3_1) branch3x3_2b = self.branch3x3_2b(branch3x3_1) - branch3x3 = fluid.layers.concat([branch3x3_2a, branch3x3_2b], axis=1) + branch3x3 = paddle.concat([branch3x3_2a, branch3x3_2b], axis=1) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl_3a = self.branch3x3dbl_3a(branch3x3dbl) branch3x3dbl_3b = self.branch3x3dbl_3b(branch3x3dbl) - branch3x3dbl = fluid.layers.concat([branch3x3dbl_3a, branch3x3dbl_3b], + branch3x3dbl = paddle.concat([branch3x3dbl_3a, branch3x3dbl_3b], axis=1) branch_pool = self.branch_pool0(x) branch_pool = self.branch_pool(branch_pool) - return fluid.layers.concat( + return paddle.concat( [branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) -class InceptionAux(fluid.dygraph.Layer): +class InceptionAux(nn.Layer): def __init__(self, in_channels, num_classes, name=None): super(InceptionAux, self).__init__() self.num_classes = num_classes - self.pool0 = Pool2D(pool_size=5, pool_stride=3, pool_type='avg') + self.pool0 = AvgPool2D(pool_size=5, pool_stride=3) self.conv0 = ConvBNLayer(in_channels, 128, 1, name=name + '.conv0') self.conv1 = ConvBNLayer(128, 768, 5, name=name + '.conv1') - self.pool1 = Pool2D(global_pooling=True, pool_type='avg') + self.pool1 = AvgPool2D(global_pooling=True) def forward(self, x): x = self.pool0(x) x = self.conv0(x) x = self.conv1(x) x = self.pool1(x) - x = fluid.layers.flatten(x, axis=1) - x = fluid.layers.fc(x, size=self.num_classes) + x = paddle.flatten(x, axis=1) + x = paddle.static.nn.fc(x, size=self.num_classes) return x -class Fid_inceptionA(fluid.dygraph.Layer): +class Fid_inceptionA(nn.Layer): """ FID block in inception v3 """ def __init__(self, in_channels, pool_features, name=None): @@ -480,11 +475,10 @@ class Fid_inceptionA(fluid.dygraph.Layer): padding=1, name=name + '.branch3x3dbl_3') - self.branch_pool0 = Pool2D(pool_size=3, + self.branch_pool0 = AvgPool2D(pool_size=3, pool_stride=1, pool_padding=1, - exclusive=True, - pool_type='avg') + exclusive=True) self.branch_pool = ConvBNLayer(in_channels, pool_features, 1, @@ -502,11 +496,11 @@ class Fid_inceptionA(fluid.dygraph.Layer): branch_pool = self.branch_pool0(x) branch_pool = self.branch_pool(branch_pool) - return fluid.layers.concat( + return paddle.concat( [branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1) -class Fid_inceptionC(fluid.dygraph.Layer): +class Fid_inceptionC(nn.Layer): """ FID block in inception v3 """ def __init__(self, in_channels, c7, name=None): @@ -550,11 +544,10 @@ class Fid_inceptionC(fluid.dygraph.Layer): padding=(0, 3), name=name + '.branch7x7dbl_5') - self.branch_pool0 = Pool2D(pool_size=3, + self.branch_pool0 = AvgPool2D(pool_size=3, pool_stride=1, pool_padding=1, - exclusive=True, - pool_type='avg') + exclusive=True) self.branch_pool = ConvBNLayer(in_channels, 192, 1, @@ -576,11 +569,11 @@ class Fid_inceptionC(fluid.dygraph.Layer): branch_pool = self.branch_pool0(x) branch_pool = self.branch_pool(branch_pool) - return fluid.layers.concat( + return paddle.concat( [branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1) -class Fid_inceptionE_1(fluid.dygraph.Layer): +class Fid_inceptionE_1(nn.Layer): """ FID block in inception v3 """ def __init__(self, in_channels, name=None): @@ -621,11 +614,10 @@ class Fid_inceptionE_1(fluid.dygraph.Layer): padding=(1, 0), name=name + '.branch3x3dbl_3b') - self.branch_pool0 = Pool2D(pool_size=3, + self.branch_pool0 = AvgPool2D(pool_size=3, pool_stride=1, pool_padding=1, - exclusive=True, - pool_type='avg') + exclusive=True) self.branch_pool = ConvBNLayer(in_channels, 192, 1, @@ -636,23 +628,23 @@ class Fid_inceptionE_1(fluid.dygraph.Layer): branch3x3_1 = self.branch3x3_1(x) branch3x3_2a = self.branch3x3_2a(branch3x3_1) branch3x3_2b = self.branch3x3_2b(branch3x3_1) - branch3x3 = fluid.layers.concat([branch3x3_2a, branch3x3_2b], axis=1) + branch3x3 = paddle.concat([branch3x3_2a, branch3x3_2b], axis=1) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl_3a = self.branch3x3dbl_3a(branch3x3dbl) branch3x3dbl_3b = self.branch3x3dbl_3b(branch3x3dbl) - branch3x3dbl = fluid.layers.concat([branch3x3dbl_3a, branch3x3dbl_3b], + branch3x3dbl = paddle.concat([branch3x3dbl_3a, branch3x3dbl_3b], axis=1) branch_pool = self.branch_pool0(x) branch_pool = self.branch_pool(branch_pool) - return fluid.layers.concat( + return paddle.concat( [branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) -class Fid_inceptionE_2(fluid.dygraph.Layer): +class Fid_inceptionE_2(nn.Layer): """ FID block in inception v3 """ def __init__(self, in_channels, name=None): @@ -693,10 +685,9 @@ class Fid_inceptionE_2(fluid.dygraph.Layer): padding=(1, 0), name=name + '.branch3x3dbl_3b') ### same with paper - self.branch_pool0 = Pool2D(pool_size=3, + self.branch_pool0 = MaxPool2D(pool_size=3, pool_stride=1, - pool_padding=1, - pool_type='max') + pool_padding=1) self.branch_pool = ConvBNLayer(in_channels, 192, 1, @@ -707,23 +698,23 @@ class Fid_inceptionE_2(fluid.dygraph.Layer): branch3x3_1 = self.branch3x3_1(x) branch3x3_2a = self.branch3x3_2a(branch3x3_1) branch3x3_2b = self.branch3x3_2b(branch3x3_1) - branch3x3 = fluid.layers.concat([branch3x3_2a, branch3x3_2b], axis=1) + branch3x3 = paddle.concat([branch3x3_2a, branch3x3_2b], axis=1) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl_3a = self.branch3x3dbl_3a(branch3x3dbl) branch3x3dbl_3b = self.branch3x3dbl_3b(branch3x3dbl) - branch3x3dbl = fluid.layers.concat([branch3x3dbl_3a, branch3x3dbl_3b], + branch3x3dbl = paddle.concat([branch3x3dbl_3a, branch3x3dbl_3b], axis=1) branch_pool = self.branch_pool0(x) branch_pool = self.branch_pool(branch_pool) - return fluid.layers.concat( + return paddle.concat( [branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) -class ConvBNLayer(fluid.dygraph.Layer): +class ConvBNLayer(nn.Layer): def __init__(self, in_channels, num_filters, @@ -741,13 +732,13 @@ class ConvBNLayer(fluid.dygraph.Layer): padding=padding, groups=groups, act=None, - param_attr=ParamAttr(name=name + ".conv.weight"), + param_attr=paddle.ParamAttr(name=name + ".conv.weight"), bias_attr=False) self.bn = BatchNorm(num_filters, act=act, epsilon=0.001, - param_attr=ParamAttr(name=name + ".bn.weight"), - bias_attr=ParamAttr(name=name + ".bn.bias"), + param_attr=paddle.ParamAttr(name=name + ".bn.weight"), + bias_attr=paddle.ParamAttr(name=name + ".bn.bias"), moving_mean_name=name + '.bn.running_mean', moving_variance_name=name + '.bn.running_var') diff --git a/ppgan/models/drn_model.py b/ppgan/models/drn_model.py index f1c41a14be8dc07fb5e45bce590bfe7014ff9e52..ce44e9888c3c39b0c03482ad549e229652436f18 100644 --- a/ppgan/models/drn_model.py +++ b/ppgan/models/drn_model.py @@ -79,7 +79,7 @@ class DRN(BaseSRModel): self.gan_criterion = build_criterion(gan_criterion) def setup_input(self, input): - self.lq = paddle.fluid.dygraph.to_variable(input['lq']) + self.lq = paddle.to_tensor(input['lq']) self.visual_items['lq'] = self.lq if isinstance(self.scale, (list, tuple)) and len( @@ -87,7 +87,7 @@ class DRN(BaseSRModel): self.lqx2 = input['lqx2'] if 'gt' in input: - self.gt = paddle.fluid.dygraph.to_variable(input['gt']) + self.gt = paddle.to_tensor(input['gt']) self.visual_items['gt'] = self.gt self.image_paths = input['lq_path'] diff --git a/ppgan/models/generators/hook.py b/ppgan/models/generators/hook.py index ba1bcd4819096a4f7eb77a036a897ddf7122e3a2..e0c3055d7313259be6bcc2d81e4832312e76eecd 100644 --- a/ppgan/models/generators/hook.py +++ b/ppgan/models/generators/hook.py @@ -91,7 +91,7 @@ class Hooks(): def _hook_inner(m, i, o): return o if isinstance( - o, paddle.fluid.framework.Variable) else o if is_listy(o) else list(o) + o, paddle.static.Variable) else o if is_listy(o) else list(o) def hook_output(module, detach=True, grad=False): diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index bfb3d0b849933909fd84a851aa456cc50d45b83e..6aec7e30a512484cc86ba9b6896bd8d542c6848c 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -74,9 +74,9 @@ class Pix2PixModel(BaseModel): AtoB = self.direction == 'AtoB' - self.real_A = paddle.fluid.dygraph.to_variable( + self.real_A = paddle.to_tensor( input['A' if AtoB else 'B']) - self.real_B = paddle.fluid.dygraph.to_variable( + self.real_B = paddle.to_tensor( input['B' if AtoB else 'A']) self.image_paths = input['A_path' if AtoB else 'B_path'] diff --git a/ppgan/models/sr_model.py b/ppgan/models/sr_model.py index 565dc649f6a49d67e67b0ae6fdc5a25f25cf9d2e..5a48d869701b2c35e71fc2ddaa7e79d32ed7ce83 100644 --- a/ppgan/models/sr_model.py +++ b/ppgan/models/sr_model.py @@ -40,10 +40,10 @@ class BaseSRModel(BaseModel): self.pixel_criterion = build_criterion(pixel_criterion) def setup_input(self, input): - self.lq = paddle.fluid.dygraph.to_variable(input['lq']) + self.lq = paddle.to_tensor(input['lq']) self.visual_items['lq'] = self.lq if 'gt' in input: - self.gt = paddle.fluid.dygraph.to_variable(input['gt']) + self.gt = paddle.to_tensor(input['gt']) self.visual_items['gt'] = self.gt self.image_paths = input['lq_path'] diff --git a/ppgan/models/styleganv2_model.py b/ppgan/models/styleganv2_model.py index 1f10ed0393d9eacce79b1e03350bca7dc1da96eb..36071e04b69ce3998b7eac39db177b93ecb29f9f 100644 --- a/ppgan/models/styleganv2_model.py +++ b/ppgan/models/styleganv2_model.py @@ -180,7 +180,7 @@ class StyleGAN2Model(BaseModel): input (dict): include the data itself and its metadata information. """ - self.real_img = paddle.fluid.dygraph.to_variable(input['A']) + self.real_img = paddle.to_tensor(input['A']) def forward(self): """Run forward pass; called by both functions and .""" diff --git a/ppgan/utils/download.py b/ppgan/utils/download.py index 016358404180a29471da3cc55992f2becffd3867..69fbb3983ad9a22d7c5ebfa77ef75853cb3c8d5c 100644 --- a/ppgan/utils/download.py +++ b/ppgan/utils/download.py @@ -64,7 +64,7 @@ def get_path_from_url(url, md5sum=None, check_exist=True): str: a local path to save downloaded models & weights & datasets. """ - from paddle.fluid.dygraph.parallel import ParallelEnv + from paddle.distributed import ParallelEnv assert is_url(url), "downloading from {} not a url".format(url) root_dir = PPGAN_HOME diff --git a/ppgan/utils/filesystem.py b/ppgan/utils/filesystem.py index 43774dcc8c83d13cc2491c1064387f78a2880848..9b0ce88bfa5a230a06c7e32bd4d9634de6e56c71 100644 --- a/ppgan/utils/filesystem.py +++ b/ppgan/utils/filesystem.py @@ -34,7 +34,7 @@ def save(state_dicts, file_name): for k, v in state_dict.items(): if isinstance( v, - (paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)): + (paddle.static.Variable, paddle.Tensor)): model_dict[k] = v.numpy() else: model_dict[k] = v @@ -45,7 +45,7 @@ def save(state_dicts, file_name): for k, v in state_dicts.items(): if isinstance( v, - (paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)): + (paddle.static.Variable, paddle.Tensor)): final_dict = convert(state_dicts) break elif isinstance(v, dict):