未验证 提交 49e549ee 编写于 作者: W wangna11BD 提交者: GitHub

add lapstyle predictor (#362)

* add lapstyle predictor

* add docs
上级 c7d7c4f5
import paddle
import os
import sys
sys.path.insert(0, os.getcwd())
from ppgan.apps import LapStylePredictor
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--content_img", type=str, help="path to content image")
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model weight path")
parser.add_argument(
"--style",
type=str,
default='starrynew',
help=
"if weight_path is None, style can be chosen in 'starrynew', 'circuit', 'ocean' and 'stars'"
)
parser.add_argument("--style_image_path",
type=str,
default=None,
help="if weight_path is not None, path to style image")
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = LapStylePredictor(output=args.output_path,
style=args.style,
weight_path=args.weight_path,
style_image_path=args.style_image_path)
predictor.run(args.content_img)
......@@ -13,14 +13,26 @@ Artistic style transfer aims at migrating the style from an example image to a c
![lapstyle_overview](https://user-images.githubusercontent.com/79366697/118654987-b24dc100-b81b-11eb-9430-d84630f80511.png)
## 2 How to use
## 2 Quick experience
```
python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG}
```
### Parameters
- `--content_img (str)`: path to content image.
- `--output_path (str)`: path to output image dir, default value:`output_dir`.
- `--weight_path (str)`: path to model weight path, if `weight_path` is `None`, the pre-training model will be downloaded automatically, default value:`None`.
- `--style (str)`: style of output image, if `weight_path` is `None`, `style` can be chosen in `starrynew`, `circuit`, `ocean` and `stars`, default value:`starrynew`.
- `--style_image_path (str)`: path to style image, it need to input when `weight_path` is not `None`, default value:`None`.
## 3 How to use
### 2.1 Prepare Datasets
### 3.1 Prepare Datasets
To train LapStyle, we use the COCO dataset as content set. And you can choose any style image you like. Before training or testing, remember modify the data path of style image in the config file.
### 2.2 Train
### 3.2 Train
Datasets used in example is COCO, you can also change it to your own dataset in the config file.
......@@ -40,14 +52,14 @@ python -u tools/main.py --config-file configs/lapstyle_rev_first.yaml --load ${P
python -u tools/main.py --config-file configs/lapstyle_rev_second.yaml --load ${PATH_OF_LAST_STAGE_WEIGHT}
```
### 2.4 Test
### 3.4 Test
To test the trained model, you can directly test the "lapstyle_rev_second", since it also contains the trained weight of previous stages:
```
python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 Results
## 4 Results
| Style | Stylized Results |
| --- | --- |
......@@ -56,7 +68,7 @@ python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-o
| ![stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | ![chicago_stylized_stars_512](https://user-images.githubusercontent.com/79366697/118655638-50da2200-b81c-11eb-9223-58d5df022fa5.png)|
| ![circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg) | ![chicago_stylized_circuit](https://user-images.githubusercontent.com/79366697/118655660-56376c80-b81c-11eb-87f2-64ae5a82375c.png)|
## 4 Pre-trained models
## 5 Pre-trained models
We also provide several trained models.
......
......@@ -12,13 +12,25 @@ LapStyle首先通过绘图网络(Drafting Network)传输低分辨率的全
![lapstyle_overview](https://user-images.githubusercontent.com/79366697/118654987-b24dc100-b81b-11eb-9430-d84630f80511.png)
## 2 如何使用
## 2 快速体验
```
python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG}
```
### **参数**
- `--content_img (str)`: 输入的内容图像路径。
- `--output_path (str)`: 输出的图像路径,默认为`output_dir`
- `--weight_path (str)`: 模型权重路径,设置`None`时会自行下载预训练模型,默认为`None`
- `--style (str)`: 生成图像风格,当`weight_path``None`时,可以在`starrynew`, `circuit`, `ocean``stars`中选择,默认为`starrynew`
- `--style_image_path (str)`: 输入的风格图像路径,当`weight_path`不为`None`时需要输入,默认为`None`
## 3 如何使用
### 2.1 数据准备
### 3.1 数据准备
为了训练LapStyle,我们使用COCO数据集作为内容数据集。您可以任意选择您喜欢的风格图片。在开始训练与测试之前,记得修改配置文件的数据路径。
### 2.2 训练
### 3.2 训练
示例以COCO数据为例。如果您想使用自己的数据集,可以在配置文件中修改数据集为您自己的数据集。
......@@ -37,14 +49,14 @@ python -u tools/main.py --config-file configs/lapstyle_rev_first.yaml --load ${P
python -u tools/main.py --config-file configs/lapstyle_rev_second.yaml --load ${PATH_OF_LAST_STAGE_WEIGHT}
```
### 2.4 测试
### 3.4 测试
测试训练好的模型,您可以直接测试 "lapstyle_rev_second",因为它包含了之前步骤里的训练权重:
```
python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 结果展示
## 4 结果展示
| Style | Stylized Results |
| --- | --- |
......@@ -54,7 +66,7 @@ python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-o
| ![circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg) | ![chicago_stylized_circuit](https://user-images.githubusercontent.com/79366697/118655660-56376c80-b81c-11eb-87f2-64ae5a82375c.png)|
## 4 模型下载
## 5 模型下载
我们提供几个训练好的权重。
......
......@@ -26,3 +26,4 @@ from .styleganv2_predictor import StyleGANv2Predictor
from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor
from .wav2lip_predictor import Wav2LipPredictor
from .mpr_predictor import MPRPredictor
from .lapstyle_predictor import LapStylePredictor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2 as cv
import numpy as np
import urllib.request
from PIL import Image
import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import functional
from ppgan.utils.download import get_path_from_url
from ppgan.utils.visual import tensor2img
from ppgan.models.generators import DecoderNet, Encoder, RevisionNet
from .base_predictor import BasePredictor
LapStyle_circuit_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams'
circuit_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg'
LapStyle_ocean_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams'
ocean_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png'
LapStyle_starrynew_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams'
starrynew_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png'
LapStyle_stars_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams'
stars_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png'
def img(img):
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
# HWC to CHW
return img
def img_read(content_img_path, style_image_path):
content_img = cv.imread(content_img_path)
if content_img.ndim == 2:
content_img = cv.cvtColor(content_img, cv.COLOR_GRAY2RGB)
else:
content_img = cv.cvtColor(content_img, cv.COLOR_BGR2RGB)
content_img = Image.fromarray(content_img)
content_img = content_img.resize((512, 512), Image.BILINEAR)
content_img = np.array(content_img)
content_img = img(content_img)
content_img = functional.to_tensor(content_img)
style_img = cv.imread(style_image_path)
style_img = cv.cvtColor(style_img, cv.COLOR_BGR2RGB)
style_img = Image.fromarray(style_img)
style_img = style_img.resize((512, 512), Image.BILINEAR)
style_img = np.array(style_img)
style_img = img(style_img)
style_img = functional.to_tensor(style_img)
content_img = paddle.unsqueeze(content_img, axis=0)
style_img = paddle.unsqueeze(style_img, axis=0)
return content_img, style_img
def tensor_resample(tensor, dst_size, mode='bilinear'):
return F.interpolate(tensor, dst_size, mode=mode, align_corners=False)
def laplacian(x):
"""
Laplacian
return:
x - upsample(downsample(x))
"""
return x - tensor_resample(
tensor_resample(x, [x.shape[2] // 2, x.shape[3] // 2]),
[x.shape[2], x.shape[3]])
def make_laplace_pyramid(x, levels):
"""
Make Laplacian Pyramid
"""
pyramid = []
current = x
for i in range(levels):
pyramid.append(laplacian(current))
current = tensor_resample(
current,
(max(current.shape[2] // 2, 1), max(current.shape[3] // 2, 1)))
pyramid.append(current)
return pyramid
def fold_laplace_pyramid(pyramid):
"""
Fold Laplacian Pyramid
"""
current = pyramid[-1]
for i in range(len(pyramid) - 2, -1, -1): # iterate from len-2 to 0
up_h, up_w = pyramid[i].shape[2], pyramid[i].shape[3]
current = pyramid[i] + tensor_resample(current, (up_h, up_w))
return current
class LapStylePredictor(BasePredictor):
def __init__(self,
output='output_dir',
style='starrynew',
weight_path=None,
style_image_path=None):
self.input = input
self.output = os.path.join(output, 'LapStyle')
if not os.path.exists(self.output):
os.makedirs(self.output)
self.net_enc = Encoder()
self.net_dec = DecoderNet()
self.net_rev = RevisionNet()
self.net_rev_2 = RevisionNet()
if weight_path is None:
self.style_image_path = os.path.join(self.output, 'style.png')
if style == 'starrynew':
weight_path = get_path_from_url(LapStyle_starrynew_WEIGHT_URL)
urllib.request.urlretrieve(starrynew_IMG_URL,
filename=self.style_image_path)
elif style == 'circuit':
weight_path = get_path_from_url(LapStyle_circuit_WEIGHT_URL)
urllib.request.urlretrieve(circuit_IMG_URL,
filename=self.style_image_path)
elif style == 'ocean':
weight_path = get_path_from_url(LapStyle_ocean_WEIGHT_URL)
urllib.request.urlretrieve(ocean_IMG_URL,
filename=self.style_image_path)
elif style == 'stars':
weight_path = get_path_from_url(LapStyle_stars_WEIGHT_URL)
urllib.request.urlretrieve(stars_IMG_URL,
filename=self.style_image_path)
else:
raise Exception(f'has not implemented {style}.')
else:
if style_image_path is None:
raise Exception('style_image_path can not be None.')
else:
self.style_image_path = style_image_path
self.net_enc.set_dict(paddle.load(weight_path)['net_enc'])
self.net_enc.eval()
self.net_dec.set_dict(paddle.load(weight_path)['net_dec'])
self.net_dec.eval()
self.net_rev.set_dict(paddle.load(weight_path)['net_rev'])
self.net_rev.eval()
self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2'])
self.net_rev_2.eval()
def run(self, content_img_path):
content_img, style_img = img_read(content_img_path,
self.style_image_path)
content_img_visual = tensor2img(content_img, min_max=(0., 1.))
content_img_visual = cv.cvtColor(content_img_visual, cv.COLOR_RGB2BGR)
cv.imwrite(os.path.join(self.output, 'content.png'), content_img_visual)
pyr_ci = make_laplace_pyramid(content_img, 2)
pyr_si = make_laplace_pyramid(style_img, 2)
pyr_ci.append(content_img)
pyr_si.append(style_img)
cF = self.net_enc(pyr_ci[2])
sF = self.net_enc(pyr_si[2])
stylized_small = self.net_dec(cF, sF)
stylized_small_visual = tensor2img(stylized_small, min_max=(0., 1.))
stylized_small_visual = cv.cvtColor(stylized_small_visual,
cv.COLOR_RGB2BGR)
cv.imwrite(os.path.join(self.output, 'stylized_small.png'),
stylized_small_visual)
stylized_up = F.interpolate(stylized_small, scale_factor=2)
revnet_input = paddle.concat(x=[pyr_ci[1], stylized_up], axis=1)
stylized_rev_lap = self.net_rev(revnet_input)
stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small])
stylized_rev_visual = tensor2img(stylized_rev, min_max=(0., 1.))
stylized_rev_visual = cv.cvtColor(stylized_rev_visual, cv.COLOR_RGB2BGR)
cv.imwrite(os.path.join(self.output, 'stylized_rev_first.png'),
stylized_rev_visual)
stylized_up = F.interpolate(stylized_rev, scale_factor=2)
revnet_input = paddle.concat(x=[pyr_ci[0], stylized_up], axis=1)
stylized_rev_lap_second = self.net_rev_2(revnet_input)
stylized_rev_second = fold_laplace_pyramid(
[stylized_rev_lap_second, stylized_rev_lap, stylized_small])
stylized = stylized_rev_second
stylized_visual = tensor2img(stylized, min_max=(0., 1.))
stylized_visual = cv.cvtColor(stylized_visual, cv.COLOR_RGB2BGR)
cv.imwrite(os.path.join(self.output, 'stylized.png'), stylized_visual)
print('Model LapStyle output images path:', self.output)
return stylized
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册