提交 156cb03f 编写于 作者: L LielinJiang

add md and py

上级 27c66cfa
# 图像风格迁移
图像的风格迁移是卷积神经网络有趣的应用之一。那什么是风格迁移呢?下图第一列左边的图为相机拍摄的一张普通图片,右边的图为梵高的名画星空。那如何让左边的普通图片拥有星空的风格呢。神经网络的风格迁移就可以帮助你生成第二列的这样的图片。
<div align=center>
<img src="images/markdown/img1.png" width = "600" height = "300" />
</br>
<img src="images/markdown/img2.png" width = "300" height = "300" divalign=center />
<div align=left>
## 基本原理
风格迁移的目标就是使得生成图片的内容与内容图片(content image)尽可能相似。由于在计算机中,我们用一个一个像素点表示图片,所以两个图片的相似程度我们可以用每个像素点的欧式距离来表示。而两个图片的风格相似度,我们采用两个图片在卷积神经网络中相同的一层特征图的gram矩阵的欧式距离来表示。对于一个特征图gram矩阵的计算如下所示:
```python
# tensor shape is [1, c, h, w]
_, c, h, w = tensor.shape
tensor = fluid.layers.reshape(c, h * w)
# gram matrix with shape: [c, c]
gram_matrix = fluid.layers.matmul(tensor, fluid.layers.transpose(tensor, [1, 0]))
```
最终风格迁移的问题转化为优化上述的两个欧式距离的问题。这里要注意的是,我们使用一个在imagenet上预训练好的模型vgg16,并且固定参数,优化器只更新输入的生成图像的值。
## 风格迁移
执行如下命令,就可以进行风格迁移。生成的图像会保存在```--save-dir```中。
```python
python -u style-transfer.py --content-image /path/to/your-content-image --style-image /path/to/your-content-image --save-dir /path/to/your-output-dir
```
具体的生成过程也可以参考[style-transfer.ipynb](./hapi-style-transfer.ipynb)
## 参考文献
[A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576)
因为 它太大了无法显示 source diff 。你可以改为 查看blob
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from hapi.model import Model, Loss
from hapi.vision.models import vgg16
from hapi.vision.transform import transforms
from paddle import fluid
from paddle.fluid.io import Dataset
import cv2
import copy
def load_image(image_path, max_size=400, shape=None):
image = cv2.imread(image_path)
image = image.astype('float32') / 255.0
size = shape if shape is not None else max_size if max(
image.shape[:2]) > max_size else max(image.shape[:2])
transform = transforms.Compose([
transforms.Resize(size), transforms.Permute(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = transform(image)[np.newaxis, :3, :, :]
image = fluid.dygraph.to_variable(image)
return image
def image_restore(image):
image = np.squeeze(image.numpy(), 0)
image = image.transpose(1, 2, 0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array(
(0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
class StyleTransferModel(Model):
def __init__(self):
super(StyleTransferModel, self).__init__()
# pretrained设置为true,会自动下载imagenet上的预训练权重并加载
vgg = vgg16(pretrained=True)
self.base_model = vgg.features
for p in self.base_model.parameters():
p.stop_gradient = True
self.layers = {
'0': 'conv1_1',
'3': 'conv2_1',
'6': 'conv3_1',
'10': 'conv4_1',
'11': 'conv4_2', ## content representation
'14': 'conv5_1'
}
def forward(self, image):
outputs = []
for name, layer in self.base_model.named_sublayers():
image = layer(image)
if name in self.layers:
outputs.append(image)
return outputs
class StyleTransferLoss(Loss):
def __init__(self,
content_loss_weight=1,
style_loss_weight=1e5,
style_weights=[1.0, 0.8, 0.5, 0.3, 0.1]):
super(StyleTransferLoss, self).__init__()
self.content_loss_weight = content_loss_weight
self.style_loss_weight = style_loss_weight
self.style_weights = style_weights
def forward(self, outputs, labels):
content_features = labels[-1]
style_features = labels[:-1]
# 计算图像内容相似度的loss
content_loss = fluid.layers.mean((outputs[-2] - content_features)**2)
# 计算风格相似度的loss
style_loss = 0
style_grams = [self.gram_matrix(feat) for feat in style_features]
style_weights = self.style_weights
for i, weight in enumerate(style_weights):
target_gram = self.gram_matrix(outputs[i])
layer_loss = weight * fluid.layers.mean((target_gram - style_grams[
i])**2)
b, d, h, w = outputs[i].shape
style_loss += layer_loss / (d * h * w)
total_loss = self.content_loss_weight * content_loss + self.style_loss_weight * style_loss
return total_loss
def gram_matrix(self, A):
if len(A.shape) == 4:
_, c, h, w = A.shape
A = fluid.layers.reshape(A, (c, h * w))
GA = fluid.layers.matmul(A, fluid.layers.transpose(A, [1, 0]))
return GA
def main():
# 启动动态图模式
fluid.enable_dygraph()
content = load_image(FLAGS.content_image)
style = load_image(FLAGS.style_image, shape=tuple(content.shape[-2:]))
model = StyleTransferModel()
style_loss = StyleTransferLoss()
# 使用内容图像初始化要生成的图像
target = Model.create_parameter(model, shape=content.shape)
target.set_value(content.numpy())
optimizer = fluid.optimizer.Adam(
parameter_list=[target], learning_rate=FLAGS.lr)
model.prepare(optimizer, style_loss)
content_fetures = model.test(content)
style_features = model.test(style)
# 将两个特征组合,作为损失函数的label传给模型
feats = style_features + [content_fetures[-2]]
# 训练5000个step,每500个step画一下生成的图像查看效果
steps = FLAGS.steps
for i in range(steps):
outs = model.train(target, feats)
if i % 500 == 0:
print('iters:', i, 'loss:', outs[0][0])
if not os.path.exists(FLAGS.save_dir):
os.makedirs(FLAGS.save_dir)
# 保存生成好的图像
name = FLAGS.content_image.split(os.sep)[-1]
output_path = os.path.join(FLAGS.save_dir, 'generated_' + name)
cv2.imwrite(output_path,
cv2.cvtColor((image_restore(target) * 255).astype('uint8'),
cv2.COLOR_RGB2BGR))
if __name__ == '__main__':
parser = argparse.ArgumentParser("Resnet Training on ImageNet")
parser.add_argument(
"--content-image",
type=str,
default='./images/chicago_cropped.jpg',
help="content image")
parser.add_argument(
"--style-image",
type=str,
default='./images/Starry-Night-by-Vincent-Van-Gogh-painting.jpg',
help="style image")
parser.add_argument(
"--save-dir", type=str, default='./output', help="output dir")
parser.add_argument(
"--steps", default=5000, type=int, help="number of steps to run")
parser.add_argument(
'--lr',
'--learning-rate',
default=1e-3,
type=float,
metavar='LR',
help='initial learning rate')
FLAGS = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册