From c0aed4f618c77b40767ed8898316579503838e4f Mon Sep 17 00:00:00 2001
From: wangna11BD <79366697+wangna11BD@users.noreply.github.com>
Date: Tue, 30 Nov 2021 09:33:38 +0800
Subject: [PATCH] fix lapstyle runtime (#502)
---
applications/tools/lapstyle.py | 11 ++++++-----
docs/en_US/tutorials/lap_style.md | 9 ++++++---
docs/zh_CN/tutorials/lap_style.md | 9 ++++++---
ppgan/apps/lapstyle_predictor.py | 29 +++++++----------------------
4 files changed, 25 insertions(+), 33 deletions(-)
diff --git a/applications/tools/lapstyle.py b/applications/tools/lapstyle.py
index e4f0fcc..f44f4e3 100644
--- a/applications/tools/lapstyle.py
+++ b/applications/tools/lapstyle.py
@@ -8,7 +8,9 @@ import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--content_img", type=str, help="path to content image")
+ parser.add_argument("--content_img_path",
+ type=str,
+ help="path to content image")
parser.add_argument("--output_path",
type=str,
@@ -31,7 +33,7 @@ if __name__ == "__main__":
parser.add_argument("--style_image_path",
type=str,
default=None,
- help="if weight_path is not None, path to style image")
+ help="path to style image")
parser.add_argument("--cpu",
dest="cpu",
@@ -45,6 +47,5 @@ if __name__ == "__main__":
predictor = LapStylePredictor(output=args.output_path,
style=args.style,
- weight_path=args.weight_path,
- style_image_path=args.style_image_path)
- predictor.run(args.content_img)
+ weight_path=args.weight_path)
+ predictor.run(args.content_img_path, args.style_image_path)
diff --git a/docs/en_US/tutorials/lap_style.md b/docs/en_US/tutorials/lap_style.md
index 163a5ef..9449d1c 100644
--- a/docs/en_US/tutorials/lap_style.md
+++ b/docs/en_US/tutorials/lap_style.md
@@ -14,16 +14,19 @@ Artistic style transfer aims at migrating the style from an example image to a c
## 2 Quick experience
+Here four style images:
+| [StarryNew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | [Stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | [Ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | [Circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg)|
+
```
-python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG}
+python applications/tools/lapstyle.py --content_img_path ${PATH_OF_CONTENT_IMG} --style_image_path ${PATH_OF_STYLE_IMG}
```
### Parameters
-- `--content_img (str)`: path to content image.
+- `--content_img_path (str)`: path to content image.
+- `--style_image_path (str)`: path to style image.
- `--output_path (str)`: path to output image dir, default value:`output_dir`.
- `--weight_path (str)`: path to model weight path, if `weight_path` is `None`, the pre-training model will be downloaded automatically, default value:`None`.
- `--style (str)`: style of output image, if `weight_path` is `None`, `style` can be chosen in `starrynew`, `circuit`, `ocean` and `stars`, default value:`starrynew`.
-- `--style_image_path (str)`: path to style image, it need to input when `weight_path` is not `None`, default value:`None`.
## 3 How to use
diff --git a/docs/zh_CN/tutorials/lap_style.md b/docs/zh_CN/tutorials/lap_style.md
index d57389f..7744ebb 100644
--- a/docs/zh_CN/tutorials/lap_style.md
+++ b/docs/zh_CN/tutorials/lap_style.md
@@ -23,18 +23,21 @@ PaddleGAN为大家提供了四种不同艺术风格的预训练模型,风格
| :----------------------------------------------------------: | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| | | | | |
+4个风格图像下载地址如下:
+| [StarryNew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | [Stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | [Ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | [Circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg)|
+
只需运行下面的代码即可迁移至指定风格:
```
-python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG}
+python applications/tools/lapstyle.py --content_img_path ${PATH_OF_CONTENT_IMG} --style_image_path ${PATH_OF_STYLE_IMG}
```
### **参数**
-- `--content_img (str)`: 输入的内容图像路径。
+- `--content_img_path (str)`: 输入的内容图像路径。
+- `--style_image_path (str)`: 输入的风格图像路径。
- `--output_path (str)`: 输出的图像路径,默认为`output_dir`。
- `--weight_path (str)`: 模型权重路径,设置`None`时会自行下载预训练模型,默认为`None`。
- `--style (str)`: 生成图像风格,当`weight_path`为`None`时,可以在`starrynew`, `circuit`, `ocean` 和 `stars`中选择,默认为`starrynew`。
-- `--style_image_path (str)`: 输入的风格图像路径,当`weight_path`不为`None`时需要输入,默认为`None`。
## 3. 模型训练
diff --git a/ppgan/apps/lapstyle_predictor.py b/ppgan/apps/lapstyle_predictor.py
index 828f752..36ae6f0 100644
--- a/ppgan/apps/lapstyle_predictor.py
+++ b/ppgan/apps/lapstyle_predictor.py
@@ -28,13 +28,9 @@ from ppgan.models.generators import DecoderNet, Encoder, RevisionNet
from .base_predictor import BasePredictor
LapStyle_circuit_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams'
-circuit_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg'
LapStyle_ocean_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams'
-ocean_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png'
LapStyle_starrynew_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams'
-starrynew_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png'
LapStyle_stars_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams'
-stars_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png'
def img(img):
@@ -117,8 +113,7 @@ class LapStylePredictor(BasePredictor):
def __init__(self,
output='output_dir',
style='starrynew',
- weight_path=None,
- style_image_path=None):
+ weight_path=None):
self.input = input
self.output = os.path.join(output, 'LapStyle')
if not os.path.exists(self.output):
@@ -127,31 +122,18 @@ class LapStylePredictor(BasePredictor):
self.net_dec = DecoderNet()
self.net_rev = RevisionNet()
self.net_rev_2 = RevisionNet()
+
if weight_path is None:
- self.style_image_path = os.path.join(self.output, 'style.png')
if style == 'starrynew':
weight_path = get_path_from_url(LapStyle_starrynew_WEIGHT_URL)
- urllib.request.urlretrieve(starrynew_IMG_URL,
- filename=self.style_image_path)
elif style == 'circuit':
weight_path = get_path_from_url(LapStyle_circuit_WEIGHT_URL)
- urllib.request.urlretrieve(circuit_IMG_URL,
- filename=self.style_image_path)
elif style == 'ocean':
weight_path = get_path_from_url(LapStyle_ocean_WEIGHT_URL)
- urllib.request.urlretrieve(ocean_IMG_URL,
- filename=self.style_image_path)
elif style == 'stars':
weight_path = get_path_from_url(LapStyle_stars_WEIGHT_URL)
- urllib.request.urlretrieve(stars_IMG_URL,
- filename=self.style_image_path)
else:
raise Exception(f'has not implemented {style}.')
- else:
- if style_image_path is None:
- raise Exception('style_image_path can not be None.')
- else:
- self.style_image_path = style_image_path
self.net_enc.set_dict(paddle.load(weight_path)['net_enc'])
self.net_enc.eval()
self.net_dec.set_dict(paddle.load(weight_path)['net_dec'])
@@ -161,12 +143,15 @@ class LapStylePredictor(BasePredictor):
self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2'])
self.net_rev_2.eval()
- def run(self, content_img_path):
+ def run(self, content_img_path, style_image_path):
content_img, style_img, h, w = img_read(content_img_path,
- self.style_image_path)
+ style_image_path)
content_img_visual = tensor2img(content_img, min_max=(0., 1.))
content_img_visual = cv.cvtColor(content_img_visual, cv.COLOR_RGB2BGR)
cv.imwrite(os.path.join(self.output, 'content.png'), content_img_visual)
+ style_img_visual = tensor2img(style_img, min_max=(0., 1.))
+ style_img_visual = cv.cvtColor(style_img_visual, cv.COLOR_RGB2BGR)
+ cv.imwrite(os.path.join(self.output, 'style.png'), style_img_visual)
pyr_ci = make_laplace_pyramid(content_img, 2)
pyr_si = make_laplace_pyramid(style_img, 2)
pyr_ci.append(content_img)
--
GitLab