diff --git a/applications/tools/lapstyle.py b/applications/tools/lapstyle.py index e4f0fccd84f1a1024398fc2e6464402aa33b78b0..f44f4e3b4b2328720720e326f6ee2517e797d2d5 100644 --- a/applications/tools/lapstyle.py +++ b/applications/tools/lapstyle.py @@ -8,7 +8,9 @@ import argparse if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--content_img", type=str, help="path to content image") + parser.add_argument("--content_img_path", + type=str, + help="path to content image") parser.add_argument("--output_path", type=str, @@ -31,7 +33,7 @@ if __name__ == "__main__": parser.add_argument("--style_image_path", type=str, default=None, - help="if weight_path is not None, path to style image") + help="path to style image") parser.add_argument("--cpu", dest="cpu", @@ -45,6 +47,5 @@ if __name__ == "__main__": predictor = LapStylePredictor(output=args.output_path, style=args.style, - weight_path=args.weight_path, - style_image_path=args.style_image_path) - predictor.run(args.content_img) + weight_path=args.weight_path) + predictor.run(args.content_img_path, args.style_image_path) diff --git a/docs/en_US/tutorials/lap_style.md b/docs/en_US/tutorials/lap_style.md index 163a5efbee4d2c19d9ce5819fad9486a436c55b4..9449d1c0e95b8123ce3f24a041c05cc740fb00a2 100644 --- a/docs/en_US/tutorials/lap_style.md +++ b/docs/en_US/tutorials/lap_style.md @@ -14,16 +14,19 @@ Artistic style transfer aims at migrating the style from an example image to a c ## 2 Quick experience +Here four style images: +| [StarryNew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | [Stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | [Ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | [Circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg)| + ``` -python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG} +python applications/tools/lapstyle.py --content_img_path ${PATH_OF_CONTENT_IMG} --style_image_path ${PATH_OF_STYLE_IMG} ``` ### Parameters -- `--content_img (str)`: path to content image. +- `--content_img_path (str)`: path to content image. +- `--style_image_path (str)`: path to style image. - `--output_path (str)`: path to output image dir, default value:`output_dir`. - `--weight_path (str)`: path to model weight path, if `weight_path` is `None`, the pre-training model will be downloaded automatically, default value:`None`. - `--style (str)`: style of output image, if `weight_path` is `None`, `style` can be chosen in `starrynew`, `circuit`, `ocean` and `stars`, default value:`starrynew`. -- `--style_image_path (str)`: path to style image, it need to input when `weight_path` is not `None`, default value:`None`. ## 3 How to use diff --git a/docs/zh_CN/tutorials/lap_style.md b/docs/zh_CN/tutorials/lap_style.md index d57389f01184eec79786c0def32a508cf7dbfeb0..7744ebb2a47cabf21d7778665ed7c48587775665 100644 --- a/docs/zh_CN/tutorials/lap_style.md +++ b/docs/zh_CN/tutorials/lap_style.md @@ -23,18 +23,21 @@ PaddleGAN为大家提供了四种不同艺术风格的预训练模型,风格 | :----------------------------------------------------------: | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | | | | | | | +4个风格图像下载地址如下: +| [StarryNew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | [Stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | [Ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | [Circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg)| + 只需运行下面的代码即可迁移至指定风格: ``` -python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG} +python applications/tools/lapstyle.py --content_img_path ${PATH_OF_CONTENT_IMG} --style_image_path ${PATH_OF_STYLE_IMG} ``` ### **参数** -- `--content_img (str)`: 输入的内容图像路径。 +- `--content_img_path (str)`: 输入的内容图像路径。 +- `--style_image_path (str)`: 输入的风格图像路径。 - `--output_path (str)`: 输出的图像路径,默认为`output_dir`。 - `--weight_path (str)`: 模型权重路径,设置`None`时会自行下载预训练模型,默认为`None`。 - `--style (str)`: 生成图像风格,当`weight_path`为`None`时,可以在`starrynew`, `circuit`, `ocean` 和 `stars`中选择,默认为`starrynew`。 -- `--style_image_path (str)`: 输入的风格图像路径,当`weight_path`不为`None`时需要输入,默认为`None`。 ## 3. 模型训练 diff --git a/ppgan/apps/lapstyle_predictor.py b/ppgan/apps/lapstyle_predictor.py index 828f7521fc072fab769cc6e9c698da4a6cb30fbf..36ae6f00811809d82189641b55bcc523d3ea727f 100644 --- a/ppgan/apps/lapstyle_predictor.py +++ b/ppgan/apps/lapstyle_predictor.py @@ -28,13 +28,9 @@ from ppgan.models.generators import DecoderNet, Encoder, RevisionNet from .base_predictor import BasePredictor LapStyle_circuit_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams' -circuit_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg' LapStyle_ocean_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams' -ocean_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png' LapStyle_starrynew_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams' -starrynew_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png' LapStyle_stars_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams' -stars_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png' def img(img): @@ -117,8 +113,7 @@ class LapStylePredictor(BasePredictor): def __init__(self, output='output_dir', style='starrynew', - weight_path=None, - style_image_path=None): + weight_path=None): self.input = input self.output = os.path.join(output, 'LapStyle') if not os.path.exists(self.output): @@ -127,31 +122,18 @@ class LapStylePredictor(BasePredictor): self.net_dec = DecoderNet() self.net_rev = RevisionNet() self.net_rev_2 = RevisionNet() + if weight_path is None: - self.style_image_path = os.path.join(self.output, 'style.png') if style == 'starrynew': weight_path = get_path_from_url(LapStyle_starrynew_WEIGHT_URL) - urllib.request.urlretrieve(starrynew_IMG_URL, - filename=self.style_image_path) elif style == 'circuit': weight_path = get_path_from_url(LapStyle_circuit_WEIGHT_URL) - urllib.request.urlretrieve(circuit_IMG_URL, - filename=self.style_image_path) elif style == 'ocean': weight_path = get_path_from_url(LapStyle_ocean_WEIGHT_URL) - urllib.request.urlretrieve(ocean_IMG_URL, - filename=self.style_image_path) elif style == 'stars': weight_path = get_path_from_url(LapStyle_stars_WEIGHT_URL) - urllib.request.urlretrieve(stars_IMG_URL, - filename=self.style_image_path) else: raise Exception(f'has not implemented {style}.') - else: - if style_image_path is None: - raise Exception('style_image_path can not be None.') - else: - self.style_image_path = style_image_path self.net_enc.set_dict(paddle.load(weight_path)['net_enc']) self.net_enc.eval() self.net_dec.set_dict(paddle.load(weight_path)['net_dec']) @@ -161,12 +143,15 @@ class LapStylePredictor(BasePredictor): self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2']) self.net_rev_2.eval() - def run(self, content_img_path): + def run(self, content_img_path, style_image_path): content_img, style_img, h, w = img_read(content_img_path, - self.style_image_path) + style_image_path) content_img_visual = tensor2img(content_img, min_max=(0., 1.)) content_img_visual = cv.cvtColor(content_img_visual, cv.COLOR_RGB2BGR) cv.imwrite(os.path.join(self.output, 'content.png'), content_img_visual) + style_img_visual = tensor2img(style_img, min_max=(0., 1.)) + style_img_visual = cv.cvtColor(style_img_visual, cv.COLOR_RGB2BGR) + cv.imwrite(os.path.join(self.output, 'style.png'), style_img_visual) pyr_ci = make_laplace_pyramid(content_img, 2) pyr_si = make_laplace_pyramid(style_img, 2) pyr_ci.append(content_img)