From 49e549ee73883952fb6f13824ad59163fe30467d Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Tue, 13 Jul 2021 12:11:37 +0800 Subject: [PATCH] add lapstyle predictor (#362) * add lapstyle predictor * add docs --- applications/tools/lapstyle.py | 50 ++++++++ docs/en_US/tutorials/lap_style.md | 24 +++- docs/zh_CN/tutorials/lap_style.md | 24 +++- ppgan/apps/__init__.py | 1 + ppgan/apps/lapstyle_predictor.py | 204 ++++++++++++++++++++++++++++++ 5 files changed, 291 insertions(+), 12 deletions(-) create mode 100644 applications/tools/lapstyle.py create mode 100644 ppgan/apps/lapstyle_predictor.py diff --git a/applications/tools/lapstyle.py b/applications/tools/lapstyle.py new file mode 100644 index 0000000..e4f0fcc --- /dev/null +++ b/applications/tools/lapstyle.py @@ -0,0 +1,50 @@ +import paddle +import os +import sys + +sys.path.insert(0, os.getcwd()) +from ppgan.apps import LapStylePredictor +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--content_img", type=str, help="path to content image") + + parser.add_argument("--output_path", + type=str, + default='output_dir', + help="path to output image dir") + + parser.add_argument("--weight_path", + type=str, + default=None, + help="path to model weight path") + + parser.add_argument( + "--style", + type=str, + default='starrynew', + help= + "if weight_path is None, style can be chosen in 'starrynew', 'circuit', 'ocean' and 'stars'" + ) + + parser.add_argument("--style_image_path", + type=str, + default=None, + help="if weight_path is not None, path to style image") + + parser.add_argument("--cpu", + dest="cpu", + action="store_true", + help="cpu mode.") + + args = parser.parse_args() + + if args.cpu: + paddle.set_device('cpu') + + predictor = LapStylePredictor(output=args.output_path, + style=args.style, + weight_path=args.weight_path, + style_image_path=args.style_image_path) + predictor.run(args.content_img) diff --git a/docs/en_US/tutorials/lap_style.md b/docs/en_US/tutorials/lap_style.md index 9f52dda..48bc99b 100644 --- a/docs/en_US/tutorials/lap_style.md +++ b/docs/en_US/tutorials/lap_style.md @@ -13,14 +13,26 @@ Artistic style transfer aims at migrating the style from an example image to a c ![lapstyle_overview](https://user-images.githubusercontent.com/79366697/118654987-b24dc100-b81b-11eb-9430-d84630f80511.png) -## 2 How to use +## 2 Quick experience +``` +python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG} +``` +### Parameters + +- `--content_img (str)`: path to content image. +- `--output_path (str)`: path to output image dir, default value:`output_dir`. +- `--weight_path (str)`: path to model weight path, if `weight_path` is `None`, the pre-training model will be downloaded automatically, default value:`None`. +- `--style (str)`: style of output image, if `weight_path` is `None`, `style` can be chosen in `starrynew`, `circuit`, `ocean` and `stars`, default value:`starrynew`. +- `--style_image_path (str)`: path to style image, it need to input when `weight_path` is not `None`, default value:`None`. + +## 3 How to use -### 2.1 Prepare Datasets +### 3.1 Prepare Datasets To train LapStyle, we use the COCO dataset as content set. And you can choose any style image you like. Before training or testing, remember modify the data path of style image in the config file. -### 2.2 Train +### 3.2 Train Datasets used in example is COCO, you can also change it to your own dataset in the config file. @@ -40,14 +52,14 @@ python -u tools/main.py --config-file configs/lapstyle_rev_first.yaml --load ${P python -u tools/main.py --config-file configs/lapstyle_rev_second.yaml --load ${PATH_OF_LAST_STAGE_WEIGHT} ``` -### 2.4 Test +### 3.4 Test To test the trained model, you can directly test the "lapstyle_rev_second", since it also contains the trained weight of previous stages: ``` python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-only --load ${PATH_OF_WEIGHT} ``` -## 3 Results +## 4 Results | Style | Stylized Results | | --- | --- | @@ -56,7 +68,7 @@ python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-o | ![stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | ![chicago_stylized_stars_512](https://user-images.githubusercontent.com/79366697/118655638-50da2200-b81c-11eb-9223-58d5df022fa5.png)| | ![circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg) | ![chicago_stylized_circuit](https://user-images.githubusercontent.com/79366697/118655660-56376c80-b81c-11eb-87f2-64ae5a82375c.png)| -## 4 Pre-trained models +## 5 Pre-trained models We also provide several trained models. diff --git a/docs/zh_CN/tutorials/lap_style.md b/docs/zh_CN/tutorials/lap_style.md index 339391c..9aff005 100644 --- a/docs/zh_CN/tutorials/lap_style.md +++ b/docs/zh_CN/tutorials/lap_style.md @@ -12,13 +12,25 @@ LapStyle首先通过绘图网络(Drafting Network)传输低分辨率的全 ![lapstyle_overview](https://user-images.githubusercontent.com/79366697/118654987-b24dc100-b81b-11eb-9430-d84630f80511.png) -## 2 如何使用 +## 2 快速体验 +``` +python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG} +``` +### **参数** + +- `--content_img (str)`: 输入的内容图像路径。 +- `--output_path (str)`: 输出的图像路径,默认为`output_dir`。 +- `--weight_path (str)`: 模型权重路径,设置`None`时会自行下载预训练模型,默认为`None`。 +- `--style (str)`: 生成图像风格,当`weight_path`为`None`时,可以在`starrynew`, `circuit`, `ocean` 和 `stars`中选择,默认为`starrynew`。 +- `--style_image_path (str)`: 输入的风格图像路径,当`weight_path`不为`None`时需要输入,默认为`None`。 + +## 3 如何使用 -### 2.1 数据准备 +### 3.1 数据准备 为了训练LapStyle,我们使用COCO数据集作为内容数据集。您可以任意选择您喜欢的风格图片。在开始训练与测试之前,记得修改配置文件的数据路径。 -### 2.2 训练 +### 3.2 训练 示例以COCO数据为例。如果您想使用自己的数据集,可以在配置文件中修改数据集为您自己的数据集。 @@ -37,14 +49,14 @@ python -u tools/main.py --config-file configs/lapstyle_rev_first.yaml --load ${P python -u tools/main.py --config-file configs/lapstyle_rev_second.yaml --load ${PATH_OF_LAST_STAGE_WEIGHT} ``` -### 2.4 测试 +### 3.4 测试 测试训练好的模型,您可以直接测试 "lapstyle_rev_second",因为它包含了之前步骤里的训练权重: ``` python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-only --load ${PATH_OF_WEIGHT} ``` -## 3 结果展示 +## 4 结果展示 | Style | Stylized Results | | --- | --- | @@ -54,7 +66,7 @@ python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-o | ![circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg) | ![chicago_stylized_circuit](https://user-images.githubusercontent.com/79366697/118655660-56376c80-b81c-11eb-87f2-64ae5a82375c.png)| -## 4 模型下载 +## 5 模型下载 我们提供几个训练好的权重。 diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index f0bfc1c..9df590a 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -26,3 +26,4 @@ from .styleganv2_predictor import StyleGANv2Predictor from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor from .wav2lip_predictor import Wav2LipPredictor from .mpr_predictor import MPRPredictor +from .lapstyle_predictor import LapStylePredictor diff --git a/ppgan/apps/lapstyle_predictor.py b/ppgan/apps/lapstyle_predictor.py new file mode 100644 index 0000000..5ea8cd4 --- /dev/null +++ b/ppgan/apps/lapstyle_predictor.py @@ -0,0 +1,204 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 as cv +import numpy as np +import urllib.request +from PIL import Image + +import paddle +import paddle.nn.functional as F +from paddle.vision.transforms import functional + +from ppgan.utils.download import get_path_from_url +from ppgan.utils.visual import tensor2img +from ppgan.models.generators import DecoderNet, Encoder, RevisionNet +from .base_predictor import BasePredictor + +LapStyle_circuit_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams' +circuit_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg' +LapStyle_ocean_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams' +ocean_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png' +LapStyle_starrynew_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams' +starrynew_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png' +LapStyle_stars_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams' +stars_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png' + + +def img(img): + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + # HWC to CHW + return img + + +def img_read(content_img_path, style_image_path): + content_img = cv.imread(content_img_path) + if content_img.ndim == 2: + content_img = cv.cvtColor(content_img, cv.COLOR_GRAY2RGB) + else: + content_img = cv.cvtColor(content_img, cv.COLOR_BGR2RGB) + content_img = Image.fromarray(content_img) + content_img = content_img.resize((512, 512), Image.BILINEAR) + content_img = np.array(content_img) + content_img = img(content_img) + content_img = functional.to_tensor(content_img) + + style_img = cv.imread(style_image_path) + style_img = cv.cvtColor(style_img, cv.COLOR_BGR2RGB) + style_img = Image.fromarray(style_img) + style_img = style_img.resize((512, 512), Image.BILINEAR) + style_img = np.array(style_img) + style_img = img(style_img) + style_img = functional.to_tensor(style_img) + + content_img = paddle.unsqueeze(content_img, axis=0) + style_img = paddle.unsqueeze(style_img, axis=0) + return content_img, style_img + + +def tensor_resample(tensor, dst_size, mode='bilinear'): + return F.interpolate(tensor, dst_size, mode=mode, align_corners=False) + + +def laplacian(x): + """ + Laplacian + + return: + x - upsample(downsample(x)) + """ + return x - tensor_resample( + tensor_resample(x, [x.shape[2] // 2, x.shape[3] // 2]), + [x.shape[2], x.shape[3]]) + + +def make_laplace_pyramid(x, levels): + """ + Make Laplacian Pyramid + """ + pyramid = [] + current = x + for i in range(levels): + pyramid.append(laplacian(current)) + current = tensor_resample( + current, + (max(current.shape[2] // 2, 1), max(current.shape[3] // 2, 1))) + pyramid.append(current) + return pyramid + + +def fold_laplace_pyramid(pyramid): + """ + Fold Laplacian Pyramid + """ + current = pyramid[-1] + for i in range(len(pyramid) - 2, -1, -1): # iterate from len-2 to 0 + up_h, up_w = pyramid[i].shape[2], pyramid[i].shape[3] + current = pyramid[i] + tensor_resample(current, (up_h, up_w)) + return current + + +class LapStylePredictor(BasePredictor): + def __init__(self, + output='output_dir', + style='starrynew', + weight_path=None, + style_image_path=None): + self.input = input + self.output = os.path.join(output, 'LapStyle') + if not os.path.exists(self.output): + os.makedirs(self.output) + self.net_enc = Encoder() + self.net_dec = DecoderNet() + self.net_rev = RevisionNet() + self.net_rev_2 = RevisionNet() + if weight_path is None: + self.style_image_path = os.path.join(self.output, 'style.png') + if style == 'starrynew': + weight_path = get_path_from_url(LapStyle_starrynew_WEIGHT_URL) + urllib.request.urlretrieve(starrynew_IMG_URL, + filename=self.style_image_path) + elif style == 'circuit': + weight_path = get_path_from_url(LapStyle_circuit_WEIGHT_URL) + urllib.request.urlretrieve(circuit_IMG_URL, + filename=self.style_image_path) + elif style == 'ocean': + weight_path = get_path_from_url(LapStyle_ocean_WEIGHT_URL) + urllib.request.urlretrieve(ocean_IMG_URL, + filename=self.style_image_path) + elif style == 'stars': + weight_path = get_path_from_url(LapStyle_stars_WEIGHT_URL) + urllib.request.urlretrieve(stars_IMG_URL, + filename=self.style_image_path) + else: + raise Exception(f'has not implemented {style}.') + else: + if style_image_path is None: + raise Exception('style_image_path can not be None.') + else: + self.style_image_path = style_image_path + self.net_enc.set_dict(paddle.load(weight_path)['net_enc']) + self.net_enc.eval() + self.net_dec.set_dict(paddle.load(weight_path)['net_dec']) + self.net_dec.eval() + self.net_rev.set_dict(paddle.load(weight_path)['net_rev']) + self.net_rev.eval() + self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2']) + self.net_rev_2.eval() + + def run(self, content_img_path): + content_img, style_img = img_read(content_img_path, + self.style_image_path) + content_img_visual = tensor2img(content_img, min_max=(0., 1.)) + content_img_visual = cv.cvtColor(content_img_visual, cv.COLOR_RGB2BGR) + cv.imwrite(os.path.join(self.output, 'content.png'), content_img_visual) + pyr_ci = make_laplace_pyramid(content_img, 2) + pyr_si = make_laplace_pyramid(style_img, 2) + pyr_ci.append(content_img) + pyr_si.append(style_img) + cF = self.net_enc(pyr_ci[2]) + sF = self.net_enc(pyr_si[2]) + stylized_small = self.net_dec(cF, sF) + stylized_small_visual = tensor2img(stylized_small, min_max=(0., 1.)) + stylized_small_visual = cv.cvtColor(stylized_small_visual, + cv.COLOR_RGB2BGR) + cv.imwrite(os.path.join(self.output, 'stylized_small.png'), + stylized_small_visual) + stylized_up = F.interpolate(stylized_small, scale_factor=2) + + revnet_input = paddle.concat(x=[pyr_ci[1], stylized_up], axis=1) + stylized_rev_lap = self.net_rev(revnet_input) + stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small]) + stylized_rev_visual = tensor2img(stylized_rev, min_max=(0., 1.)) + stylized_rev_visual = cv.cvtColor(stylized_rev_visual, cv.COLOR_RGB2BGR) + cv.imwrite(os.path.join(self.output, 'stylized_rev_first.png'), + stylized_rev_visual) + stylized_up = F.interpolate(stylized_rev, scale_factor=2) + + revnet_input = paddle.concat(x=[pyr_ci[0], stylized_up], axis=1) + stylized_rev_lap_second = self.net_rev_2(revnet_input) + stylized_rev_second = fold_laplace_pyramid( + [stylized_rev_lap_second, stylized_rev_lap, stylized_small]) + + stylized = stylized_rev_second + stylized_visual = tensor2img(stylized, min_max=(0., 1.)) + stylized_visual = cv.cvtColor(stylized_visual, cv.COLOR_RGB2BGR) + cv.imwrite(os.path.join(self.output, 'stylized.png'), stylized_visual) + + print('Model LapStyle output images path:', self.output) + + return stylized -- GitLab