未验证 提交 c0aed4f6 编写于 作者: W wangna11BD 提交者: GitHub

fix lapstyle runtime (#502)

上级 22208bf9
......@@ -8,7 +8,9 @@ import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--content_img", type=str, help="path to content image")
parser.add_argument("--content_img_path",
type=str,
help="path to content image")
parser.add_argument("--output_path",
type=str,
......@@ -31,7 +33,7 @@ if __name__ == "__main__":
parser.add_argument("--style_image_path",
type=str,
default=None,
help="if weight_path is not None, path to style image")
help="path to style image")
parser.add_argument("--cpu",
dest="cpu",
......@@ -45,6 +47,5 @@ if __name__ == "__main__":
predictor = LapStylePredictor(output=args.output_path,
style=args.style,
weight_path=args.weight_path,
style_image_path=args.style_image_path)
predictor.run(args.content_img)
weight_path=args.weight_path)
predictor.run(args.content_img_path, args.style_image_path)
......@@ -14,16 +14,19 @@ Artistic style transfer aims at migrating the style from an example image to a c
## 2 Quick experience
Here four style images:
| [StarryNew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | [Stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | [Ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | [Circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg)|
```
python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG}
python applications/tools/lapstyle.py --content_img_path ${PATH_OF_CONTENT_IMG} --style_image_path ${PATH_OF_STYLE_IMG}
```
### Parameters
- `--content_img (str)`: path to content image.
- `--content_img_path (str)`: path to content image.
- `--style_image_path (str)`: path to style image.
- `--output_path (str)`: path to output image dir, default value:`output_dir`.
- `--weight_path (str)`: path to model weight path, if `weight_path` is `None`, the pre-training model will be downloaded automatically, default value:`None`.
- `--style (str)`: style of output image, if `weight_path` is `None`, `style` can be chosen in `starrynew`, `circuit`, `ocean` and `stars`, default value:`starrynew`.
- `--style_image_path (str)`: path to style image, it need to input when `weight_path` is not `None`, default value:`None`.
## 3 How to use
......
......@@ -23,18 +23,21 @@ PaddleGAN为大家提供了四种不同艺术风格的预训练模型,风格
| :----------------------------------------------------------: | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| <img src='https://user-images.githubusercontent.com/48054808/130388598-1e2b27e7-be66-49df-84d5-57b4dc7730d6.png' width='300'/> | <img src='https://user-images.githubusercontent.com/48054808/130388606-78a3a682-2ae4-4753-a07c-671a46930de8.png' width='300'/> | <img src='https://user-images.githubusercontent.com/48054808/130388615-b04197b3-2fdf-4494-ad17-490afe0fd1cd.png' width='300'/> | <img src='https://user-images.githubusercontent.com/48054808/130388623-2eec0cca-fee1-47f0-8398-cae0171aa7a5.png' width='300'/> | <img src='https://user-images.githubusercontent.com/48054808/130388624-f27d0712-ba71-42b2-ada4-44bf60e36512.png' width='300'/> |
4个风格图像下载地址如下:
| [StarryNew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | [Stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | [Ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | [Circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg)|
只需运行下面的代码即可迁移至指定风格:
```
python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG}
python applications/tools/lapstyle.py --content_img_path ${PATH_OF_CONTENT_IMG} --style_image_path ${PATH_OF_STYLE_IMG}
```
### **参数**
- `--content_img (str)`: 输入的内容图像路径。
- `--content_img_path (str)`: 输入的内容图像路径。
- `--style_image_path (str)`: 输入的风格图像路径。
- `--output_path (str)`: 输出的图像路径,默认为`output_dir`
- `--weight_path (str)`: 模型权重路径,设置`None`时会自行下载预训练模型,默认为`None`
- `--style (str)`: 生成图像风格,当`weight_path``None`时,可以在`starrynew`, `circuit`, `ocean``stars`中选择,默认为`starrynew`
- `--style_image_path (str)`: 输入的风格图像路径,当`weight_path`不为`None`时需要输入,默认为`None`
## 3. 模型训练
......
......@@ -28,13 +28,9 @@ from ppgan.models.generators import DecoderNet, Encoder, RevisionNet
from .base_predictor import BasePredictor
LapStyle_circuit_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams'
circuit_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg'
LapStyle_ocean_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams'
ocean_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png'
LapStyle_starrynew_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams'
starrynew_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png'
LapStyle_stars_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams'
stars_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png'
def img(img):
......@@ -117,8 +113,7 @@ class LapStylePredictor(BasePredictor):
def __init__(self,
output='output_dir',
style='starrynew',
weight_path=None,
style_image_path=None):
weight_path=None):
self.input = input
self.output = os.path.join(output, 'LapStyle')
if not os.path.exists(self.output):
......@@ -127,31 +122,18 @@ class LapStylePredictor(BasePredictor):
self.net_dec = DecoderNet()
self.net_rev = RevisionNet()
self.net_rev_2 = RevisionNet()
if weight_path is None:
self.style_image_path = os.path.join(self.output, 'style.png')
if style == 'starrynew':
weight_path = get_path_from_url(LapStyle_starrynew_WEIGHT_URL)
urllib.request.urlretrieve(starrynew_IMG_URL,
filename=self.style_image_path)
elif style == 'circuit':
weight_path = get_path_from_url(LapStyle_circuit_WEIGHT_URL)
urllib.request.urlretrieve(circuit_IMG_URL,
filename=self.style_image_path)
elif style == 'ocean':
weight_path = get_path_from_url(LapStyle_ocean_WEIGHT_URL)
urllib.request.urlretrieve(ocean_IMG_URL,
filename=self.style_image_path)
elif style == 'stars':
weight_path = get_path_from_url(LapStyle_stars_WEIGHT_URL)
urllib.request.urlretrieve(stars_IMG_URL,
filename=self.style_image_path)
else:
raise Exception(f'has not implemented {style}.')
else:
if style_image_path is None:
raise Exception('style_image_path can not be None.')
else:
self.style_image_path = style_image_path
self.net_enc.set_dict(paddle.load(weight_path)['net_enc'])
self.net_enc.eval()
self.net_dec.set_dict(paddle.load(weight_path)['net_dec'])
......@@ -161,12 +143,15 @@ class LapStylePredictor(BasePredictor):
self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2'])
self.net_rev_2.eval()
def run(self, content_img_path):
def run(self, content_img_path, style_image_path):
content_img, style_img, h, w = img_read(content_img_path,
self.style_image_path)
style_image_path)
content_img_visual = tensor2img(content_img, min_max=(0., 1.))
content_img_visual = cv.cvtColor(content_img_visual, cv.COLOR_RGB2BGR)
cv.imwrite(os.path.join(self.output, 'content.png'), content_img_visual)
style_img_visual = tensor2img(style_img, min_max=(0., 1.))
style_img_visual = cv.cvtColor(style_img_visual, cv.COLOR_RGB2BGR)
cv.imwrite(os.path.join(self.output, 'style.png'), style_img_visual)
pyr_ci = make_laplace_pyramid(content_img, 2)
pyr_si = make_laplace_pyramid(style_img, 2)
pyr_ci.append(content_img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册