Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
abab2ad2
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
abab2ad2
编写于
3月 14, 2022
作者:
K
KP
提交者:
GitHub
3月 14, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1718 from rainyfly/add_painttransformer_module
add painttransformer module
上级
f2efef7a
9fda0947
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
1083 addition
and
0 deletion
+1083
-0
modules/image/Image_gan/style_transfer/paint_transformer/README.md
...mage/Image_gan/style_transfer/paint_transformer/README.md
+134
-0
modules/image/Image_gan/style_transfer/paint_transformer/inference.py
...e/Image_gan/style_transfer/paint_transformer/inference.py
+72
-0
modules/image/Image_gan/style_transfer/paint_transformer/model.py
...image/Image_gan/style_transfer/paint_transformer/model.py
+68
-0
modules/image/Image_gan/style_transfer/paint_transformer/module.py
...mage/Image_gan/style_transfer/paint_transformer/module.py
+160
-0
modules/image/Image_gan/style_transfer/paint_transformer/render_parallel.py
...e_gan/style_transfer/paint_transformer/render_parallel.py
+247
-0
modules/image/Image_gan/style_transfer/paint_transformer/render_serial.py
...age_gan/style_transfer/paint_transformer/render_serial.py
+280
-0
modules/image/Image_gan/style_transfer/paint_transformer/render_utils.py
...mage_gan/style_transfer/paint_transformer/render_utils.py
+111
-0
modules/image/Image_gan/style_transfer/paint_transformer/requirements.txt
...age_gan/style_transfer/paint_transformer/requirements.txt
+1
-0
modules/image/Image_gan/style_transfer/paint_transformer/util.py
.../image/Image_gan/style_transfer/paint_transformer/util.py
+10
-0
未找到文件。
modules/image/Image_gan/style_transfer/paint_transformer/README.md
0 → 100644
浏览文件 @
abab2ad2
# paint_transformer
|模型名称|paint_transformer|
| :--- | :---: |
|类别|图像 - 风格转换|
|网络|Paint Transformer|
|数据集|百度自建数据集|
|是否支持Fine-tuning|否|
|模型大小|77MB|
|最新更新日期|2021-12-07|
|数据指标|-|
## 一、模型基本信息
-
### 应用效果展示
-
样例结果示例:
<p
align=
"center"
>
<img
src=
"https://user-images.githubusercontent.com/22424850/145002878-ffdeea71-8ff4-48cc-88d0-fba1aa1dce4b.jpg"
width =
"40%"
hspace=
'10'
/>
<br
/>
输入图像
<br
/>
<img
src=
"https://user-images.githubusercontent.com/22424850/145002301-97c45887-cb2e-4a06-9d00-07b74080effa.png"
width =
"40%"
hspace=
'10'
/>
<br
/>
输出图像
<br
/>
</p>
-
### 模型介绍
-
该模型可以实现图像油画风格的转换。
-
更多详情参考:
[
Paint Transformer: Feed Forward Neural Painting with Stroke Prediction
](
https://github.com/wzmsltw/PaintTransformer
)
## 二、安装
-
### 1、环境依赖
-
ppgan
-
### 2、安装
-
```shell
$ hub install paint_transformer
```
-
如您安装时遇到问题,可参考:
[
零基础windows安装
](
../../../../docs/docs_ch/get_start/windows_quickstart.md
)
|
[
零基础Linux安装
](
../../../../docs/docs_ch/get_start/linux_quickstart.md
)
|
[
零基础MacOS安装
](
../../../../docs/docs_ch/get_start/mac_quickstart.md
)
## 三、模型API预测
-
### 1、命令行预测
-
```shell
# Read from a file
$ hub run paint_transformer --input_path "/PATH/TO/IMAGE"
```
-
通过命令行方式实现风格转换模型的调用,更多请见
[
PaddleHub命令行指令
](
../../../../docs/docs_ch/tutorial/cmd_usage.rst
)
-
### 2、预测代码示例
-
```python
import paddlehub as hub
module = hub.Module(name="paint_transformer")
input_path = ["/PATH/TO/IMAGE"]
# Read from a file
module.style_transfer(paths=input_path, output_dir='./transfer_result/', use_gpu=True)
```
-
### 3、API
-
```python
style_transfer(images=None, paths=None, output_dir='./transfer_result/', use_gpu=False, need_animation=False, visualization=True):
```
-
油画风格转换API。
- **参数**
- images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\];<br/>
- paths (list\[str\]): 图片的路径;<br/>
- output\_dir (str): 结果保存的路径; <br/>
- use\_gpu (bool): 是否使用 GPU;<br/>
- need_animation(bool): 是否保存中间结果形成动画
- visualization(bool): 是否保存结果到本地文件夹
## 四、服务部署
-
PaddleHub Serving可以部署一个在线油画风格转换服务。
-
### 第一步:启动PaddleHub Serving
-
运行启动命令:
-
```shell
$ hub serving start -m paint_transformer
```
-
这样就完成了一个油画风格转换的在线服务API的部署,默认端口号为8866。
-
**NOTE:**
如使用GPU预测,则需要在启动服务之前,请设置CUDA
\_
VISIBLE
\_
DEVICES环境变量,否则不用设置。
-
### 第二步:发送预测请求
-
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
-
```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/paint_transformer"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
## 五、更新历史
* 1.0.0
初始发布
- ```
shell
$ hub install paint_transformer==1.0.0
```
modules/image/Image_gan/style_transfer/paint_transformer/inference.py
0 → 100644
浏览文件 @
abab2ad2
import
numpy
as
np
from
PIL
import
Image
import
network
import
os
import
math
import
render_utils
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
cv2
import
render_parallel
import
render_serial
def
main
(
input_path
,
model_path
,
output_dir
,
need_animation
=
False
,
resize_h
=
None
,
resize_w
=
None
,
serial
=
False
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
mkdir
(
output_dir
)
input_name
=
os
.
path
.
basename
(
input_path
)
output_path
=
os
.
path
.
join
(
output_dir
,
input_name
)
frame_dir
=
None
if
need_animation
:
if
not
serial
:
print
(
'It must be under serial mode if animation results are required, so serial flag is set to True!'
)
serial
=
True
frame_dir
=
os
.
path
.
join
(
output_dir
,
input_name
[:
input_name
.
find
(
'.'
)])
if
not
os
.
path
.
exists
(
frame_dir
):
os
.
mkdir
(
frame_dir
)
stroke_num
=
8
#* ----- load model ----- *#
paddle
.
set_device
(
'gpu'
)
net_g
=
network
.
Painter
(
5
,
stroke_num
,
256
,
8
,
3
,
3
)
net_g
.
set_state_dict
(
paddle
.
load
(
model_path
))
net_g
.
eval
()
for
param
in
net_g
.
parameters
():
param
.
stop_gradient
=
True
#* ----- load brush ----- *#
brush_large_vertical
=
render_utils
.
read_img
(
'brush/brush_large_vertical.png'
,
'L'
)
brush_large_horizontal
=
render_utils
.
read_img
(
'brush/brush_large_horizontal.png'
,
'L'
)
meta_brushes
=
paddle
.
concat
([
brush_large_vertical
,
brush_large_horizontal
],
axis
=
0
)
import
time
t0
=
time
.
time
()
original_img
=
render_utils
.
read_img
(
input_path
,
'RGB'
,
resize_h
,
resize_w
)
if
serial
:
final_result_list
=
render_serial
.
render_serial
(
original_img
,
net_g
,
meta_brushes
)
if
need_animation
:
print
(
"total frame:"
,
len
(
final_result_list
))
for
idx
,
frame
in
enumerate
(
final_result_list
):
cv2
.
imwrite
(
os
.
path
.
join
(
frame_dir
,
'%03d.png'
%
idx
),
frame
)
else
:
cv2
.
imwrite
(
output_path
,
final_result_list
[
-
1
])
else
:
final_result
=
render_parallel
.
render_parallel
(
original_img
,
net_g
,
meta_brushes
)
cv2
.
imwrite
(
output_path
,
final_result
)
print
(
"total infer time:"
,
time
.
time
()
-
t0
)
if
__name__
==
'__main__'
:
main
(
input_path
=
'input/chicago.jpg'
,
model_path
=
'paint_best.pdparams'
,
output_dir
=
'output/'
,
need_animation
=
True
,
# whether need intermediate results for animation.
resize_h
=
512
,
# resize original input to this size. None means do not resize.
resize_w
=
512
,
# resize original input to this size. None means do not resize.
serial
=
True
)
# if need animation, serial must be True.
modules/image/Image_gan/style_transfer/paint_transformer/model.py
0 → 100644
浏览文件 @
abab2ad2
import
paddle
import
paddle.nn
as
nn
import
math
class
Painter
(
nn
.
Layer
):
"""
network architecture written in paddle.
"""
def
__init__
(
self
,
param_per_stroke
,
total_strokes
,
hidden_dim
,
n_heads
=
8
,
n_enc_layers
=
3
,
n_dec_layers
=
3
):
super
().
__init__
()
self
.
enc_img
=
nn
.
Sequential
(
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
'reflect'
),
nn
.
Conv2D
(
3
,
32
,
3
,
1
),
nn
.
BatchNorm2D
(
32
),
nn
.
ReLU
(),
# maybe replace with the inplace version
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
'reflect'
),
nn
.
Conv2D
(
32
,
64
,
3
,
2
),
nn
.
BatchNorm2D
(
64
),
nn
.
ReLU
(),
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
'reflect'
),
nn
.
Conv2D
(
64
,
128
,
3
,
2
),
nn
.
BatchNorm2D
(
128
),
nn
.
ReLU
())
self
.
enc_canvas
=
nn
.
Sequential
(
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
'reflect'
),
nn
.
Conv2D
(
3
,
32
,
3
,
1
),
nn
.
BatchNorm2D
(
32
),
nn
.
ReLU
(),
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
'reflect'
),
nn
.
Conv2D
(
32
,
64
,
3
,
2
),
nn
.
BatchNorm2D
(
64
),
nn
.
ReLU
(),
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
'reflect'
),
nn
.
Conv2D
(
64
,
128
,
3
,
2
),
nn
.
BatchNorm2D
(
128
),
nn
.
ReLU
())
self
.
conv
=
nn
.
Conv2D
(
128
*
2
,
hidden_dim
,
1
)
self
.
transformer
=
nn
.
Transformer
(
hidden_dim
,
n_heads
,
n_enc_layers
,
n_dec_layers
)
self
.
linear_param
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_dim
,
hidden_dim
),
nn
.
ReLU
(),
nn
.
Linear
(
hidden_dim
,
hidden_dim
),
nn
.
ReLU
(),
nn
.
Linear
(
hidden_dim
,
param_per_stroke
))
self
.
linear_decider
=
nn
.
Linear
(
hidden_dim
,
1
)
self
.
query_pos
=
paddle
.
static
.
create_parameter
([
total_strokes
,
hidden_dim
],
dtype
=
'float32'
,
default_initializer
=
nn
.
initializer
.
Uniform
(
0
,
1
))
self
.
row_embed
=
paddle
.
static
.
create_parameter
([
8
,
hidden_dim
//
2
],
dtype
=
'float32'
,
default_initializer
=
nn
.
initializer
.
Uniform
(
0
,
1
))
self
.
col_embed
=
paddle
.
static
.
create_parameter
([
8
,
hidden_dim
//
2
],
dtype
=
'float32'
,
default_initializer
=
nn
.
initializer
.
Uniform
(
0
,
1
))
def
forward
(
self
,
img
,
canvas
):
"""
prediction
"""
b
,
_
,
H
,
W
=
img
.
shape
img_feat
=
self
.
enc_img
(
img
)
canvas_feat
=
self
.
enc_canvas
(
canvas
)
h
,
w
=
img_feat
.
shape
[
-
2
:]
feat
=
paddle
.
concat
([
img_feat
,
canvas_feat
],
axis
=
1
)
feat_conv
=
self
.
conv
(
feat
)
pos_embed
=
paddle
.
concat
([
self
.
col_embed
[:
w
].
unsqueeze
(
0
).
tile
([
h
,
1
,
1
]),
self
.
row_embed
[:
h
].
unsqueeze
(
1
).
tile
([
1
,
w
,
1
]),
],
axis
=-
1
).
flatten
(
0
,
1
).
unsqueeze
(
1
)
hidden_state
=
self
.
transformer
((
pos_embed
+
feat_conv
.
flatten
(
2
).
transpose
([
2
,
0
,
1
])).
transpose
([
1
,
0
,
2
]),
self
.
query_pos
.
unsqueeze
(
1
).
tile
([
1
,
b
,
1
]).
transpose
([
1
,
0
,
2
]))
param
=
self
.
linear_param
(
hidden_state
)
decision
=
self
.
linear_decider
(
hidden_state
)
return
param
,
decision
modules/image/Image_gan/style_transfer/paint_transformer/module.py
0 → 100644
浏览文件 @
abab2ad2
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
import
copy
import
paddle
import
paddlehub
as
hub
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
import
numpy
as
np
import
cv2
from
skimage.io
import
imread
from
skimage.transform
import
rescale
,
resize
from
.model
import
Painter
from
.render_utils
import
totensor
,
read_img
from
.render_serial
import
render_serial
from
.util
import
base64_to_cv2
@
moduleinfo
(
name
=
"paint_transformer"
,
type
=
"CV/style_transfer"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
""
,
version
=
"1.0.0"
)
class
paint_transformer
:
def
__init__
(
self
):
self
.
pretrained_model
=
os
.
path
.
join
(
self
.
directory
,
"paint_best.pdparams"
)
self
.
network
=
Painter
(
5
,
8
,
256
,
8
,
3
,
3
)
self
.
network
.
set_state_dict
(
paddle
.
load
(
self
.
pretrained_model
))
self
.
network
.
eval
()
for
param
in
self
.
network
.
parameters
():
param
.
stop_gradient
=
True
#* ----- load brush ----- *#
brush_large_vertical
=
read_img
(
os
.
path
.
join
(
self
.
directory
,
'brush/brush_large_vertical.png'
),
'L'
)
brush_large_horizontal
=
read_img
(
os
.
path
.
join
(
self
.
directory
,
'brush/brush_large_horizontal.png'
),
'L'
)
self
.
meta_brushes
=
paddle
.
concat
([
brush_large_vertical
,
brush_large_horizontal
],
axis
=
0
)
def
style_transfer
(
self
,
images
:
list
=
None
,
paths
:
list
=
None
,
output_dir
:
str
=
'./transfer_result/'
,
use_gpu
:
bool
=
False
,
need_animation
:
bool
=
False
,
visualization
:
bool
=
True
):
'''
images (list[numpy.ndarray]): data of images, shape of each is [H, W, C], color space must be BGR(read by cv2).
paths (list[str]): paths to images
output_dir (str): the dir to save the results
use_gpu (bool): if True, use gpu to perform the computation, otherwise cpu.
need_animation (bool): if True, save every frame to show the process of painting.
visualization (bool): if True, save results in output_dir.
'''
results
=
[]
paddle
.
disable_static
()
place
=
'gpu:0'
if
use_gpu
else
'cpu'
place
=
paddle
.
set_device
(
place
)
if
images
==
None
and
paths
==
None
:
print
(
'No image provided. Please input an image or a image path.'
)
return
if
images
!=
None
:
for
image
in
images
:
image
=
image
[:,
:,
::
-
1
]
image
=
totensor
(
image
)
final_result_list
=
render_serial
(
image
,
self
.
network
,
self
.
meta_brushes
)
results
.
append
(
final_result_list
)
if
paths
!=
None
:
for
path
in
paths
:
image
=
cv2
.
imread
(
path
)[:,
:,
::
-
1
]
image
=
totensor
(
image
)
final_result_list
=
render_serial
(
image
,
self
.
network
,
self
.
meta_brushes
)
results
.
append
(
final_result_list
)
if
visualization
==
True
:
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
for
i
,
out
in
enumerate
(
results
):
if
out
:
if
need_animation
:
curoutputdir
=
os
.
path
.
join
(
output_dir
,
'output_{}'
.
format
(
i
))
if
not
os
.
path
.
exists
(
curoutputdir
):
os
.
makedirs
(
curoutputdir
,
exist_ok
=
True
)
for
j
,
outimg
in
enumerate
(
out
):
cv2
.
imwrite
(
os
.
path
.
join
(
curoutputdir
,
'frame_{}.png'
.
format
(
j
)),
outimg
)
else
:
cv2
.
imwrite
(
os
.
path
.
join
(
output_dir
,
'output_{}.png'
.
format
(
i
)),
out
[
-
1
])
return
results
@
runnable
def
run_cmd
(
self
,
argvs
:
list
):
"""
Run as a command.
"""
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
"Run the {} module."
.
format
(
self
.
name
),
prog
=
'hub run {}'
.
format
(
self
.
name
),
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_config_group
=
self
.
parser
.
add_argument_group
(
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
self
.
args
=
self
.
parser
.
parse_args
(
argvs
)
results
=
self
.
style_transfer
(
paths
=
[
self
.
args
.
input_path
],
output_dir
=
self
.
args
.
output_dir
,
use_gpu
=
self
.
args
.
use_gpu
,
need_animation
=
self
.
args
.
need_animation
,
visualization
=
self
.
args
.
visualization
)
return
results
@
serving
def
serving_method
(
self
,
images
,
**
kwargs
):
"""
Run as a service.
"""
images_decode
=
[
base64_to_cv2
(
image
)
for
image
in
images
]
results
=
self
.
style_transfer
(
images
=
images_decode
,
**
kwargs
)
tolist
=
[
result
.
tolist
()
for
result
in
results
]
return
tolist
def
add_module_config_arg
(
self
):
"""
Add the command config options.
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
action
=
'store_true'
,
help
=
"use GPU or not"
)
self
.
arg_config_group
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'transfer_result'
,
help
=
'output directory for saving result.'
)
self
.
arg_config_group
.
add_argument
(
'--visualization'
,
type
=
bool
,
default
=
False
,
help
=
'save results or not.'
)
self
.
arg_config_group
.
add_argument
(
'--need_animation'
,
type
=
bool
,
default
=
False
,
help
=
'save intermediate results or not.'
)
def
add_module_input_arg
(
self
):
"""
Add the command input options.
"""
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to input image."
)
modules/image/Image_gan/style_transfer/paint_transformer/render_parallel.py
0 → 100644
浏览文件 @
abab2ad2
import
render_utils
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
numpy
as
np
import
math
def
crop
(
img
,
h
,
w
):
H
,
W
=
img
.
shape
[
-
2
:]
pad_h
=
(
H
-
h
)
//
2
pad_w
=
(
W
-
w
)
//
2
remainder_h
=
(
H
-
h
)
%
2
remainder_w
=
(
W
-
w
)
%
2
img
=
img
[:,
:,
pad_h
:
H
-
pad_h
-
remainder_h
,
pad_w
:
W
-
pad_w
-
remainder_w
]
return
img
def
stroke_net_predict
(
img_patch
,
result_patch
,
patch_size
,
net_g
,
stroke_num
,
patch_num
):
"""
stroke_net_predict
"""
img_patch
=
img_patch
.
transpose
([
0
,
2
,
1
]).
reshape
([
-
1
,
3
,
patch_size
,
patch_size
])
result_patch
=
result_patch
.
transpose
([
0
,
2
,
1
]).
reshape
([
-
1
,
3
,
patch_size
,
patch_size
])
#*----- Stroke Predictor -----*#
shape_param
,
stroke_decision
=
net_g
(
img_patch
,
result_patch
)
stroke_decision
=
(
stroke_decision
>
0
).
astype
(
'float32'
)
#*----- sampling color -----*#
grid
=
shape_param
[:,
:,
:
2
].
reshape
([
img_patch
.
shape
[
0
]
*
stroke_num
,
1
,
1
,
2
])
img_temp
=
img_patch
.
unsqueeze
(
1
).
tile
([
1
,
stroke_num
,
1
,
1
,
1
]).
reshape
([
img_patch
.
shape
[
0
]
*
stroke_num
,
3
,
patch_size
,
patch_size
])
color
=
nn
.
functional
.
grid_sample
(
img_temp
,
2
*
grid
-
1
,
align_corners
=
False
).
reshape
([
img_patch
.
shape
[
0
],
stroke_num
,
3
])
param
=
paddle
.
concat
([
shape_param
,
color
],
axis
=-
1
)
param
=
param
.
reshape
([
-
1
,
8
])
param
[:,
:
2
]
=
param
[:,
:
2
]
/
2
+
0.25
param
[:,
2
:
4
]
=
param
[:,
2
:
4
]
/
2
param
=
param
.
reshape
([
1
,
patch_num
,
patch_num
,
stroke_num
,
8
])
decision
=
stroke_decision
.
reshape
([
1
,
patch_num
,
patch_num
,
stroke_num
])
#.astype('bool')
return
param
,
decision
def
param2img_parallel
(
param
,
decision
,
meta_brushes
,
cur_canvas
,
stroke_num
=
8
):
"""
Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
and whether there is a border (if intermediate painting results are required).
Output the painting results of adding the corresponding strokes on the current canvas.
Args:
param: a tensor with shape batch size x patch along height dimension x patch along width dimension
x n_stroke_per_patch x n_param_per_stroke
decision: a 01 tensor with shape batch size x patch along height dimension x patch along width dimension
x n_stroke_per_patch
meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
cur_canvas: a tensor with shape batch size x 3 x H x W,
where H and W denote height and width of padded results of original images.
Returns:
cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
"""
# param: b, h, w, stroke_per_patch, param_per_stroke
# decision: b, h, w, stroke_per_patch
b
,
h
,
w
,
s
,
p
=
param
.
shape
h
,
w
=
int
(
h
),
int
(
w
)
param
=
param
.
reshape
([
-
1
,
8
])
decision
=
decision
.
reshape
([
-
1
,
8
])
H
,
W
=
cur_canvas
.
shape
[
-
2
:]
is_odd_y
=
h
%
2
==
1
is_odd_x
=
w
%
2
==
1
render_size_y
=
2
*
H
//
h
render_size_x
=
2
*
W
//
w
even_idx_y
=
paddle
.
arange
(
0
,
h
,
2
)
even_idx_x
=
paddle
.
arange
(
0
,
w
,
2
)
if
h
>
1
:
odd_idx_y
=
paddle
.
arange
(
1
,
h
,
2
)
if
w
>
1
:
odd_idx_x
=
paddle
.
arange
(
1
,
w
,
2
)
cur_canvas
=
F
.
pad
(
cur_canvas
,
[
render_size_x
//
4
,
render_size_x
//
4
,
render_size_y
//
4
,
render_size_y
//
4
])
valid_foregrounds
=
render_utils
.
param2stroke
(
param
,
render_size_y
,
render_size_x
,
meta_brushes
)
#* ----- load dilation/erosion ---- *#
dilation
=
render_utils
.
Dilation2d
(
m
=
1
)
erosion
=
render_utils
.
Erosion2d
(
m
=
1
)
#* ----- generate alphas ----- *#
valid_alphas
=
(
valid_foregrounds
>
0
).
astype
(
'float32'
)
valid_foregrounds
=
valid_foregrounds
.
reshape
([
-
1
,
stroke_num
,
1
,
render_size_y
,
render_size_x
])
valid_alphas
=
valid_alphas
.
reshape
([
-
1
,
stroke_num
,
1
,
render_size_y
,
render_size_x
])
temp
=
[
dilation
(
valid_foregrounds
[:,
i
,
:,
:,
:])
for
i
in
range
(
stroke_num
)]
valid_foregrounds
=
paddle
.
stack
(
temp
,
axis
=
1
)
valid_foregrounds
=
valid_foregrounds
.
reshape
([
-
1
,
1
,
render_size_y
,
render_size_x
])
temp
=
[
erosion
(
valid_alphas
[:,
i
,
:,
:,
:])
for
i
in
range
(
stroke_num
)]
valid_alphas
=
paddle
.
stack
(
temp
,
axis
=
1
)
valid_alphas
=
valid_alphas
.
reshape
([
-
1
,
1
,
render_size_y
,
render_size_x
])
foregrounds
=
valid_foregrounds
.
reshape
([
-
1
,
h
,
w
,
stroke_num
,
1
,
render_size_y
,
render_size_x
])
alphas
=
valid_alphas
.
reshape
([
-
1
,
h
,
w
,
stroke_num
,
1
,
render_size_y
,
render_size_x
])
decision
=
decision
.
reshape
([
-
1
,
h
,
w
,
stroke_num
,
1
,
1
,
1
])
param
=
param
.
reshape
([
-
1
,
h
,
w
,
stroke_num
,
8
])
def
partial_render
(
this_canvas
,
patch_coord_y
,
patch_coord_x
):
canvas_patch
=
F
.
unfold
(
this_canvas
,
[
render_size_y
,
render_size_x
],
strides
=
[
render_size_y
//
2
,
render_size_x
//
2
])
# canvas_patch: b, 3 * py * px, h * w
canvas_patch
=
canvas_patch
.
reshape
([
b
,
3
,
render_size_y
,
render_size_x
,
h
,
w
])
canvas_patch
=
canvas_patch
.
transpose
([
0
,
4
,
5
,
1
,
2
,
3
])
selected_canvas_patch
=
paddle
.
gather
(
canvas_patch
,
patch_coord_y
,
1
)
selected_canvas_patch
=
paddle
.
gather
(
selected_canvas_patch
,
patch_coord_x
,
2
)
selected_canvas_patch
=
selected_canvas_patch
.
reshape
([
0
,
0
,
0
,
1
,
3
,
render_size_y
,
render_size_x
])
selected_foregrounds
=
paddle
.
gather
(
foregrounds
,
patch_coord_y
,
1
)
selected_foregrounds
=
paddle
.
gather
(
selected_foregrounds
,
patch_coord_x
,
2
)
selected_alphas
=
paddle
.
gather
(
alphas
,
patch_coord_y
,
1
)
selected_alphas
=
paddle
.
gather
(
selected_alphas
,
patch_coord_x
,
2
)
selected_decisions
=
paddle
.
gather
(
decision
,
patch_coord_y
,
1
)
selected_decisions
=
paddle
.
gather
(
selected_decisions
,
patch_coord_x
,
2
)
selected_color
=
paddle
.
gather
(
param
,
patch_coord_y
,
1
)
selected_color
=
paddle
.
gather
(
selected_color
,
patch_coord_x
,
2
)
selected_color
=
paddle
.
gather
(
selected_color
,
paddle
.
to_tensor
([
5
,
6
,
7
]),
4
)
selected_color
=
selected_color
.
reshape
([
0
,
0
,
0
,
stroke_num
,
3
,
1
,
1
])
for
i
in
range
(
stroke_num
):
i
=
paddle
.
to_tensor
(
i
)
cur_foreground
=
paddle
.
gather
(
selected_foregrounds
,
i
,
3
)
cur_alpha
=
paddle
.
gather
(
selected_alphas
,
i
,
3
)
cur_decision
=
paddle
.
gather
(
selected_decisions
,
i
,
3
)
cur_color
=
paddle
.
gather
(
selected_color
,
i
,
3
)
cur_foreground
=
cur_foreground
*
cur_color
selected_canvas_patch
=
cur_foreground
*
cur_alpha
*
cur_decision
+
selected_canvas_patch
*
(
1
-
cur_alpha
*
cur_decision
)
selected_canvas_patch
=
selected_canvas_patch
.
reshape
([
0
,
0
,
0
,
3
,
render_size_y
,
render_size_x
])
this_canvas
=
selected_canvas_patch
.
transpose
([
0
,
3
,
1
,
4
,
2
,
5
])
# this_canvas: b, 3, h_half, py, w_half, px
h_half
=
this_canvas
.
shape
[
2
]
w_half
=
this_canvas
.
shape
[
4
]
this_canvas
=
this_canvas
.
reshape
([
b
,
3
,
h_half
*
render_size_y
,
w_half
*
render_size_x
])
# this_canvas: b, 3, h_half * py, w_half * px
return
this_canvas
# even - even area
# 1 | 0
# 0 | 0
canvas
=
partial_render
(
cur_canvas
,
even_idx_y
,
even_idx_x
)
if
not
is_odd_y
:
canvas
=
paddle
.
concat
([
canvas
,
cur_canvas
[:,
:,
-
render_size_y
//
2
:,
:
canvas
.
shape
[
3
]]],
axis
=
2
)
if
not
is_odd_x
:
canvas
=
paddle
.
concat
([
canvas
,
cur_canvas
[:,
:,
:
canvas
.
shape
[
2
],
-
render_size_x
//
2
:]],
axis
=
3
)
cur_canvas
=
canvas
# odd - odd area
# 0 | 0
# 0 | 1
if
h
>
1
and
w
>
1
:
canvas
=
partial_render
(
cur_canvas
,
odd_idx_y
,
odd_idx_x
)
canvas
=
paddle
.
concat
([
cur_canvas
[:,
:,
:
render_size_y
//
2
,
-
canvas
.
shape
[
3
]:],
canvas
],
axis
=
2
)
canvas
=
paddle
.
concat
([
cur_canvas
[:,
:,
-
canvas
.
shape
[
2
]:,
:
render_size_x
//
2
],
canvas
],
axis
=
3
)
if
is_odd_y
:
canvas
=
paddle
.
concat
([
canvas
,
cur_canvas
[:,
:,
-
render_size_y
//
2
:,
:
canvas
.
shape
[
3
]]],
axis
=
2
)
if
is_odd_x
:
canvas
=
paddle
.
concat
([
canvas
,
cur_canvas
[:,
:,
:
canvas
.
shape
[
2
],
-
render_size_x
//
2
:]],
axis
=
3
)
cur_canvas
=
canvas
# odd - even area
# 0 | 0
# 1 | 0
if
h
>
1
:
canvas
=
partial_render
(
cur_canvas
,
odd_idx_y
,
even_idx_x
)
canvas
=
paddle
.
concat
([
cur_canvas
[:,
:,
:
render_size_y
//
2
,
:
canvas
.
shape
[
3
]],
canvas
],
axis
=
2
)
if
is_odd_y
:
canvas
=
paddle
.
concat
([
canvas
,
cur_canvas
[:,
:,
-
render_size_y
//
2
:,
:
canvas
.
shape
[
3
]]],
axis
=
2
)
if
not
is_odd_x
:
canvas
=
paddle
.
concat
([
canvas
,
cur_canvas
[:,
:,
:
canvas
.
shape
[
2
],
-
render_size_x
//
2
:]],
axis
=
3
)
cur_canvas
=
canvas
# odd - even area
# 0 | 1
# 0 | 0
if
w
>
1
:
canvas
=
partial_render
(
cur_canvas
,
even_idx_y
,
odd_idx_x
)
canvas
=
paddle
.
concat
([
cur_canvas
[:,
:,
:
canvas
.
shape
[
2
],
:
render_size_x
//
2
],
canvas
],
axis
=
3
)
if
not
is_odd_y
:
canvas
=
paddle
.
concat
([
canvas
,
cur_canvas
[:,
:,
-
render_size_y
//
2
:,
-
canvas
.
shape
[
3
]:]],
axis
=
2
)
if
is_odd_x
:
canvas
=
paddle
.
concat
([
canvas
,
cur_canvas
[:,
:,
:
canvas
.
shape
[
2
],
-
render_size_x
//
2
:]],
axis
=
3
)
cur_canvas
=
canvas
cur_canvas
=
cur_canvas
[:,
:,
render_size_y
//
4
:
-
render_size_y
//
4
,
render_size_x
//
4
:
-
render_size_x
//
4
]
return
cur_canvas
def
render_parallel
(
original_img
,
net_g
,
meta_brushes
):
patch_size
=
32
stroke_num
=
8
with
paddle
.
no_grad
():
original_h
,
original_w
=
original_img
.
shape
[
-
2
:]
K
=
max
(
math
.
ceil
(
math
.
log2
(
max
(
original_h
,
original_w
)
/
patch_size
)),
0
)
original_img_pad_size
=
patch_size
*
(
2
**
K
)
original_img_pad
=
render_utils
.
pad
(
original_img
,
original_img_pad_size
,
original_img_pad_size
)
final_result
=
paddle
.
zeros_like
(
original_img
)
for
layer
in
range
(
0
,
K
+
1
):
layer_size
=
patch_size
*
(
2
**
layer
)
img
=
F
.
interpolate
(
original_img_pad
,
(
layer_size
,
layer_size
))
result
=
F
.
interpolate
(
final_result
,
(
layer_size
,
layer_size
))
img_patch
=
F
.
unfold
(
img
,
[
patch_size
,
patch_size
],
strides
=
[
patch_size
,
patch_size
])
result_patch
=
F
.
unfold
(
result
,
[
patch_size
,
patch_size
],
strides
=
[
patch_size
,
patch_size
])
# There are patch_num * patch_num patches in total
patch_num
=
(
layer_size
-
patch_size
)
//
patch_size
+
1
param
,
decision
=
stroke_net_predict
(
img_patch
,
result_patch
,
patch_size
,
net_g
,
stroke_num
,
patch_num
)
#print(param.shape, decision.shape)
final_result
=
param2img_parallel
(
param
,
decision
,
meta_brushes
,
final_result
)
# paint another time for last layer
border_size
=
original_img_pad_size
//
(
2
*
patch_num
)
img
=
F
.
interpolate
(
original_img_pad
,
(
layer_size
,
layer_size
))
result
=
F
.
interpolate
(
final_result
,
(
layer_size
,
layer_size
))
img
=
F
.
pad
(
img
,
[
patch_size
//
2
,
patch_size
//
2
,
patch_size
//
2
,
patch_size
//
2
])
result
=
F
.
pad
(
result
,
[
patch_size
//
2
,
patch_size
//
2
,
patch_size
//
2
,
patch_size
//
2
])
img_patch
=
F
.
unfold
(
img
,
[
patch_size
,
patch_size
],
strides
=
[
patch_size
,
patch_size
])
result_patch
=
F
.
unfold
(
result
,
[
patch_size
,
patch_size
],
strides
=
[
patch_size
,
patch_size
])
final_result
=
F
.
pad
(
final_result
,
[
border_size
,
border_size
,
border_size
,
border_size
])
patch_num
=
(
img
.
shape
[
2
]
-
patch_size
)
//
patch_size
+
1
#w = (img.shape[3] - patch_size) // patch_size + 1
param
,
decision
=
stroke_net_predict
(
img_patch
,
result_patch
,
patch_size
,
net_g
,
stroke_num
,
patch_num
)
final_result
=
param2img_parallel
(
param
,
decision
,
meta_brushes
,
final_result
)
final_result
=
final_result
[:,
:,
border_size
:
-
border_size
,
border_size
:
-
border_size
]
final_result
=
(
final_result
.
numpy
().
squeeze
().
transpose
([
1
,
2
,
0
])[:,
:,
::
-
1
]
*
255
).
astype
(
np
.
uint8
)
return
final_result
modules/image/Image_gan/style_transfer/paint_transformer/render_serial.py
0 → 100644
浏览文件 @
abab2ad2
# !/usr/bin/env python3
"""
codes for oilpainting style transfer.
"""
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
numpy
as
np
from
PIL
import
Image
import
math
import
cv2
import
time
from
.render_utils
import
param2stroke
,
Dilation2d
,
Erosion2d
def
get_single_layer_lists
(
param
,
decision
,
ori_img
,
render_size_x
,
render_size_y
,
h
,
w
,
meta_brushes
,
dilation
,
erosion
,
stroke_num
):
"""
get_single_layer_lists
"""
valid_foregrounds
=
param2stroke
(
param
[:,
:],
render_size_y
,
render_size_x
,
meta_brushes
)
valid_alphas
=
(
valid_foregrounds
>
0
).
astype
(
'float32'
)
valid_foregrounds
=
valid_foregrounds
.
reshape
([
-
1
,
stroke_num
,
1
,
render_size_y
,
render_size_x
])
valid_alphas
=
valid_alphas
.
reshape
([
-
1
,
stroke_num
,
1
,
render_size_y
,
render_size_x
])
temp
=
[
dilation
(
valid_foregrounds
[:,
i
,
:,
:,
:])
for
i
in
range
(
stroke_num
)]
valid_foregrounds
=
paddle
.
stack
(
temp
,
axis
=
1
)
valid_foregrounds
=
valid_foregrounds
.
reshape
([
-
1
,
1
,
render_size_y
,
render_size_x
])
temp
=
[
erosion
(
valid_alphas
[:,
i
,
:,
:,
:])
for
i
in
range
(
stroke_num
)]
valid_alphas
=
paddle
.
stack
(
temp
,
axis
=
1
)
valid_alphas
=
valid_alphas
.
reshape
([
-
1
,
1
,
render_size_y
,
render_size_x
])
patch_y
=
4
*
render_size_y
//
5
patch_x
=
4
*
render_size_x
//
5
img_patch
=
ori_img
.
reshape
([
1
,
3
,
h
,
ori_img
.
shape
[
2
]
//
h
,
w
,
ori_img
.
shape
[
3
]
//
w
])
img_patch
=
img_patch
.
transpose
([
0
,
2
,
4
,
1
,
3
,
5
])[
0
]
xid_list
=
[]
yid_list
=
[]
error_list
=
[]
for
flag_idx
,
flag
in
enumerate
(
decision
.
cpu
().
numpy
()):
if
flag
:
flag_idx
=
flag_idx
//
stroke_num
x_id
=
flag_idx
%
w
flag_idx
=
flag_idx
//
w
y_id
=
flag_idx
%
h
xid_list
.
append
(
x_id
)
yid_list
.
append
(
y_id
)
inner_fores
=
valid_foregrounds
[:,
:,
render_size_y
//
10
:
9
*
render_size_y
//
10
,
render_size_x
//
10
:
9
*
render_size_x
//
10
]
inner_alpha
=
valid_alphas
[:,
:,
render_size_y
//
10
:
9
*
render_size_y
//
10
,
render_size_x
//
10
:
9
*
render_size_x
//
10
]
inner_fores
=
inner_fores
.
reshape
([
h
*
w
,
stroke_num
,
1
,
patch_y
,
patch_x
])
inner_alpha
=
inner_alpha
.
reshape
([
h
*
w
,
stroke_num
,
1
,
patch_y
,
patch_x
])
inner_real
=
img_patch
.
reshape
([
h
*
w
,
3
,
patch_y
,
patch_x
]).
unsqueeze
(
1
)
R
=
param
[:,
5
]
G
=
param
[:,
6
]
B
=
param
[:,
7
]
#, G, B = param[5:]
R
=
R
.
reshape
([
-
1
,
stroke_num
]).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
G
=
G
.
reshape
([
-
1
,
stroke_num
]).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
B
=
B
.
reshape
([
-
1
,
stroke_num
]).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
error_R
=
R
*
inner_fores
-
inner_real
[:,
:,
0
:
1
,
:,
:]
error_G
=
G
*
inner_fores
-
inner_real
[:,
:,
1
:
2
,
:,
:]
error_B
=
B
*
inner_fores
-
inner_real
[:,
:,
2
:
3
,
:,
:]
error
=
paddle
.
abs
(
error_R
)
+
paddle
.
abs
(
error_G
)
+
paddle
.
abs
(
error_B
)
error
=
error
*
inner_alpha
error
=
paddle
.
sum
(
error
,
axis
=
(
2
,
3
,
4
))
/
paddle
.
sum
(
inner_alpha
,
axis
=
(
2
,
3
,
4
))
error_list
=
error
.
reshape
([
-
1
]).
numpy
()[
decision
.
numpy
()]
error_list
=
list
(
error_list
)
valid_foregrounds
=
paddle
.
to_tensor
(
valid_foregrounds
.
numpy
()[
decision
.
numpy
()])
valid_alphas
=
paddle
.
to_tensor
(
valid_alphas
.
numpy
()[
decision
.
numpy
()])
selected_param
=
paddle
.
to_tensor
(
param
.
numpy
()[
decision
.
numpy
()])
return
xid_list
,
yid_list
,
valid_foregrounds
,
valid_alphas
,
error_list
,
selected_param
def
get_single_stroke_on_full_image_A
(
x_id
,
y_id
,
valid_foregrounds
,
valid_alphas
,
param
,
original_img
,
render_size_x
,
render_size_y
,
patch_x
,
patch_y
):
"""
get_single_stroke_on_full_image_A
"""
tmp_foreground
=
paddle
.
zeros_like
(
original_img
)
patch_y_num
=
original_img
.
shape
[
2
]
//
patch_y
patch_x_num
=
original_img
.
shape
[
3
]
//
patch_x
brush
=
valid_foregrounds
.
unsqueeze
(
0
)
color_map
=
param
[
5
:]
brush
=
brush
.
tile
([
1
,
3
,
1
,
1
])
color_map
=
color_map
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
0
)
#.repeat(1, 1, H, W)
brush
=
brush
*
color_map
pad_l
=
x_id
*
patch_x
pad_r
=
(
patch_x_num
-
x_id
-
1
)
*
patch_x
pad_t
=
y_id
*
patch_y
pad_b
=
(
patch_y_num
-
y_id
-
1
)
*
patch_y
tmp_foreground
=
nn
.
functional
.
pad
(
brush
,
[
pad_l
,
pad_r
,
pad_t
,
pad_b
])
tmp_foreground
=
tmp_foreground
[:,
:,
render_size_y
//
10
:
-
render_size_y
//
10
,
render_size_x
//
10
:
-
render_size_x
//
10
]
tmp_alpha
=
nn
.
functional
.
pad
(
valid_alphas
.
unsqueeze
(
0
),
[
pad_l
,
pad_r
,
pad_t
,
pad_b
])
tmp_alpha
=
tmp_alpha
[:,
:,
render_size_y
//
10
:
-
render_size_y
//
10
,
render_size_x
//
10
:
-
render_size_x
//
10
]
return
tmp_foreground
,
tmp_alpha
def
get_single_stroke_on_full_image_B
(
x_id
,
y_id
,
valid_foregrounds
,
valid_alphas
,
param
,
original_img
,
render_size_x
,
render_size_y
,
patch_x
,
patch_y
):
"""
get_single_stroke_on_full_image_B
"""
x_expand
=
patch_x
//
2
+
render_size_x
//
10
y_expand
=
patch_y
//
2
+
render_size_y
//
10
pad_l
=
x_id
*
patch_x
pad_r
=
original_img
.
shape
[
3
]
+
2
*
x_expand
-
(
x_id
*
patch_x
+
render_size_x
)
pad_t
=
y_id
*
patch_y
pad_b
=
original_img
.
shape
[
2
]
+
2
*
y_expand
-
(
y_id
*
patch_y
+
render_size_y
)
brush
=
valid_foregrounds
.
unsqueeze
(
0
)
color_map
=
param
[
5
:]
brush
=
brush
.
tile
([
1
,
3
,
1
,
1
])
color_map
=
color_map
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
0
)
#.repeat(1, 1, H, W)
brush
=
brush
*
color_map
tmp_foreground
=
nn
.
functional
.
pad
(
brush
,
[
pad_l
,
pad_r
,
pad_t
,
pad_b
])
tmp_foreground
=
tmp_foreground
[:,
:,
y_expand
:
-
y_expand
,
x_expand
:
-
x_expand
]
tmp_alpha
=
nn
.
functional
.
pad
(
valid_alphas
.
unsqueeze
(
0
),
[
pad_l
,
pad_r
,
pad_t
,
pad_b
])
tmp_alpha
=
tmp_alpha
[:,
:,
y_expand
:
-
y_expand
,
x_expand
:
-
x_expand
]
return
tmp_foreground
,
tmp_alpha
def
stroke_net_predict
(
img_patch
,
result_patch
,
patch_size
,
net_g
,
stroke_num
):
"""
stroke_net_predict
"""
img_patch
=
img_patch
.
transpose
([
0
,
2
,
1
]).
reshape
([
-
1
,
3
,
patch_size
,
patch_size
])
result_patch
=
result_patch
.
transpose
([
0
,
2
,
1
]).
reshape
([
-
1
,
3
,
patch_size
,
patch_size
])
#*----- Stroke Predictor -----*#
shape_param
,
stroke_decision
=
net_g
(
img_patch
,
result_patch
)
stroke_decision
=
(
stroke_decision
>
0
).
astype
(
'float32'
)
#*----- sampling color -----*#
grid
=
shape_param
[:,
:,
:
2
].
reshape
([
img_patch
.
shape
[
0
]
*
stroke_num
,
1
,
1
,
2
])
img_temp
=
img_patch
.
unsqueeze
(
1
).
tile
([
1
,
stroke_num
,
1
,
1
,
1
]).
reshape
([
img_patch
.
shape
[
0
]
*
stroke_num
,
3
,
patch_size
,
patch_size
])
color
=
nn
.
functional
.
grid_sample
(
img_temp
,
2
*
grid
-
1
,
align_corners
=
False
).
reshape
([
img_patch
.
shape
[
0
],
stroke_num
,
3
])
stroke_param
=
paddle
.
concat
([
shape_param
,
color
],
axis
=-
1
)
param
=
stroke_param
.
reshape
([
-
1
,
8
])
decision
=
stroke_decision
.
reshape
([
-
1
]).
astype
(
'bool'
)
param
[:,
:
2
]
=
param
[:,
:
2
]
/
1.25
+
0.1
param
[:,
2
:
4
]
=
param
[:,
2
:
4
]
/
1.25
return
param
,
decision
def
sort_strokes
(
params
,
decision
,
scores
):
"""
sort_strokes
"""
sorted_scores
,
sorted_index
=
paddle
.
sort
(
scores
,
axis
=
1
,
descending
=
False
)
sorted_params
=
[]
for
idx
in
range
(
8
):
tmp_pick_params
=
paddle
.
gather
(
params
[:,
:,
idx
],
axis
=
1
,
index
=
sorted_index
)
sorted_params
.
append
(
tmp_pick_params
)
sorted_params
=
paddle
.
stack
(
sorted_params
,
axis
=
2
)
sorted_decison
=
paddle
.
gather
(
decision
.
squeeze
(
2
),
axis
=
1
,
index
=
sorted_index
)
return
sorted_params
,
sorted_decison
def
render_serial
(
original_img
,
net_g
,
meta_brushes
):
patch_size
=
32
stroke_num
=
8
H
,
W
=
original_img
.
shape
[
-
2
:]
K
=
max
(
math
.
ceil
(
math
.
log2
(
max
(
H
,
W
)
/
patch_size
)),
0
)
dilation
=
Dilation2d
(
m
=
1
)
erosion
=
Erosion2d
(
m
=
1
)
frames_per_layer
=
[
20
,
20
,
30
,
40
,
60
]
final_frame_list
=
[]
with
paddle
.
no_grad
():
#* ----- read in image and init canvas ----- *#
final_result
=
paddle
.
zeros_like
(
original_img
)
for
layer
in
range
(
0
,
K
+
1
):
t0
=
time
.
time
()
layer_size
=
patch_size
*
(
2
**
layer
)
img
=
nn
.
functional
.
interpolate
(
original_img
,
(
layer_size
,
layer_size
))
result
=
nn
.
functional
.
interpolate
(
final_result
,
(
layer_size
,
layer_size
))
img_patch
=
nn
.
functional
.
unfold
(
img
,
[
patch_size
,
patch_size
],
strides
=
[
patch_size
,
patch_size
])
result_patch
=
nn
.
functional
.
unfold
(
result
,
[
patch_size
,
patch_size
],
strides
=
[
patch_size
,
patch_size
])
h
=
(
img
.
shape
[
2
]
-
patch_size
)
//
patch_size
+
1
w
=
(
img
.
shape
[
3
]
-
patch_size
)
//
patch_size
+
1
render_size_y
=
int
(
1.25
*
H
//
h
)
render_size_x
=
int
(
1.25
*
W
//
w
)
#* -------------------------------------------------------------*#
#* -------------generate strokes on window type A---------------*#
#* -------------------------------------------------------------*#
param
,
decision
=
stroke_net_predict
(
img_patch
,
result_patch
,
patch_size
,
net_g
,
stroke_num
)
expand_img
=
original_img
wA_xid_list
,
wA_yid_list
,
wA_fore_list
,
wA_alpha_list
,
wA_error_list
,
wA_params
=
\
get_single_layer_lists
(
param
,
decision
,
original_img
,
render_size_x
,
render_size_y
,
h
,
w
,
meta_brushes
,
dilation
,
erosion
,
stroke_num
)
#* -------------------------------------------------------------*#
#* -------------generate strokes on window type B---------------*#
#* -------------------------------------------------------------*#
#*----- generate input canvas and target patches -----*#
wB_error_list
=
[]
img
=
nn
.
functional
.
pad
(
img
,
[
patch_size
//
2
,
patch_size
//
2
,
patch_size
//
2
,
patch_size
//
2
])
result
=
nn
.
functional
.
pad
(
result
,
[
patch_size
//
2
,
patch_size
//
2
,
patch_size
//
2
,
patch_size
//
2
])
img_patch
=
nn
.
functional
.
unfold
(
img
,
[
patch_size
,
patch_size
],
strides
=
[
patch_size
,
patch_size
])
result_patch
=
nn
.
functional
.
unfold
(
result
,
[
patch_size
,
patch_size
],
strides
=
[
patch_size
,
patch_size
])
h
+=
1
w
+=
1
param
,
decision
=
stroke_net_predict
(
img_patch
,
result_patch
,
patch_size
,
net_g
,
stroke_num
)
patch_y
=
4
*
render_size_y
//
5
patch_x
=
4
*
render_size_x
//
5
expand_img
=
nn
.
functional
.
pad
(
original_img
,
[
patch_x
//
2
,
patch_x
//
2
,
patch_y
//
2
,
patch_y
//
2
])
wB_xid_list
,
wB_yid_list
,
wB_fore_list
,
wB_alpha_list
,
wB_error_list
,
wB_params
=
\
get_single_layer_lists
(
param
,
decision
,
expand_img
,
render_size_x
,
render_size_y
,
h
,
w
,
meta_brushes
,
dilation
,
erosion
,
stroke_num
)
#* -------------------------------------------------------------*#
#* -------------rank strokes and plot stroke one by one---------*#
#* -------------------------------------------------------------*#
numA
=
len
(
wA_error_list
)
numB
=
len
(
wB_error_list
)
total_error_list
=
wA_error_list
+
wB_error_list
sort_list
=
list
(
np
.
argsort
(
total_error_list
))
sample
=
0
samples
=
np
.
linspace
(
0
,
len
(
sort_list
)
-
2
,
frames_per_layer
[
layer
]).
astype
(
int
)
for
ii
in
sort_list
:
ii
=
int
(
ii
)
if
ii
<
numA
:
x_id
=
wA_xid_list
[
ii
]
y_id
=
wA_yid_list
[
ii
]
valid_foregrounds
=
wA_fore_list
[
ii
]
valid_alphas
=
wA_alpha_list
[
ii
]
sparam
=
wA_params
[
ii
]
tmp_foreground
,
tmp_alpha
=
get_single_stroke_on_full_image_A
(
x_id
,
y_id
,
valid_foregrounds
,
valid_alphas
,
sparam
,
original_img
,
render_size_x
,
render_size_y
,
patch_x
,
patch_y
)
else
:
x_id
=
wB_xid_list
[
ii
-
numA
]
y_id
=
wB_yid_list
[
ii
-
numA
]
valid_foregrounds
=
wB_fore_list
[
ii
-
numA
]
valid_alphas
=
wB_alpha_list
[
ii
-
numA
]
sparam
=
wB_params
[
ii
-
numA
]
tmp_foreground
,
tmp_alpha
=
get_single_stroke_on_full_image_B
(
x_id
,
y_id
,
valid_foregrounds
,
valid_alphas
,
sparam
,
original_img
,
render_size_x
,
render_size_y
,
patch_x
,
patch_y
)
final_result
=
tmp_foreground
*
tmp_alpha
+
(
1
-
tmp_alpha
)
*
final_result
if
sample
in
samples
:
saveframe
=
(
final_result
.
numpy
().
squeeze
().
transpose
([
1
,
2
,
0
])[:,
:,
::
-
1
]
*
255
).
astype
(
np
.
uint8
)
final_frame_list
.
append
(
saveframe
)
#saveframe = cv2.resize(saveframe, (ow, oh))
sample
+=
1
print
(
"layer %d cost: %.02f"
%
(
layer
,
time
.
time
()
-
t0
))
saveframe
=
(
final_result
.
numpy
().
squeeze
().
transpose
([
1
,
2
,
0
])[:,
:,
::
-
1
]
*
255
).
astype
(
np
.
uint8
)
final_frame_list
.
append
(
saveframe
)
return
final_frame_list
modules/image/Image_gan/style_transfer/paint_transformer/render_utils.py
0 → 100644
浏览文件 @
abab2ad2
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
cv2
import
numpy
as
np
from
PIL
import
Image
import
math
class
Erosion2d
(
nn
.
Layer
):
"""
Erosion2d
"""
def
__init__
(
self
,
m
=
1
):
super
(
Erosion2d
,
self
).
__init__
()
self
.
m
=
m
self
.
pad
=
[
m
,
m
,
m
,
m
]
def
forward
(
self
,
x
):
batch_size
,
c
,
h
,
w
=
x
.
shape
x_pad
=
F
.
pad
(
x
,
pad
=
self
.
pad
,
mode
=
'constant'
,
value
=
1e9
)
channel
=
nn
.
functional
.
unfold
(
x_pad
,
2
*
self
.
m
+
1
,
strides
=
1
,
paddings
=
0
).
reshape
([
batch_size
,
c
,
-
1
,
h
,
w
])
result
=
paddle
.
min
(
channel
,
axis
=
2
)
return
result
class
Dilation2d
(
nn
.
Layer
):
"""
Dilation2d
"""
def
__init__
(
self
,
m
=
1
):
super
(
Dilation2d
,
self
).
__init__
()
self
.
m
=
m
self
.
pad
=
[
m
,
m
,
m
,
m
]
def
forward
(
self
,
x
):
batch_size
,
c
,
h
,
w
=
x
.
shape
x_pad
=
F
.
pad
(
x
,
pad
=
self
.
pad
,
mode
=
'constant'
,
value
=-
1e9
)
channel
=
nn
.
functional
.
unfold
(
x_pad
,
2
*
self
.
m
+
1
,
strides
=
1
,
paddings
=
0
).
reshape
([
batch_size
,
c
,
-
1
,
h
,
w
])
result
=
paddle
.
max
(
channel
,
axis
=
2
)
return
result
def
param2stroke
(
param
,
H
,
W
,
meta_brushes
):
"""
param2stroke
"""
b
=
param
.
shape
[
0
]
param_list
=
paddle
.
split
(
param
,
8
,
axis
=
1
)
x0
,
y0
,
w
,
h
,
theta
=
[
item
.
squeeze
(
-
1
)
for
item
in
param_list
[:
5
]]
sin_theta
=
paddle
.
sin
(
math
.
pi
*
theta
)
cos_theta
=
paddle
.
cos
(
math
.
pi
*
theta
)
index
=
paddle
.
full
((
b
,
),
-
1
,
dtype
=
'int64'
).
numpy
()
index
[(
h
>
w
).
numpy
()]
=
0
index
[(
h
<=
w
).
numpy
()]
=
1
meta_brushes_resize
=
F
.
interpolate
(
meta_brushes
,
(
H
,
W
)).
numpy
()
brush
=
paddle
.
to_tensor
(
meta_brushes_resize
[
index
])
warp_00
=
cos_theta
/
w
warp_01
=
sin_theta
*
H
/
(
W
*
w
)
warp_02
=
(
1
-
2
*
x0
)
*
cos_theta
/
w
+
(
1
-
2
*
y0
)
*
sin_theta
*
H
/
(
W
*
w
)
warp_10
=
-
sin_theta
*
W
/
(
H
*
h
)
warp_11
=
cos_theta
/
h
warp_12
=
(
1
-
2
*
y0
)
*
cos_theta
/
h
-
(
1
-
2
*
x0
)
*
sin_theta
*
W
/
(
H
*
h
)
warp_0
=
paddle
.
stack
([
warp_00
,
warp_01
,
warp_02
],
axis
=
1
)
warp_1
=
paddle
.
stack
([
warp_10
,
warp_11
,
warp_12
],
axis
=
1
)
warp
=
paddle
.
stack
([
warp_0
,
warp_1
],
axis
=
1
)
grid
=
nn
.
functional
.
affine_grid
(
warp
,
[
b
,
3
,
H
,
W
])
# paddle和torch默认值是反过来的
brush
=
nn
.
functional
.
grid_sample
(
brush
,
grid
)
return
brush
def
read_img
(
img_path
,
img_type
=
'RGB'
,
h
=
None
,
w
=
None
):
"""
read img
"""
img
=
Image
.
open
(
img_path
).
convert
(
img_type
)
if
h
is
not
None
and
w
is
not
None
:
img
=
img
.
resize
((
w
,
h
),
resample
=
Image
.
NEAREST
)
img
=
np
.
array
(
img
)
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=-
1
)
img
=
img
.
transpose
((
2
,
0
,
1
))
img
=
paddle
.
to_tensor
(
img
).
unsqueeze
(
0
).
astype
(
'float32'
)
/
255.
return
img
def
preprocess
(
img
,
w
=
512
,
h
=
512
):
image
=
cv2
.
resize
(
img
,
(
w
,
h
),
cv2
.
INTER_NEAREST
)
image
=
image
.
transpose
((
2
,
0
,
1
))
image
=
paddle
.
to_tensor
(
image
).
unsqueeze
(
0
).
astype
(
'float32'
)
/
255.
return
image
def
totensor
(
img
):
image
=
img
.
transpose
((
2
,
0
,
1
))
image
=
paddle
.
to_tensor
(
image
).
unsqueeze
(
0
).
astype
(
'float32'
)
/
255.
return
image
def
pad
(
img
,
H
,
W
):
b
,
c
,
h
,
w
=
img
.
shape
pad_h
=
(
H
-
h
)
//
2
pad_w
=
(
W
-
w
)
//
2
remainder_h
=
(
H
-
h
)
%
2
remainder_w
=
(
W
-
w
)
%
2
expand_img
=
nn
.
functional
.
pad
(
img
,
[
pad_w
,
pad_w
+
remainder_w
,
pad_h
,
pad_h
+
remainder_h
])
return
expand_img
modules/image/Image_gan/style_transfer/paint_transformer/requirements.txt
0 → 100644
浏览文件 @
abab2ad2
ppgan
modules/image/Image_gan/style_transfer/paint_transformer/util.py
0 → 100644
浏览文件 @
abab2ad2
import
base64
import
cv2
import
numpy
as
np
def
base64_to_cv2
(
b64str
):
data
=
base64
.
b64decode
(
b64str
.
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录