提交 1f1ab47d 编写于 作者: X xiaoting 提交者: Hongyu Liu

add cycle gan for dygraph (#2362)

* add cycle_gan for dygraph

* upload download

* polish code

* fix bug

* fix some erro

* add infer

* refine code

* refine code

* matching bn

* add infer images

* update readme

* update readme

* replace pics

* Update README.md

* modified for conv2d
上级 ccb40836
# Cycle GAN
---
## 内容
- [安装](#安装)
- [简介](#简介)
- [代码结构](#代码结构)
- [数据准备](#数据准备)
- [模型训练与预测](#模型训练与预测)
## 安装
运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
## 简介
Cycle GAN 是一种image to image 的图像生成网络,实现了非对称图像数据集的生成和风格迁移。模型结构如下图所示,我们的模型包含两个生成网络 G: X → Y 和 F: Y → X,以及相关的判别器 DY 和 DX 。通过训练DY,使G将X图尽量转换为Y图,反之亦然。同时引入两个“周期一致性损失”,它们保证:如果我们从一个领域转换到另一个领域,它还可以被转换回去:(b)正向循环一致性损失:x→G(x)→F(G(x))≈x, (c)反向循环一致性损失:y→F(y)→G(F(y))≈y
<p align="center">
<img src="image/net.png" hspace='10'/> <br />
图1.网络结构
</p>
## 代码结构
```
├── data_reader.py # 读取、处理数据。
├── layers.py # 封装定义基础的layers。
├── model.py # 定义基础生成网络和判别网络。
├── trainer.py # 构造loss和训练网络。
├── train.py # 训练脚本。
└── infer.py # 预测脚本。
```
## 数据准备
本教程使用 cityscapes 数据集 来进行模型的训练测试工作,可以通过指定 `python download.py --dataset cityscapes` 下载得到。
cityscapes 训练集包含2975张街景实拍图片,2975张对应真实街景的语义分割图片。测试集包含499张实拍图片和499张语义分割图片。
数据下载处理完毕后,并组织为以下路径结构:
```
data
|-- cityscapes
| |-- testA
| |-- testA.txt
| |-- testB
| |-- testB.txt
| |-- trainA
| |-- trainA.txt
| |-- trainB
| `-- trainB.txt
```
以上数据文件中,`data`文件夹需要放在训练脚本`train.py`同级目录下。`testA`为存放真实街景图片的文件夹,`testB`为存放语义分割图片的文件夹,`testA.txt``testB.txt`分别为测试图片路径列表文件,格式如下:
```
data/cityscapes/testA/234_A.jpg
data/cityscapes/testA/292_A.jpg
data/cityscapes/testA/412_A.jpg
```
训练数据组织方式与测试数据相同。
## 模型训练与预测
### 训练
在GPU单卡上训练:
```
env CUDA_VISIBLE_DEVICES=0 python train.py
```
执行`python train.py --help`可查看更多使用方式和参数详细说明。
图1为训练152轮的训练损失示意图,其中横坐标轴为训练轮数,纵轴为在训练集上的损失。其中,'g_loss','d_A_loss'和'd_B_loss'分别为生成器、判别器A和判别器B的训练损失。
todo:loss曲线
### 测试
执行以下命令可以选择已保存的训练权重,对测试集进行测试,通过 `--epoch` 制定权重轮次:
```
env CUDA_VISIBLE_DEVICES=0 python test.py --epoch=200
```
### 预测
执行以下命令读取单张或多张图片进行预测:
真实街景生成分割图像:
```
env CUDA_VISIBLE_DEVICES=0 python infer.py \
--init_model="./G/199" --input="./image/testA/123_A.jpg" \
--input_style=A
```
分割图像生成真实街景:
```
env CUDA_VISIBLE_DEVICES=0 python infer.py \
--init_model="./G/199" --input="./image/testB/78_B.jpg" \
--input_style=B
```
训练180轮的模型预测效果如fakeA和fakeB所示:
<p align="center">
<img src="image/A2B.png" width="620" hspace='10'/> <br/>
<strong>A2B</strong>
</p>
<p align="center">
<img src="image/B2A.png" width="620" hspace='10'/> <br/>
<strong>B2A</strong>
</p>
>在本文示例中,均可通过修改`CUDA_VISIBLE_DEVICES`改变使用的显卡号。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from PIL import Image, ImageOps
import numpy as np
###A_LIST_FILE = "./train_data/trainA.txt"
###B_LIST_FILE = "./train_data/trainB.txt"
###A_TEST_LIST_FILE = "./train_data/testA.txt"
###B_TEST_LIST_FILE = "./train_data/testB.txt"
###IMAGES_ROOT = "./train_data/"
A_LIST_FILE = "./data/cityscapes/trainA.txt"
B_LIST_FILE = "./data/cityscapes/trainB.txt"
A_TEST_LIST_FILE = "./data/cityscapes/testA.txt"
B_TEST_LIST_FILE = "./data/cityscapes/testB.txt"
IMAGES_ROOT = "./data/cityscapes/"
def image_shape():
return [3, 256, 256]
def max_images_num():
return 2974
def reader_creater(list_file, cycle=True, shuffle=True, return_name=False):
images = [IMAGES_ROOT + line for line in open(list_file, 'r').readlines()]
def reader():
while True:
if shuffle:
np.random.shuffle(images)
for file in images:
file = file.strip("\n\r\t ")
image = Image.open(file)
## Resize
image = image.resize((286, 286), Image.BICUBIC)
## RandomCrop
i = np.random.randint(0, 30)
j = np.random.randint(0, 30)
image = image.crop((i, j , i+256, j+256))
# RandomHorizontalFlip
sed = np.random.rand()
if sed > 0.5:
image = ImageOps.mirror(image)
# ToTensor
image = np.array(image).transpose([2, 0, 1]).astype('float32')
image = image / 255.0
# Normalize, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
image = (image - 0.5) / 0.5
if return_name:
yield image[np.newaxis, :], os.path.basename(file)
else:
yield image
if not cycle:
break
return reader
def a_reader(shuffle=True):
"""
Reader of images with A style for training.
"""
return reader_creater(A_LIST_FILE, shuffle=shuffle)
def b_reader(shuffle=True):
"""
Reader of images with B style for training.
"""
return reader_creater(B_LIST_FILE, shuffle=shuffle)
def a_test_reader():
"""
Reader of images with A style for test.
"""
return reader_creater(A_TEST_LIST_FILE, cycle=False, return_name=True)
def b_test_reader():
"""
Reader of images with B style for test.
"""
return reader_creater(B_TEST_LIST_FILE, cycle=False, return_name=True)
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import print_function
from PIL import Image
import numpy as np
import os
import sys
import gzip
import zipfile
import argparse
import requests
import six
import hashlib
parser = argparse.ArgumentParser(description='Download dataset.')
#TODO add celeA dataset
parser.add_argument(
'--dataset',
type=str,
default='mnist',
help='name of dataset to download [mnist]')
def md5file(fname):
hash_md5 = hashlib.md5()
f = open(fname, "rb")
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
f.close()
return hash_md5.hexdigest()
def download_mnist(dir_path):
URL_DIC = {}
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
URL_DIC[TRAIN_IMAGE_URL] = TRAIN_IMAGE_MD5
URL_DIC[TRAIN_LABEL_URL] = TRAIN_LABEL_MD5
URL_DIC[TEST_IMAGE_URL] = TEST_IMAGE_MD5
URL_DIC[TEST_LABEL_URL] = TEST_LABEL_MD5
### print(url)
for url in URL_DIC:
md5sum = URL_DIC[url]
data_dir = os.path.join(dir_path + 'mnist')
if not os.path.exists(data_dir):
os.makedirs(data_dir)
filename = os.path.join(data_dir, url.split('/')[-1])
retry = 0
retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if os.path.exists(filename):
sys.stderr.write("file %s md5 %s" %
(md5file(filename), md5sum))
if retry < retry_limit:
retry += 1
else:
raise RuntimeError("Cannot download {0} within retry limit {1}".
format(url, retry_limit))
sys.stderr.write("Cache file %s not found, downloading %s" %
(filename, url))
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(filename, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(filename, 'wb') as f:
dl = 0
total_length = int(total_length)
for data in r.iter_content(chunk_size=4096):
if six.PY2:
data = six.b(data)
dl += len(data)
f.write(data)
done = int(50 * dl / total_length)
sys.stderr.write("\r[%s%s]" % ('=' * done,
' ' * (50 - done)))
sys.stdout.flush()
sys.stderr.write("\n")
sys.stdout.flush()
print(filename)
def download_cycle_pix(dir_path, dataname):
URL_PREFIX = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/'
IMAGE_URL = '{}.zip'.format(dataname)
url = URL_PREFIX + IMAGE_URL
if not os.path.exists(dir_path):
os.makedirs(dir_path)
r = requests.get(url, stream=True)
total_length = float(r.headers.get('content-length'))
filename = os.path.join(dir_path, IMAGE_URL)
print(filename)
if not os.path.exists(filename):
dl = 0
with open(filename, "wb") as f:
for data in r.iter_content(chunk_size=4096):
if six.PY2:
data = six.b(data)
dl += len(data)
f.write(data)
done = int(100 * dl / total_length)
sys.stderr.write("\r[{}{}] {}% ".format('=' * done, ' ' * (
100 - done), done))
sys.stdout.flush()
else:
sys.stderr.write('{}.zip is EXIST, DO NOT NEED to download it again.'.
format(dataname))
### unzip .zip file
if not os.path.exists(os.path.join(dir_path, '{}'.format(dataname))):
zip_f = zipfile.ZipFile(filename, 'r')
for zip_file in zip_f.namelist():
zip_f.extract(zip_file, dir_path)
### generator .txt file according to dirs
dirs = os.listdir(os.path.join(dir_path, '{}'.format(dataname)))
for d in dirs:
txt_file = d + '.txt'
txt_dir = os.path.join(dir_path, dataname)
f = open(os.path.join(txt_dir, txt_file), 'w')
for fil in os.listdir(os.path.join(txt_dir, d)):
wl = d + '/' + fil + '\n'
f.write(wl)
f.close()
sys.stderr.write("\n")
if __name__ == '__main__':
args = parser.parse_args()
cycle_pix_dataset = [
'apple2orange', 'summer2winter_yosemite', 'horse2zebra', 'monet2photo',
'cezanne2photo', 'ukiyoe2photo', 'vangogh2photo', 'maps', 'cityscapes',
'facades', 'iphone2dslr_flower', 'ae_photos', 'mini'
]
if args.dataset == 'mnist':
print('Download dataset: {}'.format(args.dataset))
download_mnist('./data/')
elif args.dataset in cycle_pix_dataset:
print('Download dataset: {}'.format(args.dataset))
download_cycle_pix('./data/', args.dataset)
else:
print('Please download by yourself, thanks')
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
import paddle
import argparse
import functools
import time
import numpy as np
import glob
from PIL import Image
from scipy.misc import imsave
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import data_reader
from utility import add_arguments, print_arguments, ImagePool
from trainer import *
from paddle.fluid.dygraph.base import to_variable
import six
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('input', str, "123_A.jpg", "input image")
add_arg('output', str, "./output_0", "The directory the model and the test result to be saved to.")
add_arg('init_model', str, './G/150', "The init model file of directory.")
add_arg('input_style', str, "A", "A or B")
def infer():
with fluid.dygraph.guard():
data_shape = [-1,3,256,256]
out_path = args.output + "/single" + "/" + str(args.input)
if not os.path.exists(out_path):
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
save_dir = args.init_model
restore = fluid.dygraph.load_persistables(save_dir)
cycle_gan.load_dict(restore)
cycle_gan.eval()
for file in glob.glob(args.input):
print ("read %s" % file)
image_name = os.path.basename(file)
image = Image.open(file).convert('RGB')
image = image.resize((256, 256), Image.BICUBIC)
image = np.array(image) / 127.5 - 1
image = image[:, :, 0:3].astype("float32")
data = image.transpose([2, 0, 1])[np.newaxis,:]
data_A_tmp = to_variable(data)
fake_A_temp,fake_B_temp,cyc_A_temp,cyc_B_temp,g_A_loss,g_B_loss,idt_loss_A,idt_loss_B,cyc_A_loss,cyc_B_loss,g_loss = cycle_gan(data_A_tmp,data_A_tmp,True,False,False)
fake_A_temp = np.squeeze(fake_A_temp.numpy()[0]).transpose([1, 2, 0])
fake_B_temp = np.squeeze(fake_B_temp.numpy()[0]).transpose([1, 2, 0])
if args.input_style == "A":
imsave(out_path + "/fakeB_" + image_name, (
(fake_B_temp + 1) * 127.5).astype(np.uint8))
if args.input_style == "B":
imsave(out_path + "/fakeA_" + image_name, (
(fake_A_temp + 1) * 127.5).astype(np.uint8))
if __name__ == "__main__":
args = parser.parse_args()
print_arguments(args)
infer()
from __future__ import division
import paddle.fluid as fluid
import numpy as np
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose , BatchNorm ,Pool2D
import os
# cudnn is not better when batch size is 1.
use_cudnn = False
class conv2d(fluid.dygraph.Layer):
"""docstring for Conv2D"""
def __init__(self,
name_scope,
num_filters=64,
filter_size=7,
stride=1,
stddev=0.02,
padding=0,
norm=True,
relu=True,
relufactor=0.0,
use_bias=False):
super(conv2d, self).__init__(name_scope)
if use_bias == False:
con_bias_attr = False
else:
con_bias_attr = fluid.ParamAttr(name="conv_bias",initializer=fluid.initializer.Constant(0.0))
self.conv = Conv2D(
self.full_name(),
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
use_cudnn=use_cudnn,
param_attr=fluid.ParamAttr(
name="conv2d_weights",
initializer=fluid.initializer.NormalInitializer(loc=0.0,scale=stddev)),
bias_attr=con_bias_attr)
if norm:
self.bn = BatchNorm(self.full_name(),
num_channels=num_filters,
param_attr=fluid.ParamAttr(
name="scale",
initializer=fluid.initializer.NormalInitializer(1.0,0.02)),
bias_attr=fluid.ParamAttr(
name="bias",
initializer=fluid.initializer.Constant(0.0)),
trainable_statistics=True
)
self.relufactor = relufactor
self.use_bias = use_bias
self.norm = norm
self.relu = relu
def forward(self,inputs):
conv = self.conv(inputs)
if self.norm:
conv = self.bn(conv)
if self.relu:
conv = fluid.layers.leaky_relu(conv,alpha=self.relufactor)
return conv
class DeConv2D(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_filters=64,
filter_size=7,
stride=1,
stddev=0.02,
padding=[0,0],
outpadding=[0,0,0,0],
relu=True,
norm=True,
relufactor=0.0,
use_bias=False
):
super(DeConv2D,self).__init__(name_scope)
if use_bias == False:
de_bias_attr = False
else:
de_bias_attr = fluid.ParamAttr(name="de_bias",initializer=fluid.initializer.Constant(0.0))
self._deconv = Conv2DTranspose(self.full_name(),
num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=fluid.ParamAttr(
name="this_is_deconv_weights",
initializer=fluid.initializer.NormalInitializer(loc=0.0, scale=stddev)),
bias_attr=de_bias_attr)
if norm:
self.bn = BatchNorm(self.full_name(),
num_channels=num_filters,
param_attr=fluid.ParamAttr(
name="de_wights",
initializer=fluid.initializer.NormalInitializer(1.0, 0.02)),
bias_attr=fluid.ParamAttr(name="de_bn_bias",initializer=fluid.initializer.Constant(0.0))
trainable_statistics=True)
self.outpadding = outpadding
self.relufactor = relufactor
self.use_bias = use_bias
self.norm = norm
self.relu = relu
def forward(self,inputs):
#todo: add use_bias
#if self.use_bias==False:
conv = self._deconv(inputs)
#else:
# conv = self._deconv(inputs)
conv = fluid.layers.pad2d(conv, paddings=self.outpadding, mode='constant', pad_value=0.0)
if self.norm:
conv = self.bn(conv)
if self.relu:
conv = fluid.layers.leaky_relu(conv,alpha=self.relufactor)
return conv
from layers import *
import paddle.fluid as fluid
class build_resnet_block(fluid.dygraph.Layer):
def __init__(self,
name_scope,
dim,
use_bias=False):
super(build_resnet_block,self).__init__(name_scope)
self.conv0 = conv2d(self.full_name(),
num_filters=dim,
filter_size=3,
stride=1,
stddev=0.02,
use_bias=False)
self.conv1 = conv2d(self.full_name(),
num_filters=dim,
filter_size=3,
stride=1,
stddev=0.02,
relu=False,
use_bias=False)
self.dim = dim
def forward(self,inputs):
out_res = fluid.layers.pad2d(inputs, [1, 1, 1, 1], mode="reflect")
out_res = self.conv0(out_res)
#if self.use_dropout:
# out_res = fluid.layers.dropout(out_res,dropout_prod=0.5)
out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect")
out_res = self.conv1(out_res)
return out_res + inputs
class build_generator_resnet_9blocks(fluid.dygraph.Layer):
def __init__ (self,
name_scope):
super(build_generator_resnet_9blocks,self).__init__(name_scope)
self.conv0 = conv2d(self.full_name(),
num_filters=32,
filter_size=7,
stride=1,
padding=0,
stddev=0.02)
self.conv1 = conv2d(self.full_name(),
num_filters=64,
filter_size=3,
stride=2,
padding=1,
stddev=0.02)
self.conv2 = conv2d(self.full_name(),
num_filters=128,
filter_size=3,
stride=2,
padding=1,
stddev=0.02)
self.build_resnet_block_list=[]
dim = 32*4
for i in range(9):
Build_Resnet_Block = self.add_sublayer(
"generator_%d" % (i+1),
build_resnet_block(self.full_name(),
128))
self.build_resnet_block_list.append(Build_Resnet_Block)
self.deconv0 = DeConv2D(self.full_name(),
num_filters=32*2,
filter_size=3,
stride=2,
stddev=0.02,
padding=[1, 1],
outpadding=[0, 1, 0, 1],
)
self.deconv1 = DeConv2D(self.full_name(),
num_filters=32,
filter_size=3,
stride=2,
stddev=0.02,
padding=[1, 1],
outpadding=[0, 1, 0, 1])
self.conv3 = conv2d(self.full_name(),
num_filters=3,
filter_size=7,
stride=1,
stddev=0.02,
padding=0,
relu=False,
norm=False,
use_bias=True)
def forward(self,inputs):
pad_input = fluid.layers.pad2d(inputs, [3, 3, 3, 3], mode="reflect")
y = self.conv0(pad_input)
y = self.conv1(y)
y = self.conv2(y)
for build_resnet_block_i in self.build_resnet_block_list:
y = build_resnet_block_i(y)
y = self.deconv0(y)
y = self.deconv1(y)
y = fluid.layers.pad2d(y,[3,3,3,3],mode="reflect")
y = self.conv3(y)
y = fluid.layers.tanh(y)
return y
class build_gen_discriminator(fluid.dygraph.Layer):
def __init__(self,name_scope):
super(build_gen_discriminator,self).__init__(name_scope)
self.conv0 = conv2d(self.full_name(),
num_filters=64,
filter_size=4,
stride=2,
stddev=0.02,
padding=1,
norm=False,
use_bias=True,
relufactor=0.2)
self.conv1 = conv2d(self.full_name(),
num_filters=128,
filter_size=4,
stride=2,
stddev=0.02,
padding=1,
relufactor=0.2)
self.conv2 = conv2d(self.full_name(),
num_filters=256,
filter_size=4,
stride=2,
stddev=0.02,
padding=1,
relufactor=0.2)
self.conv3 = conv2d(self.full_name(),
num_filters=512,
filter_size=4,
stride=1,
stddev=0.02,
padding=1,
relufactor=0.2)
self.conv4 = conv2d(self.full_name(),
num_filters=1,
filter_size=4,
stride=1,
stddev=0.02,
padding=1,
norm=False,
relu=False,
use_bias=True)
def forward(self,inputs):
y = self.conv0(inputs)
y = self.conv1(y)
y = self.conv2(y)
y = self.conv3(y)
y = self.conv4(y)
return y
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
import paddle
import argparse
import functools
import time
import numpy as np
from scipy.misc import imsave
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import data_reader
from utility import add_arguments, print_arguments, ImagePool
from trainer import *
from paddle.fluid.dygraph.base import to_variable
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('epoch', int, None, "The number of weights to be testes.")
add_arg('output', str, "./output_0", "The directory the model and the test result to be saved to.")
add_arg('init_model', str, './G/', "The init model file of directory.")
def test():
with fluid.dygraph.guard():
A_test_reader = data_reader.a_test_reader()
B_test_reader = data_reader.b_test_reader()
epoch = args.epoch
out_path = args.output + "/eval" + "/" + str(epoch)
if not os.path.exists(out_path):
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
save_dir = args.init_model + str(epoch)
restore = fluid.dygraph.load_persistables(save_dir)
cycle_gan.load_dict(restore)
cycle_gan.eval()
for data_A , data_B in zip(A_test_reader(), B_test_reader()):
A_name = data_A[1]
B_name = data_B[1]
print(A_name)
print(B_name)
tensor_A = np.array([data_A[0].reshape(3,256,256)]).astype("float32")
tensor_B = np.array([data_B[0].reshape(3,256,256)]).astype("float32")
data_A_tmp = to_variable(tensor_A)
data_B_tmp = to_variable(tensor_B)
fake_A_temp,fake_B_temp,cyc_A_temp,cyc_B_temp,g_A_loss,g_B_loss,idt_loss_A,idt_loss_B,cyc_A_loss,cyc_B_loss,g_loss = cycle_gan(data_A_tmp,data_B_tmp,True,False,False)
fake_A_temp = np.squeeze(fake_A_temp.numpy()[0]).transpose([1, 2, 0])
fake_B_temp = np.squeeze(fake_B_temp.numpy()[0]).transpose([1, 2, 0])
cyc_A_temp = np.squeeze(cyc_A_temp.numpy()[0]).transpose([1, 2, 0])
cyc_B_temp = np.squeeze(cyc_B_temp.numpy()[0]).transpose([1, 2, 0])
input_A_temp = np.squeeze(data_A[0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])
imsave(out_path + "/fakeB_" + str(epoch) + "_" + A_name, (
(fake_B_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/fakeA_" + str(epoch) + "_" + B_name, (
(fake_A_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/cycA_" + str(epoch) + "_" + A_name, (
(cyc_A_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/cycB_" + str(epoch) + "_" + B_name, (
(cyc_B_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/inputA_" + str(epoch) + "_" + A_name, (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/inputB_" + str(epoch) + "_" + B_name, (
(input_B_temp + 1) * 127.5).astype(np.uint8))
if __name__ == "__main__":
args = parser.parse_args()
print_arguments(args)
test()
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
import paddle
import argparse
import functools
import time
import numpy as np
from scipy.misc import imsave
import paddle.fluid as fluid
import data_reader
from utility import add_arguments, print_arguments, ImagePool
from trainer import *
from paddle.fluid.dygraph.base import to_variable
import six
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('epoch', int, 200, "The number of epoched to be trained.")
add_arg('output', str, "./output_0", "The directory the model and the test result to be saved to.")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('save_checkpoints', bool, True, "Whether to save checkpoints.")
# yapf: enable
lambda_A = 10.0
lambda_B = 10.0
lambda_identity = 0.5
tep_per_epoch = 2974
def optimizer_setting():
lr=0.0002
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=[
100 * step_per_epoch, 120 * step_per_epoch,
140 * step_per_epoch, 160 * step_per_epoch,
180 * step_per_epoch
],
values=[
lr , lr * 0.8, lr * 0.6, lr * 0.4, lr * 0.2, lr * 0.1
]),
beta1=0.5)
return optimizer
def train(args):
with fluid.dygraph.guard():
max_images_num = data_reader.max_images_num()
shuffle = True
data_shape = [-1] + data_reader.image_shape()
print(data_shape)
A_pool = ImagePool()
B_pool = ImagePool()
A_reader = paddle.batch(
data_reader.a_reader(shuffle=shuffle), args.batch_size)()
B_reader = paddle.batch(
data_reader.b_reader(shuffle=shuffle), args.batch_size)()
A_test_reader = data_reader.a_test_reader()
B_test_reader = data_reader.b_test_reader()
cycle_gan = Cycle_Gan("cycle_gan",istrain=True)
losses = [[], []]
t_time = 0
optimizer1 = optimizer_setting()
optimizer2 = optimizer_setting()
optimizer3 = optimizer_setting()
for epoch in range(args.epoch):
batch_id = 0
for i in range(max_images_num):
data_A = next(A_reader)
data_B = next(B_reader)
s_time = time.time()
data_A = np.array([data_A[0].reshape(3,256,256)]).astype("float32")
data_B = np.array([data_B[0].reshape(3,256,256)]).astype("float32")
data_A = to_variable(data_A)
data_B = to_variable(data_B)
# optimize the g_A network
fake_A,fake_B,cyc_A,cyc_B,g_A_loss,g_B_loss,idt_loss_A,idt_loss_B,cyc_A_loss,cyc_B_loss,g_loss = cycle_gan(data_A,data_B,True,False,False)
g_loss_out = g_loss.numpy()
g_loss.backward()
vars_G = []
for param in cycle_gan.parameters():
if param.name[:52]=="cycle_gan/Cycle_Gan_0/build_generator_resnet_9blocks":
vars_G.append(param)
optimizer1.minimize(g_loss,parameter_list=vars_G)
cycle_gan.clear_gradients()
fake_pool_B = B_pool.pool_image(fake_B).numpy()
fake_pool_B = np.array([fake_pool_B[0].reshape(3,256,256)]).astype("float32")
fake_pool_B = to_variable(fake_pool_B)
fake_pool_A = A_pool.pool_image(fake_A).numpy()
fake_pool_A = np.array([fake_pool_A[0].reshape(3,256,256)]).astype("float32")
fake_pool_A = to_variable(fake_pool_A)
# optimize the d_A network
rec_B, fake_pool_rec_B = cycle_gan(data_B,fake_pool_B,False,True,False)
d_loss_A = (fluid.layers.square(fake_pool_rec_B) +
fluid.layers.square(rec_B - 1)) / 2.0
d_loss_A = fluid.layers.reduce_mean(d_loss_A)
d_loss_A.backward()
vars_da = []
for param in cycle_gan.parameters():
if param.name[:47]=="cycle_gan/Cycle_Gan_0/build_gen_discriminator_0":
vars_da.append(param)
optimizer2.minimize(d_loss_A,parameter_list=vars_da)
cycle_gan.clear_gradients()
# optimize the d_B network
rec_A, fake_pool_rec_A = cycle_gan(data_A,fake_pool_A,False,False,True)
d_loss_B = (fluid.layers.square(fake_pool_rec_A) +
fluid.layers.square(rec_A - 1)) / 2.0
d_loss_B = fluid.layers.reduce_mean(d_loss_B)
d_loss_B.backward()
vars_db = []
for param in cycle_gan.parameters():
if param.name[:47]=="cycle_gan/Cycle_Gan_0/build_gen_discriminator_1":
vars_db.append(param)
optimizer3.minimize(d_loss_B,parameter_list=vars_db)
cycle_gan.clear_gradients()
batch_time = time.time() - s_time
t_time += batch_time
print(
"epoch{}; batch{}; g_loss:{}; d_A_loss: {}; d_B_loss:{} ; \n g_A_loss: {}; g_A_cyc_loss: {}; g_A_idt_loss: {}; g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {};Batch_time_cost: {:.2f}".format(epoch, batch_id,g_loss_out[0],d_loss_A.numpy()[0], d_loss_B.numpy()[0],g_A_loss.numpy()[0],cyc_A_loss.numpy()[0], idt_loss_A.numpy()[0], g_B_loss.numpy()[0],cyc_B_loss.numpy()[0],idt_loss_B.numpy()[0], batch_time))
with open('logging_train.txt', 'a') as log_file:
now = time.strftime("%c")
log_file.write(
"time: {}; epoch{}; batch{}; d_A_loss: {}; g_A_loss: {}; \
g_A_cyc_loss: {}; g_A_idt_loss: {}; d_B_loss: {}; \
g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {}; \
Batch_time_cost: {:.2f}\n".format(now, epoch, \
batch_id, d_loss_A[0], g_A_loss[ 0], cyc_A_loss[0], \
idt_loss_A[0], d_loss_B[0], g_A_loss[0], \
cyc_B_loss[0], idt_loss_B[0], batch_time))
losses[0].append(g_A_loss[0])
losses[1].append(d_loss_A[0])
sys.stdout.flush()
batch_id += 1
if args.save_checkpoints:
fluid.dygraph.save_persistables(cycle_gan.state_dict(),args.output+"/checkpoints/{}".format(epoch))
if __name__ == "__main__":
args = parser.parse_args()
print_arguments(args)
train(args)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from model import *
import paddle.fluid as fluid
step_per_epoch = 2974
lambda_A = 10.0
lambda_B = 10.0
lambda_identity = 0.5
class Cycle_Gan(fluid.dygraph.Layer):
def __init__(self, name_scope,istrain=True):
super (Cycle_Gan, self).__init__(name_scope)
self.build_generator_resnet_9blocks_a = build_generator_resnet_9blocks(self.full_name())
self.build_generator_resnet_9blocks_b = build_generator_resnet_9blocks(self.full_name())
if istrain:
self.build_gen_discriminator_a = build_gen_discriminator(self.full_name())
self.build_gen_discriminator_b = build_gen_discriminator(self.full_name())
def forward(self,input_A,input_B,is_G,is_DA,is_DB):
if is_G:
fake_B = self.build_generator_resnet_9blocks_a(input_A)
fake_A = self.build_generator_resnet_9blocks_b(input_B)
cyc_A = self.build_generator_resnet_9blocks_b(fake_B)
cyc_B = self.build_generator_resnet_9blocks_a(fake_A)
diff_A = fluid.layers.abs(
fluid.layers.elementwise_sub(
x=input_A,y=cyc_A))
diff_B = fluid.layers.abs(
fluid.layers.elementwise_sub(
x=input_B, y=cyc_B))
cyc_A_loss = fluid.layers.reduce_mean(diff_A) * lambda_A
cyc_B_loss = fluid.layers.reduce_mean(diff_B) * lambda_B
cyc_loss = cyc_A_loss + cyc_B_loss
fake_rec_A = self.build_gen_discriminator_a(fake_B)
g_A_loss = fluid.layers.reduce_mean(fluid.layers.square(fake_rec_A-1))
fake_rec_B = self.build_gen_discriminator_b(fake_A)
g_B_loss = fluid.layers.reduce_mean(fluid.layers.square(fake_rec_B-1))
G = g_A_loss + g_B_loss
idt_A = self.build_generator_resnet_9blocks_a(input_B)
idt_loss_A = fluid.layers.reduce_mean(fluid.layers.abs(fluid.layers.elementwise_sub(x = input_B , y = idt_A))) * lambda_B * lambda_identity
idt_B = self.build_generator_resnet_9blocks_b(input_A)
idt_loss_B = fluid.layers.reduce_mean(fluid.layers.abs(fluid.layers.elementwise_sub(x = input_A , y = idt_B))) * lambda_A * lambda_identity
idt_loss = fluid.layers.elementwise_add(idt_loss_A,idt_loss_B)
g_loss = cyc_loss + G + idt_loss
return fake_A,fake_B,cyc_A,cyc_B,g_A_loss,g_B_loss,idt_loss_A,idt_loss_B,cyc_A_loss,cyc_B_loss,g_loss
if is_DA:
### D
rec_B = self.build_gen_discriminator_a(input_A)
fake_pool_rec_B = self.build_gen_discriminator_a(input_B)
return rec_B, fake_pool_rec_B
if is_DB:
rec_A = self.build_gen_discriminator_b(input_A)
fake_pool_rec_A = self.build_gen_discriminator_b(input_B)
return rec_A, fake_pool_rec_A
"""Contains common utility functions."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import six
import random
import glob
import numpy as np
from paddle.fluid import core
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class ImagePool(object):
def __init__(self, pool_size=50):
self.pool = []
self.count = 0
self.pool_size = pool_size
def pool_image(self, image):
if self.count < self.pool_size:
self.pool.append(image)
self.count += 1
return image
else:
p = random.random()
if p > 0.5:
random_id = random.randint(0, self.pool_size - 1)
temp = self.pool[random_id]
self.pool[random_id] = image
return temp
else:
return image
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册