Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
156cb03f
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
156cb03f
编写于
4月 14, 2020
作者:
L
LielinJiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add md and py
上级
27c66cfa
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
787 addition
and
0 deletion
+787
-0
tutorials/style-transfer/README.md
tutorials/style-transfer/README.md
+38
-0
tutorials/style-transfer/style-transfer.ipynb
tutorials/style-transfer/style-transfer.ipynb
+575
-0
tutorials/style-transfer/style-transfer.py
tutorials/style-transfer/style-transfer.py
+174
-0
未找到文件。
tutorials/style-transfer/README.md
0 → 100644
浏览文件 @
156cb03f
# 图像风格迁移
图像的风格迁移是卷积神经网络有趣的应用之一。那什么是风格迁移呢?下图第一列左边的图为相机拍摄的一张普通图片,右边的图为梵高的名画星空。那如何让左边的普通图片拥有星空的风格呢。神经网络的风格迁移就可以帮助你生成第二列的这样的图片。
<div
align=
center
>
<img
src=
"images/markdown/img1.png"
width =
"600"
height =
"300"
/>
</br>
<img
src=
"images/markdown/img2.png"
width =
"300"
height =
"300"
divalign=
center
/>
<div
align=
left
>
## 基本原理
风格迁移的目标就是使得生成图片的内容与内容图片(content image)尽可能相似。由于在计算机中,我们用一个一个像素点表示图片,所以两个图片的相似程度我们可以用每个像素点的欧式距离来表示。而两个图片的风格相似度,我们采用两个图片在卷积神经网络中相同的一层特征图的gram矩阵的欧式距离来表示。对于一个特征图gram矩阵的计算如下所示:
```
python
# tensor shape is [1, c, h, w]
_
,
c
,
h
,
w
=
tensor
.
shape
tensor
=
fluid
.
layers
.
reshape
(
c
,
h
*
w
)
# gram matrix with shape: [c, c]
gram_matrix
=
fluid
.
layers
.
matmul
(
tensor
,
fluid
.
layers
.
transpose
(
tensor
,
[
1
,
0
]))
```
最终风格迁移的问题转化为优化上述的两个欧式距离的问题。这里要注意的是,我们使用一个在imagenet上预训练好的模型vgg16,并且固定参数,优化器只更新输入的生成图像的值。
## 风格迁移
执行如下命令,就可以进行风格迁移。生成的图像会保存在
```--save-dir```
中。
```
python
python
-
u
style
-
transfer
.
py
--
content
-
image
/
path
/
to
/
your
-
content
-
image
--
style
-
image
/
path
/
to
/
your
-
content
-
image
--
save
-
dir
/
path
/
to
/
your
-
output
-
dir
```
具体的生成过程也可以参考
[
style-transfer.ipynb
](
./hapi-style-transfer.ipynb
)
## 参考文献
[
A Neural Algorithm of Artistic Style
](
https://arxiv.org/abs/1508.06576
)
tutorials/style-transfer/style-transfer.ipynb
0 → 100644
浏览文件 @
156cb03f
此差异已折叠。
点击以展开。
tutorials/style-transfer/style-transfer.py
0 → 100644
浏览文件 @
156cb03f
import
os
import
argparse
import
numpy
as
np
import
matplotlib.pyplot
as
plt
from
hapi.model
import
Model
,
Loss
from
hapi.vision.models
import
vgg16
from
hapi.vision.transform
import
transforms
from
paddle
import
fluid
from
paddle.fluid.io
import
Dataset
import
cv2
import
copy
def
load_image
(
image_path
,
max_size
=
400
,
shape
=
None
):
image
=
cv2
.
imread
(
image_path
)
image
=
image
.
astype
(
'float32'
)
/
255.0
size
=
shape
if
shape
is
not
None
else
max_size
if
max
(
image
.
shape
[:
2
])
>
max_size
else
max
(
image
.
shape
[:
2
])
transform
=
transforms
.
Compose
([
transforms
.
Resize
(
size
),
transforms
.
Permute
(),
transforms
.
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
])
])
image
=
transform
(
image
)[
np
.
newaxis
,
:
3
,
:,
:]
image
=
fluid
.
dygraph
.
to_variable
(
image
)
return
image
def
image_restore
(
image
):
image
=
np
.
squeeze
(
image
.
numpy
(),
0
)
image
=
image
.
transpose
(
1
,
2
,
0
)
image
=
image
*
np
.
array
((
0.229
,
0.224
,
0.225
))
+
np
.
array
(
(
0.485
,
0.456
,
0.406
))
image
=
image
.
clip
(
0
,
1
)
return
image
class
StyleTransferModel
(
Model
):
def
__init__
(
self
):
super
(
StyleTransferModel
,
self
).
__init__
()
# pretrained设置为true,会自动下载imagenet上的预训练权重并加载
vgg
=
vgg16
(
pretrained
=
True
)
self
.
base_model
=
vgg
.
features
for
p
in
self
.
base_model
.
parameters
():
p
.
stop_gradient
=
True
self
.
layers
=
{
'0'
:
'conv1_1'
,
'3'
:
'conv2_1'
,
'6'
:
'conv3_1'
,
'10'
:
'conv4_1'
,
'11'
:
'conv4_2'
,
## content representation
'14'
:
'conv5_1'
}
def
forward
(
self
,
image
):
outputs
=
[]
for
name
,
layer
in
self
.
base_model
.
named_sublayers
():
image
=
layer
(
image
)
if
name
in
self
.
layers
:
outputs
.
append
(
image
)
return
outputs
class
StyleTransferLoss
(
Loss
):
def
__init__
(
self
,
content_loss_weight
=
1
,
style_loss_weight
=
1e5
,
style_weights
=
[
1.0
,
0.8
,
0.5
,
0.3
,
0.1
]):
super
(
StyleTransferLoss
,
self
).
__init__
()
self
.
content_loss_weight
=
content_loss_weight
self
.
style_loss_weight
=
style_loss_weight
self
.
style_weights
=
style_weights
def
forward
(
self
,
outputs
,
labels
):
content_features
=
labels
[
-
1
]
style_features
=
labels
[:
-
1
]
# 计算图像内容相似度的loss
content_loss
=
fluid
.
layers
.
mean
((
outputs
[
-
2
]
-
content_features
)
**
2
)
# 计算风格相似度的loss
style_loss
=
0
style_grams
=
[
self
.
gram_matrix
(
feat
)
for
feat
in
style_features
]
style_weights
=
self
.
style_weights
for
i
,
weight
in
enumerate
(
style_weights
):
target_gram
=
self
.
gram_matrix
(
outputs
[
i
])
layer_loss
=
weight
*
fluid
.
layers
.
mean
((
target_gram
-
style_grams
[
i
])
**
2
)
b
,
d
,
h
,
w
=
outputs
[
i
].
shape
style_loss
+=
layer_loss
/
(
d
*
h
*
w
)
total_loss
=
self
.
content_loss_weight
*
content_loss
+
self
.
style_loss_weight
*
style_loss
return
total_loss
def
gram_matrix
(
self
,
A
):
if
len
(
A
.
shape
)
==
4
:
_
,
c
,
h
,
w
=
A
.
shape
A
=
fluid
.
layers
.
reshape
(
A
,
(
c
,
h
*
w
))
GA
=
fluid
.
layers
.
matmul
(
A
,
fluid
.
layers
.
transpose
(
A
,
[
1
,
0
]))
return
GA
def
main
():
# 启动动态图模式
fluid
.
enable_dygraph
()
content
=
load_image
(
FLAGS
.
content_image
)
style
=
load_image
(
FLAGS
.
style_image
,
shape
=
tuple
(
content
.
shape
[
-
2
:]))
model
=
StyleTransferModel
()
style_loss
=
StyleTransferLoss
()
# 使用内容图像初始化要生成的图像
target
=
Model
.
create_parameter
(
model
,
shape
=
content
.
shape
)
target
.
set_value
(
content
.
numpy
())
optimizer
=
fluid
.
optimizer
.
Adam
(
parameter_list
=
[
target
],
learning_rate
=
FLAGS
.
lr
)
model
.
prepare
(
optimizer
,
style_loss
)
content_fetures
=
model
.
test
(
content
)
style_features
=
model
.
test
(
style
)
# 将两个特征组合,作为损失函数的label传给模型
feats
=
style_features
+
[
content_fetures
[
-
2
]]
# 训练5000个step,每500个step画一下生成的图像查看效果
steps
=
FLAGS
.
steps
for
i
in
range
(
steps
):
outs
=
model
.
train
(
target
,
feats
)
if
i
%
500
==
0
:
print
(
'iters:'
,
i
,
'loss:'
,
outs
[
0
][
0
])
if
not
os
.
path
.
exists
(
FLAGS
.
save_dir
):
os
.
makedirs
(
FLAGS
.
save_dir
)
# 保存生成好的图像
name
=
FLAGS
.
content_image
.
split
(
os
.
sep
)[
-
1
]
output_path
=
os
.
path
.
join
(
FLAGS
.
save_dir
,
'generated_'
+
name
)
cv2
.
imwrite
(
output_path
,
cv2
.
cvtColor
((
image_restore
(
target
)
*
255
).
astype
(
'uint8'
),
cv2
.
COLOR_RGB2BGR
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
"Resnet Training on ImageNet"
)
parser
.
add_argument
(
"--content-image"
,
type
=
str
,
default
=
'./images/chicago_cropped.jpg'
,
help
=
"content image"
)
parser
.
add_argument
(
"--style-image"
,
type
=
str
,
default
=
'./images/Starry-Night-by-Vincent-Van-Gogh-painting.jpg'
,
help
=
"style image"
)
parser
.
add_argument
(
"--save-dir"
,
type
=
str
,
default
=
'./output'
,
help
=
"output dir"
)
parser
.
add_argument
(
"--steps"
,
default
=
5000
,
type
=
int
,
help
=
"number of steps to run"
)
parser
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
1e-3
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate'
)
FLAGS
=
parser
.
parse_args
()
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录