未验证 提交 dfe41a7c 编写于 作者: L lvmengsi 提交者: GitHub

update gan0727 (#2945)

* update gan0727
上级 ff28150a
...@@ -20,9 +20,10 @@ ...@@ -20,9 +20,10 @@
注意: 注意:
1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。 1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
2. CGAN和DCGAN仅支持多batch size训练 2. GAN模型目前仅仅验证了单机单卡训练和预测结果
3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集,测试集列表(test_list)和下载到的list文件格式相同,即包含测试集数量,属性列表,想要进行测试的图片和标签。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。 3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。
4. PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装。 4. PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装。
5. 中间效果图保存在${output_dir}/test文件夹中。对于Pix2Pix来说,inputA 和inputB 代表输入的两种风格的图片,fakeB表示生成图片;对于CycleGAN来说,inputA表示输入图片,fakeB表示inputA根据生成的图片,cycA表示fakeB经过生成器重构出来的对应于inputA的重构图片;对于StarGAN,AttGAN和STGAN来说,第一行表示原图,之后的每一行都代表一种属性变换。
图像生成模型库库的目录结构如下: 图像生成模型库库的目录结构如下:
``` ```
...@@ -52,6 +53,11 @@ ...@@ -52,6 +53,11 @@
│ ├── run_....py 训练启动示例 │ ├── run_....py 训练启动示例
│ ├── infer_....py 测试启动示例 │ ├── infer_....py 测试启动示例
│ ├── make_pair_data.py pix2pix GAN的数据list的生成脚本 │ ├── make_pair_data.py pix2pix GAN的数据list的生成脚本
├── data 下载的数据集存放的位置
│ ├── celeba
│ ├── ${image_dir} 存放实际图片
│ ├── list 文件
``` ```
...@@ -62,6 +68,9 @@ ...@@ -62,6 +68,9 @@
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。 在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
其他依赖包:
1. `pip install imageio` 或者 `pip install -r requirements.txt` 安装imageio包(保存图片代码中所依赖的包)
### 任务简介 ### 任务简介
Pix2Pix和CycleGAN采用cityscapes\[[10](#参考文献)\]数据集进行风格转换。 Pix2Pix和CycleGAN采用cityscapes\[[10](#参考文献)\]数据集进行风格转换。
...@@ -76,7 +85,7 @@ StarGAN,AttGAN和STGAN采用celeba\[[11](#参考文献)\]数据集进行属性 ...@@ -76,7 +85,7 @@ StarGAN,AttGAN和STGAN采用celeba\[[11](#参考文献)\]数据集进行属性
StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)数据集可以自行下载。 StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)数据集可以自行下载。
**自定义数据集:** **自定义数据集:**
用户可以使用自定义的数据集,只要设置成所对应的生成模型所需要的数据格式即可 如果您要使用自定义的数据集,只要设置成对应的生成模型所需要的数据格式,并放在data文件夹下,然后把`--dataset`参数设置成您自定义数据集的名称,data_reader.py文件就会自动去data文件夹中寻找数据
注意: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里的make_pair_data.py来生成,可以使用以下命令来生成: 注意: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里的make_pair_data.py来生成,可以使用以下命令来生成:
python scripts/make_pair_data.py \ python scripts/make_pair_data.py \
...@@ -85,7 +94,7 @@ StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects ...@@ -85,7 +94,7 @@ StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects
### 模型训练 ### 模型训练
**开始训练:** 数据准备完毕后,可以通过下方式启动训练: **开始训练:** 数据准备完毕后,可以通过下方式启动训练:
python train.py \ python train.py \
--model_net=$(name_of_model) \ --model_net=$(name_of_model) \
...@@ -242,15 +251,17 @@ STGAN的网络结构[9] ...@@ -242,15 +251,17 @@ STGAN的网络结构[9]
**A:** 查看是否所有的标签都转换对了。 **A:** 查看是否所有的标签都转换对了。
**Q:** 预测结果不正常,是怎么回事? **Q:** 预测结果不正常,是怎么回事?
**A:** 某些GAN预测的时候batch_norm的设置需要和训练的时候行为一致,查看模型库中相应的GAN中预测时batch_norm的行为和自己模型中的预测时batch_norm的 **A:** 某些GAN预测的时候batch_norm的设置需要和训练的时候行为一致,查看模型库中相应的GAN中预测时batch_norm的行为和自己模型中的预测时batch_norm的行为是否一致。
行为是否一致。
**Q:** 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢? **Q:** 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢?
**A:** 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女 **A:** 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女
性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,保持要转变的标签不变即可。 性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,保持要转变的标签不变即可。
**Q:** 如何使用自己的数据集进行训练? **Q:** 如何使用自己的数据集进行训练?
**A:** 对于Pix2Pix来说,只要准备好类似于Cityscapes数据集的不同风格的成对的数据即可。对于CycleGAN来说,只要准备类似于Cityscapes数据集的不同风格的数据即可。对于StarGAN,AttGAN和STGAN来说,除了需要准备类似于CelebA数据集中的图片和标签文件外,还需要把模型中的selected_attrs参数设置为想要改变的目标属性,c_dim参数这是为目标属性的个数。 **A:** 对于Pix2Pix来说,只要准备好类似于Cityscapes数据集的不同风格的成对的数据即可。对于CycleGAN来说,只要准备类似于Cityscapes数据集的不同风格的数据即可。对于StarGAN,AttGAN和STGAN来说,除了需要准备类似于CelebA数据集中图片,包含图片数量、名称和标签信息的list文件外,还需要把模型中的selected_attrs参数设置为想要改变的目标属性,c_dim参数设置为目标属性的个数。
**Q:** 如何从模型库中拿出单独的一个模型?
**A:** 由于trainer文件夹中的__init__.py文件默认导入了所有网络结构,所以需要删掉__init__.py文件中导入的当前模型之外的包,然后把trainer和network中不需要的模型文件删掉即可。
## 参考论文 ## 参考论文
......
...@@ -66,27 +66,24 @@ class reader_creator(object): ...@@ -66,27 +66,24 @@ class reader_creator(object):
list_filename, list_filename,
shuffle=False, shuffle=False,
batch_size=1, batch_size=1,
drop_last=False,
mode="TRAIN"): mode="TRAIN"):
self.image_dir = image_dir self.image_dir = image_dir
self.list_filename = list_filename self.list_filename = list_filename
self.batch_size = batch_size self.batch_size = batch_size
self.mode = mode self.mode = mode
self.name2id = {}
self.id2name = {}
self.lines = open(self.list_filename).readlines() self.lines = open(self.list_filename).readlines()
if self.mode == "TRAIN": if self.mode == "TRAIN":
self.shuffle = shuffle self.shuffle = shuffle
self.drop_last = drop_last
else: else:
self.shuffle = False self.shuffle = False
self.drop_last = False
def len(self): def len(self):
if self.drop_last or len(self.lines) % self.batch_size == 0: return len(self.lines) // self.batch_size
return len(self.lines) // self.batch_size
else:
return len(self.lines) // self.batch_size + 1
def make_reader(self, args, return_name=False): def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename) print(self.image_dir, self.list_filename)
...@@ -98,8 +95,10 @@ class reader_creator(object): ...@@ -98,8 +95,10 @@ class reader_creator(object):
if self.shuffle: if self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
for file in self.lines: for i, file in enumerate(self.lines):
file = file.strip('\n\r\t ') file = file.strip('\n\r\t ')
self.name2id[os.path.basename(file)] = i
self.id2name[i] = os.path.basename(file)
img = Image.open(os.path.join(self.image_dir, file)).convert( img = Image.open(os.path.join(self.image_dir, file)).convert(
'RGB') 'RGB')
if self.mode == "TRAIN": if self.mode == "TRAIN":
...@@ -117,7 +116,7 @@ class reader_creator(object): ...@@ -117,7 +116,7 @@ class reader_creator(object):
if return_name: if return_name:
batch_out.append(img) batch_out.append(img)
batch_out_name.append(os.path.basename(file)) batch_out_name.append(i)
else: else:
batch_out.append(img) batch_out.append(img)
if len(batch_out) == self.batch_size: if len(batch_out) == self.batch_size:
...@@ -125,13 +124,8 @@ class reader_creator(object): ...@@ -125,13 +124,8 @@ class reader_creator(object):
yield batch_out, batch_out_name yield batch_out, batch_out_name
batch_out_name = [] batch_out_name = []
else: else:
yield batch_out yield [batch_out]
batch_out = [] batch_out = []
if self.drop_last == False and len(batch_out) != 0:
if return_name:
yield batch_out, batch_out_name
else:
yield batch_out
return reader return reader
...@@ -144,14 +138,12 @@ class pair_reader_creator(reader_creator): ...@@ -144,14 +138,12 @@ class pair_reader_creator(reader_creator):
list_filename, list_filename,
shuffle=False, shuffle=False,
batch_size=1, batch_size=1,
drop_last=False,
mode="TRAIN"): mode="TRAIN"):
super(pair_reader_creator, self).__init__( super(pair_reader_creator, self).__init__(
image_dir, image_dir,
list_filename, list_filename,
shuffle=shuffle, shuffle=shuffle,
batch_size=batch_size, batch_size=batch_size,
drop_last=drop_last,
mode=mode) mode=mode)
def make_reader(self, args, return_name=False): def make_reader(self, args, return_name=False):
...@@ -163,13 +155,16 @@ class pair_reader_creator(reader_creator): ...@@ -163,13 +155,16 @@ class pair_reader_creator(reader_creator):
batch_out_name = [] batch_out_name = []
if self.shuffle: if self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
for line in self.lines: for i, line in enumerate(self.lines):
files = line.strip('\n\r\t ').split('\t') files = line.strip('\n\r\t ').split('\t')
img1 = Image.open(os.path.join(self.image_dir, files[ img1 = Image.open(os.path.join(self.image_dir, files[
0])).convert('RGB') 0])).convert('RGB')
img2 = Image.open(os.path.join(self.image_dir, files[ img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB') 1])).convert('RGB')
self.name2id[os.path.basename(files[0])] = i
self.id2name[i] = os.path.basename(files[0])
if self.mode == "TRAIN": if self.mode == "TRAIN":
param = get_preprocess_param(args.image_size, param = get_preprocess_param(args.image_size,
args.crop_size) args.crop_size)
...@@ -201,7 +196,7 @@ class pair_reader_creator(reader_creator): ...@@ -201,7 +196,7 @@ class pair_reader_creator(reader_creator):
batch_out_1.append(img1) batch_out_1.append(img1)
batch_out_2.append(img2) batch_out_2.append(img2)
if return_name: if return_name:
batch_out_name.append(os.path.basename(files[0])) batch_out_name.append(i)
if len(batch_out_1) == self.batch_size: if len(batch_out_1) == self.batch_size:
if return_name: if return_name:
yield batch_out_1, batch_out_2, batch_out_name yield batch_out_1, batch_out_2, batch_out_name
...@@ -210,11 +205,6 @@ class pair_reader_creator(reader_creator): ...@@ -210,11 +205,6 @@ class pair_reader_creator(reader_creator):
yield batch_out_1, batch_out_2 yield batch_out_1, batch_out_2
batch_out_1 = [] batch_out_1 = []
batch_out_2 = [] batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
if return_name:
yield batch_out_1, batch_out_2, batch_out_name
else:
yield batch_out_1, batch_out_2
return reader return reader
...@@ -238,17 +228,14 @@ class celeba_reader_creator(reader_creator): ...@@ -238,17 +228,14 @@ class celeba_reader_creator(reader_creator):
attr2idx = {} attr2idx = {}
for i, attr_name in enumerate(all_attr_names): for i, attr_name in enumerate(all_attr_names):
attr2idx[attr_name] = i attr2idx[attr_name] = i
lines = lines[2:]
if self.mode == "TRAIN": if self.mode == "TRAIN":
self.batch_size = args.batch_size self.batch_size = args.batch_size
self.drop_last = args.drop_last
self.shuffle = args.shuffle self.shuffle = args.shuffle
lines = lines[2:train_end] lines = lines[2:train_end]
else: else:
self.batch_size = args.n_samples self.batch_size = args.n_samples
self.shuffle = False self.shuffle = False
self.drop_last = False
if self.mode == "TEST": if self.mode == "TEST":
lines = lines[train_end:test_end] lines = lines[train_end:test_end]
else: else:
...@@ -256,20 +243,17 @@ class celeba_reader_creator(reader_creator): ...@@ -256,20 +243,17 @@ class celeba_reader_creator(reader_creator):
self.images = [] self.images = []
attr_names = args.selected_attrs.split(',') attr_names = args.selected_attrs.split(',')
for line in lines: for i, line in enumerate(lines):
arr = line.strip().split() arr = line.strip().split()
name = os.path.join('img_align_celeba', arr[0]) name = os.path.join('img_align_celeba', arr[0])
label = [] label = []
for attr_name in attr_names: for attr_name in attr_names:
idx = attr2idx[attr_name] idx = attr2idx[attr_name]
label.append(arr[idx + 1] == "1") label.append(arr[idx + 1] == "1")
self.images.append((name, label)) self.images.append((name, label, arr[0]))
def len(self): def len(self):
if self.drop_last or len(self.images) % self.batch_size == 0: return len(self.images) // self.batch_size
return len(self.images) // self.batch_size
else:
return len(self.images) // self.batch_size + 1
def make_reader(self, return_name=False): def make_reader(self, return_name=False):
print(self.image_dir, self.list_filename) print(self.image_dir, self.list_filename)
...@@ -277,10 +261,11 @@ class celeba_reader_creator(reader_creator): ...@@ -277,10 +261,11 @@ class celeba_reader_creator(reader_creator):
def reader(): def reader():
batch_out_1 = [] batch_out_1 = []
batch_out_2 = [] batch_out_2 = []
batch_out_3 = []
batch_out_name = [] batch_out_name = []
if self.shuffle: if self.shuffle:
np.random.shuffle(self.images) np.random.shuffle(self.images)
for file, label in self.images: for file, label, f_name in self.images:
img = Image.open(os.path.join(self.image_dir, file)) img = Image.open(os.path.join(self.image_dir, file))
label = np.array(label).astype("float32") label = np.array(label).astype("float32")
if self.args.model_net == "StarGAN": if self.args.model_net == "StarGAN":
...@@ -294,20 +279,19 @@ class celeba_reader_creator(reader_creator): ...@@ -294,20 +279,19 @@ class celeba_reader_creator(reader_creator):
batch_out_1.append(img) batch_out_1.append(img)
batch_out_2.append(label) batch_out_2.append(label)
if return_name: if return_name:
batch_out_name.append(os.path.basename(file)) batch_out_name.append(int(f_name.split('.')[0]))
if len(batch_out_1) == self.batch_size: if len(batch_out_1) == self.batch_size:
batch_out_3 = np.copy(batch_out_2)
if self.shuffle:
np.random.shuffle(batch_out_3)
if return_name: if return_name:
yield batch_out_1, batch_out_2, batch_out_name yield batch_out_1, batch_out_2, batch_out_3, batch_out_name
batch_out_name = [] batch_out_name = []
else: else:
yield batch_out_1, batch_out_2 yield batch_out_1, batch_out_2, batch_out_3
batch_out_1 = [] batch_out_1 = []
batch_out_2 = [] batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0: batch_out_3 = []
if return_name:
yield batch_out_1, batch_out_2, batch_out_name
else:
yield batch_out_1, batch_out_2
return reader return reader
...@@ -388,17 +372,17 @@ class data_reader(object): ...@@ -388,17 +372,17 @@ class data_reader(object):
list_filename=trainA_list, list_filename=trainA_list,
shuffle=self.cfg.shuffle, shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last,
mode="TRAIN") mode="TRAIN")
b_train_reader = reader_creator( b_train_reader = reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=trainB_list, list_filename=trainB_list,
shuffle=self.cfg.shuffle, shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last,
mode="TRAIN") mode="TRAIN")
a_reader_test = None a_reader_test = None
b_reader_test = None b_reader_test = None
a_id2name = None
b_id2name = None
if self.cfg.run_test: if self.cfg.run_test:
testA_list = os.path.join(dataset_dir, "testA.txt") testA_list = os.path.join(dataset_dir, "testA.txt")
testB_list = os.path.join(dataset_dir, "testB.txt") testB_list = os.path.join(dataset_dir, "testB.txt")
...@@ -407,25 +391,25 @@ class data_reader(object): ...@@ -407,25 +391,25 @@ class data_reader(object):
list_filename=testA_list, list_filename=testA_list,
shuffle=False, shuffle=False,
batch_size=1, batch_size=1,
drop_last=self.cfg.drop_last,
mode="TEST") mode="TEST")
b_test_reader = reader_creator( b_test_reader = reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=testB_list, list_filename=testB_list,
shuffle=False, shuffle=False,
batch_size=1, batch_size=1,
drop_last=self.cfg.drop_last,
mode="TEST") mode="TEST")
a_reader_test = a_test_reader.make_reader( a_reader_test = a_test_reader.make_reader(
self.cfg, return_name=True) self.cfg, return_name=True)
b_reader_test = b_test_reader.make_reader( b_reader_test = b_test_reader.make_reader(
self.cfg, return_name=True) self.cfg, return_name=True)
a_id2name = a_test_reader.id2name
b_id2name = b_test_reader.id2name
batch_num = max(a_train_reader.len(), b_train_reader.len()) batch_num = max(a_train_reader.len(), b_train_reader.len())
a_reader = a_train_reader.make_reader(self.cfg) a_reader = a_train_reader.make_reader(self.cfg)
b_reader = b_train_reader.make_reader(self.cfg) b_reader = b_train_reader.make_reader(self.cfg)
return a_reader, b_reader, a_reader_test, b_reader_test, batch_num return a_reader, b_reader, a_reader_test, b_reader_test, batch_num, a_id2name, b_id2name
elif self.cfg.model_net in ['StarGAN', 'STGAN', 'AttGAN']: elif self.cfg.model_net in ['StarGAN', 'STGAN', 'AttGAN']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
...@@ -450,7 +434,7 @@ class data_reader(object): ...@@ -450,7 +434,7 @@ class data_reader(object):
reader_test = test_reader.make_reader(return_name=True) reader_test = test_reader.make_reader(return_name=True)
batch_num = train_reader.len() batch_num = train_reader.len()
reader = train_reader.make_reader() reader = train_reader.make_reader()
return reader, reader_test, batch_num return reader, reader_test, batch_num, None
elif self.cfg.model_net in ['Pix2pix']: elif self.cfg.model_net in ['Pix2pix']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
...@@ -462,9 +446,9 @@ class data_reader(object): ...@@ -462,9 +446,9 @@ class data_reader(object):
list_filename=train_list, list_filename=train_list,
shuffle=self.cfg.shuffle, shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last,
mode="TRAIN") mode="TRAIN")
reader_test = None reader_test = None
id2name = None
if self.cfg.run_test: if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt") test_list = os.path.join(dataset_dir, "test.txt")
if self.cfg.test_list is not None: if self.cfg.test_list is not None:
...@@ -474,13 +458,13 @@ class data_reader(object): ...@@ -474,13 +458,13 @@ class data_reader(object):
list_filename=test_list, list_filename=test_list,
shuffle=False, shuffle=False,
batch_size=1, batch_size=1,
drop_last=self.cfg.drop_last,
mode="TEST") mode="TEST")
reader_test = test_reader.make_reader( reader_test = test_reader.make_reader(
self.cfg, return_name=True) self.cfg, return_name=True)
id2name = test_reader.id2name
batch_num = train_reader.len() batch_num = train_reader.len()
reader = train_reader.make_reader(self.cfg) reader = train_reader.make_reader(self.cfg)
return reader, reader_test, batch_num return reader, reader_test, batch_num, id2name
else: else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt') train_list = os.path.join(dataset_dir, 'train.txt')
...@@ -489,14 +473,15 @@ class data_reader(object): ...@@ -489,14 +473,15 @@ class data_reader(object):
train_reader = reader_creator( train_reader = reader_creator(
image_dir=dataset_dir, list_filename=train_list) image_dir=dataset_dir, list_filename=train_list)
reader_test = None reader_test = None
id2name = None
if self.cfg.run_test: if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt") test_list = os.path.join(dataset_dir, "test.txt")
test_reader = reader_creator( test_reader = reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=test_list, list_filename=test_list,
batch_size=self.cfg.n_samples, batch_size=self.cfg.n_samples)
drop_last=self.cfg.drop_last)
reader_test = test_reader.get_test_reader( reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True) self.cfg, shuffle=False, return_name=True)
id2name = test_reader.id2name
batch_num = train_reader.len() batch_num = train_reader.len()
return train_reader, reader_test, batch_num return train_reader, reader_test, batch_num, id2name
...@@ -63,7 +63,7 @@ def download_mnist(dir_path): ...@@ -63,7 +63,7 @@ def download_mnist(dir_path):
for url in URL_DIC: for url in URL_DIC:
md5sum = URL_DIC[url] md5sum = URL_DIC[url]
data_dir = os.path.join(dir_path + 'mnist') data_dir = os.path.join(dir_path, 'mnist')
if not os.path.exists(data_dir): if not os.path.exists(data_dir):
os.makedirs(data_dir) os.makedirs(data_dir)
...@@ -157,11 +157,12 @@ if __name__ == '__main__': ...@@ -157,11 +157,12 @@ if __name__ == '__main__':
'facades', 'iphone2dslr_flower', 'ae_photos', 'mini' 'facades', 'iphone2dslr_flower', 'ae_photos', 'mini'
] ]
pwd = os.path.join(os.path.dirname(__file__), 'data')
if args.dataset == 'mnist': if args.dataset == 'mnist':
print('Download dataset: {}'.format(args.dataset)) print('Download dataset: {}'.format(args.dataset))
download_mnist('data') download_mnist(pwd)
elif args.dataset in cycle_pix_dataset: elif args.dataset in cycle_pix_dataset:
print('Download dataset: {}'.format(args.dataset)) print('Download dataset: {}'.format(args.dataset))
download_cycle_pix(os.path.join('data', args.dataset)) download_cycle_pix(pwd, args.dataset)
else: else:
print('Please download by yourself, thanks') print('Please download by yourself, thanks')
...@@ -26,7 +26,7 @@ import numpy as np ...@@ -26,7 +26,7 @@ import numpy as np
import imageio import imageio
import glob import glob
from util.config import add_arguments, print_arguments from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator from data_reader import celeba_reader_creator, reader_creator
from util.utility import check_attribute_conflict, check_gpu, save_batch_image from util.utility import check_attribute_conflict, check_gpu, save_batch_image
from util import utility from util import utility
import copy import copy
...@@ -70,9 +70,16 @@ def infer(args): ...@@ -70,9 +70,16 @@ def infer(args):
name='label_org_', shape=[args.c_dim], dtype='float32') name='label_org_', shape=[args.c_dim], dtype='float32')
label_trg_ = fluid.layers.data( label_trg_ = fluid.layers.data(
name='label_trg_', shape=[args.c_dim], dtype='float32') name='label_trg_', shape=[args.c_dim], dtype='float32')
image_name = fluid.layers.data(
name='image_name', shape=[args.n_samples], dtype='int32')
model_name = 'net_G' model_name = 'net_G'
if args.model_net == 'CycleGAN': if args.model_net == 'CycleGAN':
py_reader = fluid.io.PyReader(
feed_list=[input, image_name],
capacity=4, ## batch_size * 4
iterable=True,
use_double_buffer=True)
from network.CycleGAN_network import CycleGAN_model from network.CycleGAN_network import CycleGAN_model
model = CycleGAN_model() model = CycleGAN_model()
if args.input_style == "A": if args.input_style == "A":
...@@ -82,15 +89,35 @@ def infer(args): ...@@ -82,15 +89,35 @@ def infer(args):
else: else:
raise "Input with style [%s] is not supported." % args.input_style raise "Input with style [%s] is not supported." % args.input_style
elif args.model_net == 'Pix2pix': elif args.model_net == 'Pix2pix':
py_reader = fluid.io.PyReader(
feed_list=[input, image_name],
capacity=4, ## batch_size * 4
iterable=True,
use_double_buffer=True)
from network.Pix2pix_network import Pix2pix_model from network.Pix2pix_network import Pix2pix_model
model = Pix2pix_model() model = Pix2pix_model()
fake = model.network_G(input, "generator", cfg=args) fake = model.network_G(input, "generator", cfg=args)
elif args.model_net == 'StarGAN': elif args.model_net == 'StarGAN':
py_reader = fluid.io.PyReader(
feed_list=[input, label_org_, label_trg_, image_name],
capacity=32,
iterable=True,
use_double_buffer=True)
from network.StarGAN_network import StarGAN_model from network.StarGAN_network import StarGAN_model
model = StarGAN_model() model = StarGAN_model()
fake = model.network_G(input, label_trg_, name="g_main", cfg=args) fake = model.network_G(input, label_trg_, name="g_main", cfg=args)
elif args.model_net == 'STGAN': elif args.model_net == 'STGAN':
from network.STGAN_network import STGAN_model from network.STGAN_network import STGAN_model
py_reader = fluid.io.PyReader(
feed_list=[input, label_org_, label_trg_, image_name],
capacity=32,
iterable=True,
use_double_buffer=True)
model = STGAN_model() model = STGAN_model()
fake, _ = model.network_G( fake, _ = model.network_G(
input, input,
...@@ -101,6 +128,13 @@ def infer(args): ...@@ -101,6 +128,13 @@ def infer(args):
is_test=True) is_test=True)
elif args.model_net == 'AttGAN': elif args.model_net == 'AttGAN':
from network.AttGAN_network import AttGAN_model from network.AttGAN_network import AttGAN_model
py_reader = fluid.io.PyReader(
feed_list=[input, label_org_, label_trg_, image_name],
capacity=32,
iterable=True,
use_double_buffer=True)
model = AttGAN_model() model = AttGAN_model()
fake, _ = model.network_G( fake, _ = model.network_G(
input, input,
...@@ -116,19 +150,26 @@ def infer(args): ...@@ -116,19 +150,26 @@ def infer(args):
name='conditions', shape=[1], dtype='float32') name='conditions', shape=[1], dtype='float32')
from network.CGAN_network import CGAN_model from network.CGAN_network import CGAN_model
model = CGAN_model() model = CGAN_model(args.n_samples)
fake = model.network_G(noise, conditions, name="G") fake = model.network_G(noise, conditions, name="G")
elif args.model_net == 'DCGAN': elif args.model_net == 'DCGAN':
noise = fluid.layers.data( noise = fluid.layers.data(
name='noise', shape=[args.noise_size], dtype='float32') name='noise', shape=[args.noise_size], dtype='float32')
from network.DCGAN_network import DCGAN_model from network.DCGAN_network import DCGAN_model
model = DCGAN_model() model = DCGAN_model(args.n_samples)
fake = model.network_G(noise, name="G") fake = model.network_G(noise, name="G")
else: else:
raise NotImplementedError("model_net {} is not support".format( raise NotImplementedError("model_net {} is not support".format(
args.model_net)) args.model_net))
def _compute_start_end(image_name):
image_name_start = np.array(image_name)[0].astype('int32')
image_name_end = image_name_start + args.n_samples - 1
image_name_save = str(np.array(image_name)[0].astype('int32')) + '.jpg'
print("read {}.jpg ~ {}.jpg".format(image_name_start, image_name_end))
return image_name_save
# prepare environment # prepare environment
place = fluid.CPUPlace() place = fluid.CPUPlace()
if args.use_gpu: if args.use_gpu:
...@@ -152,36 +193,34 @@ def infer(args): ...@@ -152,36 +193,34 @@ def infer(args):
args=args, args=args,
mode="VAL") mode="VAL")
reader_test = test_reader.make_reader(return_name=True) reader_test = test_reader.make_reader(return_name=True)
for data in zip(reader_test()): py_reader.decorate_batch_generator(
real_img, label_org, name = data[0] reader_test,
print("read {}".format(name)) places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
label_trg = copy.deepcopy(label_org) for data in py_reader():
tensor_img = fluid.LoDTensor() real_img, label_org, label_trg, image_name = data[0]['input'], data[
tensor_label_org = fluid.LoDTensor() 0]['label_org_'], data[0]['label_trg_'], data[0]['image_name']
tensor_label_trg = fluid.LoDTensor() image_name_save = _compute_start_end(image_name)
tensor_label_org_ = fluid.LoDTensor() real_img_temp = save_batch_image(np.array(real_img))
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = save_batch_image(real_img)
images = [real_img_temp] images = [real_img_temp]
for i in range(args.c_dim): for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_trg) label_trg_tmp = copy.deepcopy(np.array(label_trg))
for j in range(len(label_org)): for j in range(len(label_trg_tmp)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict( label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names) label_trg_tmp, attr_names[i], attr_names)
label_org_ = list(map(lambda x: ((x * 2) - 1) * 0.5, label_org)) label_org_tmp = list(
label_trg_ = list( map(lambda x: ((x * 2) - 1) * 0.5, np.array(label_org)))
label_trg_tmp = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp)) map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
if args.model_net == 'AttGAN': if args.model_net == 'AttGAN':
for k in range(len(label_org)): for k in range(len(label_trg_tmp)):
label_trg_[k][i] = label_trg_[k][i] * 2.0 label_trg_tmp[k][i] = label_trg_tmp[k][i] * 2.0
tensor_label_org_.set(label_org_, place) tensor_label_org_ = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place) tensor_label_trg_ = fluid.LoDTensor()
tensor_label_trg_.set(label_trg_, place) tensor_label_org_.set(label_org_tmp, place)
tensor_label_trg_.set(label_trg_tmp, place)
out = exe.run(feed={ out = exe.run(feed={
"input": tensor_img, "input": real_img,
"label_org_": tensor_label_org_, "label_org_": tensor_label_org_,
"label_trg_": tensor_label_trg_ "label_trg_": tensor_label_trg_
}, },
...@@ -189,10 +228,11 @@ def infer(args): ...@@ -189,10 +228,11 @@ def infer(args):
fake_temp = save_batch_image(out[0]) fake_temp = save_batch_image(out[0])
images.append(fake_temp) images.append(fake_temp)
images_concat = np.concatenate(images, 1) images_concat = np.concatenate(images, 1)
if len(label_org) > 1: if len(np.array(label_org)) > 1:
images_concat = np.concatenate(images_concat, 1) images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], ( imageio.imwrite(
(images_concat + 1) * 127.5).astype(np.uint8)) os.path.join(args.output, "fake_img_" + image_name_save), (
(images_concat + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'StarGAN': elif args.model_net == 'StarGAN':
test_reader = celeba_reader_creator( test_reader = celeba_reader_creator(
image_dir=args.dataset_dir, image_dir=args.dataset_dir,
...@@ -200,54 +240,60 @@ def infer(args): ...@@ -200,54 +240,60 @@ def infer(args):
args=args, args=args,
mode="VAL") mode="VAL")
reader_test = test_reader.make_reader(return_name=True) reader_test = test_reader.make_reader(return_name=True)
for data in zip(reader_test()): py_reader.decorate_batch_generator(
real_img, label_org, name = data[0] reader_test,
print("read {}".format(name)) places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
tensor_img = fluid.LoDTensor() for data in py_reader():
tensor_label_org = fluid.LoDTensor() real_img, label_org, label_trg, image_name = data[0]['input'], data[
tensor_img.set(real_img, place) 0]['label_org_'], data[0]['label_trg_'], data[0]['image_name']
tensor_label_org.set(label_org, place) image_name_save = _compute_start_end(image_name)
real_img_temp = save_batch_image(real_img) real_img_temp = save_batch_image(np.array(real_img))
images = [real_img_temp] images = [real_img_temp]
for i in range(args.c_dim): for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_org) label_trg_tmp = copy.deepcopy(np.array(label_org))
for j in range(len(label_org)): for j in range(len(np.array(label_org))):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict( label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names) label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor() tensor_label_trg_ = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place) tensor_label_trg_.set(label_trg_tmp, place)
out = exe.run( out = exe.run(
feed={"input": tensor_img, feed={"input": real_img,
"label_trg_": tensor_label_trg}, "label_trg_": tensor_label_trg_},
fetch_list=[fake.name]) fetch_list=[fake.name])
fake_temp = save_batch_image(out[0]) fake_temp = save_batch_image(out[0])
images.append(fake_temp) images.append(fake_temp)
images_concat = np.concatenate(images, 1) images_concat = np.concatenate(images, 1)
if len(label_org) > 1: if len(np.array(label_org)) > 1:
images_concat = np.concatenate(images_concat, 1) images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], ( imageio.imwrite(
(images_concat + 1) * 127.5).astype(np.uint8)) os.path.join(args.output, "fake_img_" + image_name_save), (
(images_concat + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'Pix2pix' or args.model_net == 'CycleGAN': elif args.model_net == 'Pix2pix' or args.model_net == 'CycleGAN':
for file in glob.glob(args.dataset_dir): test_reader = reader_creator(
print("read {}".format(file)) image_dir=args.dataset_dir,
image_name = os.path.basename(file) list_filename=args.test_list,
image = Image.open(file).convert('RGB') shuffle=False,
image = image.resize((256, 256), Image.BICUBIC) batch_size=args.n_samples,
image = np.array(image).transpose([2, 0, 1]).astype('float32') mode="VAL")
image = image / 255.0 reader_test = test_reader.make_reader(args, return_name=True)
image = (image - 0.5) / 0.5 py_reader.decorate_batch_generator(
data = image[np.newaxis, :] reader_test,
tensor = fluid.LoDTensor() places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
tensor.set(data, place) id2name = test_reader.id2name
for data in py_reader():
fake_temp = exe.run(fetch_list=[fake.name], feed={"input": tensor}) real_img, image_name = data[0]['input'], data[0]['image_name']
image_name = id2name[np.array(image_name).astype('int32')[0]]
print("read: ", image_name)
fake_temp = exe.run(fetch_list=[fake.name],
feed={"input": real_img})
fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0]) fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0])
input_temp = np.squeeze(data).transpose([1, 2, 0]) input_temp = np.squeeze(np.array(real_img)[0]).transpose([1, 2, 0])
imageio.imwrite(args.output + "/fake_" + image_name, ( imageio.imwrite(
(fake_temp + 1) * 127.5).astype(np.uint8)) os.path.join(args.output, "fake_" + image_name), (
(fake_temp + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'CGAN': elif args.model_net == 'CGAN':
noise_data = np.random.uniform( noise_data = np.random.uniform(
...@@ -282,7 +328,7 @@ def infer(args): ...@@ -282,7 +328,7 @@ def infer(args):
fig = utility.plot(fake_image) fig = utility.plot(fake_image)
plt.savefig( plt.savefig(
os.path.join(args.output, '/fake_dcgan.png'), bbox_inches='tight') os.path.join(args.output, 'fake_dcgan.png'), bbox_inches='tight')
plt.close(fig) plt.close(fig)
else: else:
raise NotImplementedError("model_net {} is not support".format( raise NotImplementedError("model_net {} is not support".format(
......
...@@ -45,15 +45,15 @@ class CGAN_model(object): ...@@ -45,15 +45,15 @@ class CGAN_model(object):
y = fluid.layers.reshape(label, shape=[-1, self.y_dim, 1, 1]) y = fluid.layers.reshape(label, shape=[-1, self.y_dim, 1, 1])
xy = fluid.layers.concat([input, y], 1) xy = fluid.layers.concat([input, y], 1)
o_l1 = linear( o_l1 = linear(
xy, input=xy,
self.gf_dim * 8, output_size=self.gf_dim * 8,
norm=self.norm, norm=self.norm,
activation_fn='relu', activation_fn='relu',
name=name + '_l1') name=name + '_l1')
o_c1 = fluid.layers.concat([o_l1, y], 1) o_c1 = fluid.layers.concat([o_l1, y], 1)
o_l2 = linear( o_l2 = linear(
o_c1, input=o_c1,
self.gf_dim * (self.img_w // 4) * (self.img_h // 4), output_size=self.gf_dim * (self.img_w // 4) * (self.img_h // 4),
norm=self.norm, norm=self.norm,
activation_fn='relu', activation_fn='relu',
name=name + '_l2') name=name + '_l2')
...@@ -63,10 +63,10 @@ class CGAN_model(object): ...@@ -63,10 +63,10 @@ class CGAN_model(object):
name=name + '_reshape') name=name + '_reshape')
o_c2 = conv_cond_concat(o_r1, y) o_c2 = conv_cond_concat(o_r1, y)
o_dc1 = deconv2d( o_dc1 = deconv2d(
o_c2, input=o_c2,
self.gf_dim, num_filters=self.gf_dim,
4, filter_size=4,
2, stride=2,
padding=[1, 1], padding=[1, 1],
norm='batch_norm', norm='batch_norm',
activation_fn='relu', activation_fn='relu',
...@@ -74,10 +74,10 @@ class CGAN_model(object): ...@@ -74,10 +74,10 @@ class CGAN_model(object):
output_size=[self.img_w // 2, self.img_h // 2]) output_size=[self.img_w // 2, self.img_h // 2])
o_c3 = conv_cond_concat(o_dc1, y) o_c3 = conv_cond_concat(o_dc1, y)
o_dc2 = deconv2d( o_dc2 = deconv2d(
o_dc1, input=o_dc1,
1, num_filters=1,
4, filter_size=4,
2, stride=2,
padding=[1, 1], padding=[1, 1],
activation_fn='tanh', activation_fn='tanh',
name=name + '_dc2', name=name + '_dc2',
...@@ -91,26 +91,26 @@ class CGAN_model(object): ...@@ -91,26 +91,26 @@ class CGAN_model(object):
y = fluid.layers.reshape(label, shape=[-1, self.y_dim, 1, 1]) y = fluid.layers.reshape(label, shape=[-1, self.y_dim, 1, 1])
xy = conv_cond_concat(x, y) xy = conv_cond_concat(x, y)
o_l1 = conv2d( o_l1 = conv2d(
xy, input=xy,
self.df_dim, num_filters=self.df_dim,
3, filter_size=3,
2, stride=2,
name=name + '_l1', name=name + '_l1',
activation_fn='leaky_relu') activation_fn='leaky_relu')
o_c1 = conv_cond_concat(o_l1, y) o_c1 = conv_cond_concat(o_l1, y)
o_l2 = conv2d( o_l2 = conv2d(
o_c1, input=o_c1,
self.df_dim, num_filters=self.df_dim,
3, filter_size=3,
2, stride=2,
name=name + '_l2', name=name + '_l2',
norm='batch_norm', norm='batch_norm',
activation_fn='leaky_relu') activation_fn='leaky_relu')
o_f1 = fluid.layers.flatten(o_l2, axis=1) o_f1 = fluid.layers.flatten(o_l2, axis=1)
o_c2 = fluid.layers.concat([o_f1, y], 1) o_c2 = fluid.layers.concat([o_f1, y], 1)
o_l3 = linear( o_l3 = linear(
o_c2, input=o_c2,
self.df_dim * 16, output_size=self.df_dim * 16,
norm=self.norm, norm=self.norm,
activation_fn='leaky_relu', activation_fn='leaky_relu',
name=name + '_l3') name=name + '_l3')
......
...@@ -97,11 +97,11 @@ def build_resnet_block(inputres, ...@@ -97,11 +97,11 @@ def build_resnet_block(inputres,
norm_type='batch_norm'): norm_type='batch_norm'):
out_res = fluid.layers.pad2d(inputres, [1, 1, 1, 1], mode="reflect") out_res = fluid.layers.pad2d(inputres, [1, 1, 1, 1], mode="reflect")
out_res = conv2d( out_res = conv2d(
out_res, input=out_res,
dim, num_filters=dim,
3, filter_size=3,
1, stride=1,
0.02, stddev=0.02,
name=name + "_c1", name=name + "_c1",
norm=norm_type, norm=norm_type,
activation_fn='relu', activation_fn='relu',
...@@ -112,11 +112,11 @@ def build_resnet_block(inputres, ...@@ -112,11 +112,11 @@ def build_resnet_block(inputres,
out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect") out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect")
out_res = conv2d( out_res = conv2d(
out_res, input=out_res,
dim, num_filters=dim,
3, filter_size=3,
1, stride=1,
0.02, stddev=0.02,
name=name + "_c2", name=name + "_c2",
norm=norm_type, norm=norm_type,
use_bias=use_bias) use_bias=use_bias)
...@@ -134,31 +134,31 @@ def build_generator_resnet_blocks(inputgen, ...@@ -134,31 +134,31 @@ def build_generator_resnet_blocks(inputgen,
use_bias = norm_type == 'instance_norm' use_bias = norm_type == 'instance_norm'
pad_input = fluid.layers.pad2d(inputgen, [3, 3, 3, 3], mode="reflect") pad_input = fluid.layers.pad2d(inputgen, [3, 3, 3, 3], mode="reflect")
o_c1 = conv2d( o_c1 = conv2d(
pad_input, input=pad_input,
g_base_dims, num_filters=g_base_dims,
7, filter_size=7,
1, stride=1,
0.02, stddev=0.02,
name=name + "_c1", name=name + "_c1",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
o_c2 = conv2d( o_c2 = conv2d(
o_c1, input=o_c1,
g_base_dims * 2, num_filters=g_base_dims * 2,
3, filter_size=3,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + "_c2", name=name + "_c2",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
res_input = conv2d( res_input = conv2d(
o_c2, input=o_c2,
g_base_dims * 4, num_filters=g_base_dims * 4,
3, filter_size=3,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + "_c3", name=name + "_c3",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
...@@ -166,37 +166,41 @@ def build_generator_resnet_blocks(inputgen, ...@@ -166,37 +166,41 @@ def build_generator_resnet_blocks(inputgen,
conv_name = name + "_r{}".format(i + 1) conv_name = name + "_r{}".format(i + 1)
res_output = build_resnet_block( res_output = build_resnet_block(
res_input, res_input,
g_base_dims * 4, dim=g_base_dims * 4,
name=conv_name, name=conv_name,
use_bias=use_bias, use_bias=use_bias,
use_dropout=use_dropout) use_dropout=use_dropout)
res_input = res_output res_input = res_output
o_c4 = deconv2d( o_c4 = deconv2d(
res_output, input=res_output,
g_base_dims * 2, num_filters=g_base_dims * 2,
3, filter_size=3,
2, stride=2,
0.02, [1, 1], [0, 1, 0, 1], stddev=0.02,
padding=[1, 1],
outpadding=[0, 1, 0, 1],
name=name + "_c4", name=name + "_c4",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
o_c5 = deconv2d( o_c5 = deconv2d(
o_c4, input=o_c4,
g_base_dims, num_filters=g_base_dims,
3, filter_size=3,
2, stride=2,
0.02, [1, 1], [0, 1, 0, 1], stddev=0.02,
padding=[1, 1],
outpadding=[0, 1, 0, 1],
name=name + "_c5", name=name + "_c5",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
o_p2 = fluid.layers.pad2d(o_c5, [3, 3, 3, 3], mode="reflect") o_p2 = fluid.layers.pad2d(o_c5, [3, 3, 3, 3], mode="reflect")
o_c6 = conv2d( o_c6 = conv2d(
o_p2, input=o_p2,
3, num_filters=3,
7, filter_size=7,
1, stride=1,
0.02, stddev=0.02,
name=name + "_c6", name=name + "_c6",
activation_fn='tanh', activation_fn='tanh',
use_bias=True) use_bias=True)
...@@ -217,33 +221,33 @@ def Unet_block(inputunet, ...@@ -217,33 +221,33 @@ def Unet_block(inputunet,
name=None): name=None):
if outermost == True: if outermost == True:
downconv = conv2d( downconv = conv2d(
inputunet, input=inputunet,
inner_dim, num_filters=inner_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_outermost_dc1', name=name + '_outermost_dc1',
use_bias=True) use_bias=True)
i += 1 i += 1
mid_block = Unet_block( mid_block = Unet_block(
downconv, downconv,
i, i,
inner_dim, outer_dim=inner_dim,
inner_dim * 2, inner_dim=inner_dim * 2,
num_downsample, num_downsample=num_downsample,
norm_type=norm_type, norm_type=norm_type,
use_bias=use_bias, use_bias=use_bias,
use_dropout=use_dropout, use_dropout=use_dropout,
name=name) name=name)
uprelu = fluid.layers.relu(mid_block, name=name + '_outermost_relu') uprelu = fluid.layers.relu(mid_block, name=name + '_outermost_relu')
updeconv = deconv2d( updeconv = deconv2d(
uprelu, input=uprelu,
outer_dim, num_filters=outer_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_outermost_uc1', name=name + '_outermost_uc1',
activation_fn='tanh', activation_fn='tanh',
use_bias=use_bias) use_bias=use_bias)
...@@ -252,22 +256,22 @@ def Unet_block(inputunet, ...@@ -252,22 +256,22 @@ def Unet_block(inputunet,
downrelu = fluid.layers.leaky_relu( downrelu = fluid.layers.leaky_relu(
inputunet, 0.2, name=name + '_innermost_leaky_relu') inputunet, 0.2, name=name + '_innermost_leaky_relu')
upconv = conv2d( upconv = conv2d(
downrelu, input=downrelu,
inner_dim, num_filters=inner_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_innermost_dc1', name=name + '_innermost_dc1',
activation_fn='relu', activation_fn='relu',
use_bias=use_bias) use_bias=use_bias)
updeconv = deconv2d( updeconv = deconv2d(
upconv, input=upconv,
outer_dim, num_filters=outer_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_innermost_uc1', name=name + '_innermost_uc1',
norm=norm_type, norm=norm_type,
use_bias=use_bias) use_bias=use_bias)
...@@ -276,12 +280,12 @@ def Unet_block(inputunet, ...@@ -276,12 +280,12 @@ def Unet_block(inputunet,
downrelu = fluid.layers.leaky_relu( downrelu = fluid.layers.leaky_relu(
inputunet, 0.2, name=name + '_leaky_relu') inputunet, 0.2, name=name + '_leaky_relu')
downnorm = conv2d( downnorm = conv2d(
downrelu, input=downrelu,
inner_dim, num_filters=inner_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + 'dc1', name=name + 'dc1',
norm=norm_type, norm=norm_type,
use_bias=use_bias) use_bias=use_bias)
...@@ -290,9 +294,9 @@ def Unet_block(inputunet, ...@@ -290,9 +294,9 @@ def Unet_block(inputunet,
mid_block = Unet_block( mid_block = Unet_block(
downnorm, downnorm,
i, i,
inner_dim, outer_dim=inner_dim,
inner_dim * 2, inner_dim=inner_dim * 2,
num_downsample, num_downsample=num_downsample,
norm_type=norm_type, norm_type=norm_type,
use_bias=use_bias, use_bias=use_bias,
name=name + '_mid{}'.format(i)) name=name + '_mid{}'.format(i))
...@@ -300,9 +304,9 @@ def Unet_block(inputunet, ...@@ -300,9 +304,9 @@ def Unet_block(inputunet,
mid_block = Unet_block( mid_block = Unet_block(
downnorm, downnorm,
i, i,
inner_dim, outer_dim=inner_dim,
inner_dim, inner_dim=inner_dim,
num_downsample, num_downsample=num_downsample,
norm_type=norm_type, norm_type=norm_type,
use_bias=use_bias, use_bias=use_bias,
use_dropout=use_dropout, use_dropout=use_dropout,
...@@ -311,21 +315,21 @@ def Unet_block(inputunet, ...@@ -311,21 +315,21 @@ def Unet_block(inputunet,
mid_block = Unet_block( mid_block = Unet_block(
downnorm, downnorm,
i, i,
inner_dim, outer_dim=inner_dim,
inner_dim, inner_dim=inner_dim,
num_downsample, num_downsample=num_downsample,
innermost=True, innermost=True,
norm_type=norm_type, norm_type=norm_type,
use_bias=use_bias, use_bias=use_bias,
name=name + '_innermost') name=name + '_innermost')
uprelu = fluid.layers.relu(mid_block, name=name + '_relu') uprelu = fluid.layers.relu(mid_block, name=name + '_relu')
updeconv = deconv2d( updeconv = deconv2d(
uprelu, input=uprelu,
outer_dim, num_filters=outer_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_uc1', name=name + '_uc1',
norm=norm_type, norm=norm_type,
use_bias=use_bias) use_bias=use_bias)
...@@ -345,10 +349,10 @@ def build_generator_Unet(inputgen, ...@@ -345,10 +349,10 @@ def build_generator_Unet(inputgen,
use_bias = norm_type == 'instance_norm' use_bias = norm_type == 'instance_norm'
unet_block = Unet_block( unet_block = Unet_block(
inputgen, inputgen,
0, i=0,
3, outer_dim=3,
g_base_dims, inner_dim=g_base_dims,
num_downsample, num_downsample=num_downsample,
outermost=True, outermost=True,
norm_type=norm_type, norm_type=norm_type,
use_bias=use_bias, use_bias=use_bias,
...@@ -364,12 +368,12 @@ def build_discriminator_Nlayers(inputdisc, ...@@ -364,12 +368,12 @@ def build_discriminator_Nlayers(inputdisc,
norm_type='batch_norm'): norm_type='batch_norm'):
use_bias = norm_type != 'batch_norm' use_bias = norm_type != 'batch_norm'
dis_input = conv2d( dis_input = conv2d(
inputdisc, input=inputdisc,
d_base_dims, num_filters=d_base_dims,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + "_c1", name=name + "_c1",
activation_fn='leaky_relu', activation_fn='leaky_relu',
relufactor=0.2, relufactor=0.2,
...@@ -379,12 +383,12 @@ def build_discriminator_Nlayers(inputdisc, ...@@ -379,12 +383,12 @@ def build_discriminator_Nlayers(inputdisc,
conv_name = name + "_c{}".format(i + 2) conv_name = name + "_c{}".format(i + 2)
d_dims *= 2 d_dims *= 2
dis_output = conv2d( dis_output = conv2d(
dis_input, input=dis_input,
d_dims, num_filters=d_dims,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=conv_name, name=conv_name,
norm=norm_type, norm=norm_type,
activation_fn='leaky_relu', activation_fn='leaky_relu',
...@@ -393,25 +397,25 @@ def build_discriminator_Nlayers(inputdisc, ...@@ -393,25 +397,25 @@ def build_discriminator_Nlayers(inputdisc,
dis_input = dis_output dis_input = dis_output
last_dims = min(2**d_nlayers, 8) last_dims = min(2**d_nlayers, 8)
o_c4 = conv2d( o_c4 = conv2d(
dis_output, input=dis_output,
d_base_dims * last_dims, num_filters=d_base_dims * last_dims,
4, filter_size=4,
1, stride=1,
0.02, stddev=0.02,
1, padding=1,
name + "_c{}".format(d_nlayers + 1), name=name + "_c{}".format(d_nlayers + 1),
norm=norm_type, norm=norm_type,
activation_fn='leaky_relu', activation_fn='leaky_relu',
relufactor=0.2, relufactor=0.2,
use_bias=use_bias) use_bias=use_bias)
o_c5 = conv2d( o_c5 = conv2d(
o_c4, input=o_c4,
1, num_filters=1,
4, filter_size=4,
1, stride=1,
0.02, stddev=0.02,
1, padding=1,
name + "_c{}".format(d_nlayers + 2), name=name + "_c{}".format(d_nlayers + 2),
use_bias=True) use_bias=True)
return o_c5 return o_c5
...@@ -422,25 +426,32 @@ def build_discriminator_Pixel(inputdisc, ...@@ -422,25 +426,32 @@ def build_discriminator_Pixel(inputdisc,
norm_type='batch_norm'): norm_type='batch_norm'):
use_bias = norm_type != 'instance_norm' use_bias = norm_type != 'instance_norm'
o_c1 = conv2d( o_c1 = conv2d(
inputdisc, input=inputdisc,
d_base_dims, num_filters=d_base_dims,
1, filter_size=1,
1, stride=1,
0.02, stddev=0.02,
name=name + '_c1', name=name + '_c1',
activation_fn='leaky_relu', activation_fn='leaky_relu',
relufactor=0.2, relufactor=0.2,
use_bias=True) use_bias=True)
o_c2 = conv2d( o_c2 = conv2d(
o_c1, input=o_c1,
d_base_dims * 2, num_filters=d_base_dims * 2,
1, filter_size=1,
1, stride=1,
0.02, stddev=0.02,
name=name + '_c2', name=name + '_c2',
norm=norm_type, norm=norm_type,
activation_fn='leaky_relu', activation_fn='leaky_relu',
relufactor=0.2, relufactor=0.2,
use_bias=use_bias) use_bias=use_bias)
o_c3 = conv2d(o_c2, 1, 1, 1, 0.02, name=name + '_c3', use_bias=use_bias) o_c3 = conv2d(
o_c2,
num_filters=1,
filter_size=1,
stride=1,
stddev=0.02,
name=name + '_c3',
use_bias=use_bias)
return o_c3 return o_c3
...@@ -16,7 +16,7 @@ from __future__ import absolute_import ...@@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from .base_network import conv2d, deconv2d, linear from .base_network import norm_layer, deconv2d, linear, conv_and_pool
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
...@@ -37,30 +37,38 @@ class DCGAN_model(object): ...@@ -37,30 +37,38 @@ class DCGAN_model(object):
self.norm = "batch_norm" self.norm = "batch_norm"
def network_G(self, input, name="generator"): def network_G(self, input, name="generator"):
o_l1 = linear(input, self.gfc_dim, norm=self.norm, name=name + '_l1') o_l1 = linear(
input,
self.gfc_dim,
norm=self.norm,
activation_fn='relu',
name=name + '_l1')
o_l2 = linear( o_l2 = linear(
o_l1, input=o_l1,
self.gf_dim * 2 * self.img_dim // 4 * self.img_dim // 4, output_size=self.gf_dim * 2 * self.img_dim // 4 * self.img_dim // 4,
norm=self.norm, norm=self.norm,
activation_fn='relu',
name=name + '_l2') name=name + '_l2')
o_r1 = fluid.layers.reshape( o_r1 = fluid.layers.reshape(
o_l2, [-1, self.df_dim * 2, self.img_dim // 4, self.img_dim // 4]) o_l2, [-1, self.df_dim * 2, self.img_dim // 4, self.img_dim // 4])
o_dc1 = deconv2d( o_dc1 = deconv2d(
o_r1, input=o_r1,
self.gf_dim * 2, num_filters=self.gf_dim * 2,
4, filter_size=5,
2, stride=2,
padding=[1, 1], padding=2,
activation_fn='relu', activation_fn='relu',
output_size=[self.img_dim // 2, self.img_dim // 2], output_size=[self.img_dim // 2, self.img_dim // 2],
use_bias=True,
name=name + '_dc1') name=name + '_dc1')
o_dc2 = deconv2d( o_dc2 = deconv2d(
o_dc1, input=o_dc1,
1, num_filters=1,
4, filter_size=5,
2, stride=2,
padding=[1, 1], padding=2,
activation_fn='tanh', activation_fn='tanh',
use_bias=True,
output_size=[self.img_dim, self.img_dim], output_size=[self.img_dim, self.img_dim],
name=name + '_dc2') name=name + '_dc2')
out = fluid.layers.reshape(o_dc2, shape=[-1, 28 * 28]) out = fluid.layers.reshape(o_dc2, shape=[-1, 28 * 28])
...@@ -69,26 +77,15 @@ class DCGAN_model(object): ...@@ -69,26 +77,15 @@ class DCGAN_model(object):
def network_D(self, input, name="discriminator"): def network_D(self, input, name="discriminator"):
o_r1 = fluid.layers.reshape( o_r1 = fluid.layers.reshape(
input, shape=[-1, 1, self.img_dim, self.img_dim]) input, shape=[-1, 1, self.img_dim, self.img_dim])
o_c1 = conv2d( o_c1 = conv_and_pool(
o_r1, o_r1, self.df_dim, name=name + '_c1', act='leaky_relu')
self.df_dim, o_c2_1 = conv_and_pool(o_c1, self.df_dim * 2, name=name + '_c2')
4, o_c2_2 = norm_layer(
2, o_c2_1, norm_type='batch_norm', name=name + '_c2_bn')
padding=[1, 1], o_c2 = fluid.layers.leaky_relu(o_c2_2, name=name + '_c2_leaky_relu')
activation_fn='leaky_relu',
name=name + '_c1')
o_c2 = conv2d(
o_c1,
self.df_dim * 2,
4,
2,
padding=[1, 1],
norm='batch_norm',
activation_fn='leaky_relu',
name=name + '_c2')
o_l1 = linear( o_l1 = linear(
o_c2, input=o_c2,
self.dfc_dim, output_size=self.dfc_dim,
norm=self.norm, norm=self.norm,
activation_fn='leaky_relu', activation_fn='leaky_relu',
name=name + '_l1') name=name + '_l1')
......
...@@ -97,11 +97,11 @@ def build_resnet_block(inputres, ...@@ -97,11 +97,11 @@ def build_resnet_block(inputres,
norm_type='batch_norm'): norm_type='batch_norm'):
out_res = fluid.layers.pad2d(inputres, [1, 1, 1, 1], mode="reflect") out_res = fluid.layers.pad2d(inputres, [1, 1, 1, 1], mode="reflect")
out_res = conv2d( out_res = conv2d(
out_res, input=out_res,
dim, num_filters=dim,
3, filter_size=3,
1, stride=1,
0.02, stddev=0.02,
name=name + "_c1", name=name + "_c1",
norm=norm_type, norm=norm_type,
activation_fn='relu', activation_fn='relu',
...@@ -112,11 +112,11 @@ def build_resnet_block(inputres, ...@@ -112,11 +112,11 @@ def build_resnet_block(inputres,
out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect") out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect")
out_res = conv2d( out_res = conv2d(
out_res, input=out_res,
dim, num_filters=dim,
3, filter_size=3,
1, stride=1,
0.02, stddev=0.02,
name=name + "_c2", name=name + "_c2",
norm=norm_type, norm=norm_type,
use_bias=use_bias) use_bias=use_bias)
...@@ -134,31 +134,31 @@ def build_generator_resnet_blocks(inputgen, ...@@ -134,31 +134,31 @@ def build_generator_resnet_blocks(inputgen,
use_bias = norm_type == 'instance_norm' use_bias = norm_type == 'instance_norm'
pad_input = fluid.layers.pad2d(inputgen, [3, 3, 3, 3], mode="reflect") pad_input = fluid.layers.pad2d(inputgen, [3, 3, 3, 3], mode="reflect")
o_c1 = conv2d( o_c1 = conv2d(
pad_input, input=pad_input,
g_base_dims, num_filters=g_base_dims,
7, filter_size=7,
1, stride=1,
0.02, stddev=0.02,
name=name + "_c1", name=name + "_c1",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
o_c2 = conv2d( o_c2 = conv2d(
o_c1, input=o_c1,
g_base_dims * 2, num_filters=g_base_dims * 2,
3, filter_size=3,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + "_c2", name=name + "_c2",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
res_input = conv2d( res_input = conv2d(
o_c2, input=o_c2,
g_base_dims * 4, num_filters=g_base_dims * 4,
3, filter_size=3,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + "_c3", name=name + "_c3",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
...@@ -173,30 +173,34 @@ def build_generator_resnet_blocks(inputgen, ...@@ -173,30 +173,34 @@ def build_generator_resnet_blocks(inputgen,
res_input = res_output res_input = res_output
o_c4 = deconv2d( o_c4 = deconv2d(
res_output, input=res_output,
g_base_dims * 2, num_filters=g_base_dims * 2,
3, filter_size=3,
2, stride=2,
0.02, [1, 1], [0, 1, 0, 1], stddev=0.02,
padding=[1, 1],
outpadding=[0, 1, 0, 1],
name=name + "_c4", name=name + "_c4",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
o_c5 = deconv2d( o_c5 = deconv2d(
o_c4, input=o_c4,
g_base_dims, num_filters=g_base_dims,
3, filter_size=3,
2, stride=2,
0.02, [1, 1], [0, 1, 0, 1], stddev=0.02,
padding=[1, 1],
outpadding=[0, 1, 0, 1],
name=name + "_c5", name=name + "_c5",
norm=norm_type, norm=norm_type,
activation_fn='relu') activation_fn='relu')
o_p2 = fluid.layers.pad2d(o_c5, [3, 3, 3, 3], mode="reflect") o_p2 = fluid.layers.pad2d(o_c5, [3, 3, 3, 3], mode="reflect")
o_c6 = conv2d( o_c6 = conv2d(
o_p2, input=o_p2,
3, num_filters=3,
7, filter_size=7,
1, stride=1,
0.02, stddev=0.02,
name=name + "_c6", name=name + "_c6",
activation_fn='tanh', activation_fn='tanh',
use_bias=True) use_bias=True)
...@@ -217,12 +221,12 @@ def Unet_block(inputunet, ...@@ -217,12 +221,12 @@ def Unet_block(inputunet,
name=None): name=None):
if outermost == True: if outermost == True:
downconv = conv2d( downconv = conv2d(
inputunet, input=inputunet,
inner_dim, num_filters=inner_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_outermost_dc1', name=name + '_outermost_dc1',
use_bias=True) use_bias=True)
i += 1 i += 1
...@@ -238,12 +242,12 @@ def Unet_block(inputunet, ...@@ -238,12 +242,12 @@ def Unet_block(inputunet,
name=name) name=name)
uprelu = fluid.layers.relu(mid_block, name=name + '_outermost_relu') uprelu = fluid.layers.relu(mid_block, name=name + '_outermost_relu')
updeconv = deconv2d( updeconv = deconv2d(
uprelu, input=uprelu,
outer_dim, num_filters=outer_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_outermost_uc1', name=name + '_outermost_uc1',
activation_fn='tanh', activation_fn='tanh',
use_bias=use_bias) use_bias=use_bias)
...@@ -252,22 +256,22 @@ def Unet_block(inputunet, ...@@ -252,22 +256,22 @@ def Unet_block(inputunet,
downrelu = fluid.layers.leaky_relu( downrelu = fluid.layers.leaky_relu(
inputunet, 0.2, name=name + '_innermost_leaky_relu') inputunet, 0.2, name=name + '_innermost_leaky_relu')
upconv = conv2d( upconv = conv2d(
downrelu, input=downrelu,
inner_dim, num_filters=inner_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_innermost_dc1', name=name + '_innermost_dc1',
activation_fn='relu', activation_fn='relu',
use_bias=use_bias) use_bias=use_bias)
updeconv = deconv2d( updeconv = deconv2d(
upconv, input=upconv,
outer_dim, num_filters=outer_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_innermost_uc1', name=name + '_innermost_uc1',
norm=norm_type, norm=norm_type,
use_bias=use_bias) use_bias=use_bias)
...@@ -276,12 +280,12 @@ def Unet_block(inputunet, ...@@ -276,12 +280,12 @@ def Unet_block(inputunet,
downrelu = fluid.layers.leaky_relu( downrelu = fluid.layers.leaky_relu(
inputunet, 0.2, name=name + '_leaky_relu') inputunet, 0.2, name=name + '_leaky_relu')
downnorm = conv2d( downnorm = conv2d(
downrelu, input=downrelu,
inner_dim, num_filters=inner_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + 'dc1', name=name + 'dc1',
norm=norm_type, norm=norm_type,
use_bias=use_bias) use_bias=use_bias)
...@@ -320,12 +324,12 @@ def Unet_block(inputunet, ...@@ -320,12 +324,12 @@ def Unet_block(inputunet,
name=name + '_innermost') name=name + '_innermost')
uprelu = fluid.layers.relu(mid_block, name=name + '_relu') uprelu = fluid.layers.relu(mid_block, name=name + '_relu')
updeconv = deconv2d( updeconv = deconv2d(
uprelu, input=uprelu,
outer_dim, num_filters=outer_dim,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + '_uc1', name=name + '_uc1',
norm=norm_type, norm=norm_type,
use_bias=use_bias) use_bias=use_bias)
...@@ -349,9 +353,9 @@ def UnetSkipConnectionBlock(input, ...@@ -349,9 +353,9 @@ def UnetSkipConnectionBlock(input,
if outermost: if outermost:
downconv = conv2d( downconv = conv2d(
input, input,
inner_nc, num_filters=inner_nc,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
use_bias=use_bias, use_bias=use_bias,
name=name + '_down_conv') name=name + '_down_conv')
...@@ -367,10 +371,10 @@ def UnetSkipConnectionBlock(input, ...@@ -367,10 +371,10 @@ def UnetSkipConnectionBlock(input,
name=name + '_u%d' % i) name=name + '_u%d' % i)
uprelu = fluid.layers.relu(sub_res) uprelu = fluid.layers.relu(sub_res)
upconv = deconv2d( upconv = deconv2d(
uprelu, input=uprelu,
outer_nc, num_filters=outer_nc,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
activation_fn='tanh', activation_fn='tanh',
name=name + '_up_conv') name=name + '_up_conv')
...@@ -378,19 +382,19 @@ def UnetSkipConnectionBlock(input, ...@@ -378,19 +382,19 @@ def UnetSkipConnectionBlock(input,
elif innermost: elif innermost:
downrelu = fluid.layers.leaky_relu(input, 0.2) downrelu = fluid.layers.leaky_relu(input, 0.2)
downconv = conv2d( downconv = conv2d(
downrelu, input=downrelu,
inner_nc, num_filters=inner_nc,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
use_bias=use_bias, use_bias=use_bias,
name=name + '_down_conv') name=name + '_down_conv')
uprelu = fluid.layers.relu(downconv) uprelu = fluid.layers.relu(downconv)
upconv = deconv2d( upconv = deconv2d(
uprelu, input=uprelu,
outer_nc, num_filters=outer_nc,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
use_bias=use_bias, use_bias=use_bias,
norm=norm, norm=norm,
...@@ -399,10 +403,10 @@ def UnetSkipConnectionBlock(input, ...@@ -399,10 +403,10 @@ def UnetSkipConnectionBlock(input,
else: else:
downrelu = fluid.layers.leaky_relu(input, 0.2) downrelu = fluid.layers.leaky_relu(input, 0.2)
downconv = conv2d( downconv = conv2d(
downrelu, input=downrelu,
inner_nc, num_filters=inner_nc,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
use_bias=use_bias, use_bias=use_bias,
norm=norm, norm=norm,
...@@ -441,10 +445,10 @@ def UnetSkipConnectionBlock(input, ...@@ -441,10 +445,10 @@ def UnetSkipConnectionBlock(input,
uprelu = fluid.layers.relu(sub_res) uprelu = fluid.layers.relu(sub_res)
upconv = deconv2d( upconv = deconv2d(
uprelu, input=uprelu,
outer_nc, num_filters=outer_nc,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
use_bias=use_bias, use_bias=use_bias,
norm=norm, norm=norm,
...@@ -483,12 +487,12 @@ def build_discriminator_Nlayers(inputdisc, ...@@ -483,12 +487,12 @@ def build_discriminator_Nlayers(inputdisc,
norm_type='batch_norm'): norm_type='batch_norm'):
use_bias = norm_type != 'batch_norm' use_bias = norm_type != 'batch_norm'
dis_input = conv2d( dis_input = conv2d(
inputdisc, input=inputdisc,
d_base_dims, num_filters=d_base_dims,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=name + "_c1", name=name + "_c1",
activation_fn='leaky_relu', activation_fn='leaky_relu',
relufactor=0.2, relufactor=0.2,
...@@ -498,12 +502,12 @@ def build_discriminator_Nlayers(inputdisc, ...@@ -498,12 +502,12 @@ def build_discriminator_Nlayers(inputdisc,
conv_name = name + "_c{}".format(i + 2) conv_name = name + "_c{}".format(i + 2)
d_dims *= 2 d_dims *= 2
dis_output = conv2d( dis_output = conv2d(
dis_input, input=dis_input,
d_dims, num_filters=d_dims,
4, filter_size=4,
2, stride=2,
0.02, stddev=0.02,
1, padding=1,
name=conv_name, name=conv_name,
norm=norm_type, norm=norm_type,
activation_fn='leaky_relu', activation_fn='leaky_relu',
...@@ -512,25 +516,25 @@ def build_discriminator_Nlayers(inputdisc, ...@@ -512,25 +516,25 @@ def build_discriminator_Nlayers(inputdisc,
dis_input = dis_output dis_input = dis_output
last_dims = min(2**d_nlayers, 8) last_dims = min(2**d_nlayers, 8)
o_c4 = conv2d( o_c4 = conv2d(
dis_output, input=dis_output,
d_base_dims * last_dims, num_filters=d_base_dims * last_dims,
4, filter_size=4,
1, stride=1,
0.02, stddev=0.02,
1, padding=1,
name + "_c{}".format(d_nlayers + 1), name=name + "_c{}".format(d_nlayers + 1),
norm=norm_type, norm=norm_type,
activation_fn='leaky_relu', activation_fn='leaky_relu',
relufactor=0.2, relufactor=0.2,
use_bias=use_bias) use_bias=use_bias)
o_c5 = conv2d( o_c5 = conv2d(
o_c4, input=o_c4,
1, num_filters=1,
4, filter_size=4,
1, stride=1,
0.02, stddev=0.02,
1, padding=1,
name + "_c{}".format(d_nlayers + 2), name=name + "_c{}".format(d_nlayers + 2),
use_bias=True) use_bias=True)
return o_c5 return o_c5
...@@ -541,25 +545,32 @@ def build_discriminator_Pixel(inputdisc, ...@@ -541,25 +545,32 @@ def build_discriminator_Pixel(inputdisc,
norm_type='batch_norm'): norm_type='batch_norm'):
use_bias = norm_type != 'instance_norm' use_bias = norm_type != 'instance_norm'
o_c1 = conv2d( o_c1 = conv2d(
inputdisc, input=inputdisc,
d_base_dims, num_filters=d_base_dims,
1, filter_size=1,
1, stride=1,
0.02, stddev=0.02,
name=name + '_c1', name=name + '_c1',
activation_fn='leaky_relu', activation_fn='leaky_relu',
relufactor=0.2, relufactor=0.2,
use_bias=True) use_bias=True)
o_c2 = conv2d( o_c2 = conv2d(
o_c1, input=o_c1,
d_base_dims * 2, num_filters=d_base_dims * 2,
1, filter_size=1,
1, stride=1,
0.02, stddev=0.02,
name=name + '_c2', name=name + '_c2',
norm=norm_type, norm=norm_type,
activation_fn='leaky_relu', activation_fn='leaky_relu',
relufactor=0.2, relufactor=0.2,
use_bias=use_bias) use_bias=use_bias)
o_c3 = conv2d(o_c2, 1, 1, 1, 0.02, name=name + '_c3', use_bias=use_bias) o_c3 = conv2d(
o_c2,
num_filters=1,
filter_size=1,
stride=1,
stddev=0.02,
name=name + '_c3',
use_bias=use_bias)
return o_c3 return o_c3
...@@ -92,10 +92,10 @@ class STGAN_model(object): ...@@ -92,10 +92,10 @@ class STGAN_model(object):
for i in range(n_layers): for i in range(n_layers):
d = min(dim * 2**i, MAX_DIM) d = min(dim * 2**i, MAX_DIM)
z = conv2d( z = conv2d(
z, input=z,
d, num_filters=d,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
norm="batch_norm", norm="batch_norm",
activation_fn='leaky_relu', activation_fn='leaky_relu',
...@@ -125,9 +125,9 @@ class STGAN_model(object): ...@@ -125,9 +125,9 @@ class STGAN_model(object):
for i in range(n_layers): for i in range(n_layers):
d = min(dim * 2**(n_layers - 1 - i), MAX_DIM) d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
output = self.gru_cell( output = self.gru_cell(
zs[n_layers - 1 - i], in_data=zs[n_layers - 1 - i],
state, state=state,
d, out_channel=d,
kernel_size=kernel_size, kernel_size=kernel_size,
norm=norm, norm=norm,
pass_state=pass_state, pass_state=pass_state,
...@@ -157,10 +157,10 @@ class STGAN_model(object): ...@@ -157,10 +157,10 @@ class STGAN_model(object):
if i < n_layers - 1: if i < n_layers - 1:
d = min(dim * 2**(n_layers - 1 - i), MAX_DIM) d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
z = deconv2d( z = deconv2d(
z, input=z,
d, num_filters=d,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
name=name + str(i), name=name + str(i),
norm='batch_norm', norm='batch_norm',
...@@ -174,10 +174,10 @@ class STGAN_model(object): ...@@ -174,10 +174,10 @@ class STGAN_model(object):
z = self.concat(z, a) z = self.concat(z, a)
else: else:
x = z = deconv2d( x = z = deconv2d(
z, input=z,
3, num_filters=3,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
name=name + str(i), name=name + str(i),
activation_fn='tanh', activation_fn='tanh',
...@@ -199,10 +199,10 @@ class STGAN_model(object): ...@@ -199,10 +199,10 @@ class STGAN_model(object):
for i in range(n_layers): for i in range(n_layers):
d = min(dim * 2**i, MAX_DIM) d = min(dim * 2**i, MAX_DIM)
y = conv2d( y = conv2d(
y, input=y,
d, num_filters=d,
4, filter_size=4,
2, stride=2,
norm=norm, norm=norm,
padding_type="SAME", padding_type="SAME",
activation_fn='leaky_relu', activation_fn='leaky_relu',
...@@ -212,22 +212,25 @@ class STGAN_model(object): ...@@ -212,22 +212,25 @@ class STGAN_model(object):
initial='kaiming') initial='kaiming')
logit_gan = linear( logit_gan = linear(
y, input=y,
fc_dim, output_size=fc_dim,
activation_fn='leaky_relu', activation_fn='leaky_relu',
name=name + 'fc_adv_1', name=name + 'fc_adv_1',
initial='kaiming') initial='kaiming')
logit_gan = linear( logit_gan = linear(
logit_gan, 1, name=name + 'fc_adv_2', initial='kaiming') logit_gan, output_size=1, name=name + 'fc_adv_2', initial='kaiming')
logit_att = linear( logit_att = linear(
y, input=y,
fc_dim, output_size=fc_dim,
activation_fn='leaky_relu', activation_fn='leaky_relu',
name=name + 'fc_cls_1', name=name + 'fc_cls_1',
initial='kaiming') initial='kaiming')
logit_att = linear( logit_att = linear(
logit_att, n_atts, name=name + 'fc_cls_2', initial='kaiming') logit_att,
output_size=n_atts,
name=name + 'fc_cls_2',
initial='kaiming')
return logit_gan, logit_att return logit_gan, logit_att
...@@ -241,10 +244,10 @@ class STGAN_model(object): ...@@ -241,10 +244,10 @@ class STGAN_model(object):
name='gru', name='gru',
is_test=False): is_test=False):
state_ = deconv2d( state_ = deconv2d(
state, input=state,
out_channel, num_filters=out_channel,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
name=name + '_deconv2d', name=name + '_deconv2d',
use_bias=True, use_bias=True,
...@@ -252,10 +255,10 @@ class STGAN_model(object): ...@@ -252,10 +255,10 @@ class STGAN_model(object):
is_test=is_test, is_test=is_test,
) # upsample and make `channel` identical to `out_channel` ) # upsample and make `channel` identical to `out_channel`
reset_gate = conv2d( reset_gate = conv2d(
fluid.layers.concat( input=fluid.layers.concat(
[in_data, state_], axis=1), [in_data, state_], axis=1),
out_channel, num_filters=out_channel,
kernel_size, filter_size=kernel_size,
norm=norm, norm=norm,
activation_fn='sigmoid', activation_fn='sigmoid',
padding_type='SAME', padding_type='SAME',
...@@ -264,10 +267,10 @@ class STGAN_model(object): ...@@ -264,10 +267,10 @@ class STGAN_model(object):
initial='kaiming', initial='kaiming',
is_test=is_test) is_test=is_test)
update_gate = conv2d( update_gate = conv2d(
fluid.layers.concat( input=fluid.layers.concat(
[in_data, state_], axis=1), [in_data, state_], axis=1),
out_channel, num_filters=out_channel,
kernel_size, filter_size=kernel_size,
norm=norm, norm=norm,
activation_fn='sigmoid', activation_fn='sigmoid',
padding_type='SAME', padding_type='SAME',
...@@ -277,10 +280,10 @@ class STGAN_model(object): ...@@ -277,10 +280,10 @@ class STGAN_model(object):
is_test=is_test) is_test=is_test)
left_state = reset_gate * state_ left_state = reset_gate * state_
new_info = conv2d( new_info = conv2d(
fluid.layers.concat( input=fluid.layers.concat(
[in_data, left_state], axis=1), [in_data, left_state], axis=1),
out_channel, num_filters=out_channel,
kernel_size, filter_size=kernel_size,
norm=norm, norm=norm,
activation_fn='tanh', activation_fn='tanh',
name=name + '_info', name=name + '_info',
......
...@@ -28,9 +28,9 @@ class StarGAN_model(object): ...@@ -28,9 +28,9 @@ class StarGAN_model(object):
def ResidualBlock(self, input, dim, name): def ResidualBlock(self, input, dim, name):
conv0 = conv2d( conv0 = conv2d(
input, input,
dim, num_filters=dim,
3, filter_size=3,
1, stride=1,
padding=1, padding=1,
use_bias=False, use_bias=False,
norm="instance_norm", norm="instance_norm",
...@@ -38,10 +38,10 @@ class StarGAN_model(object): ...@@ -38,10 +38,10 @@ class StarGAN_model(object):
name=name + ".main0", name=name + ".main0",
initial='kaiming') initial='kaiming')
conv1 = conv2d( conv1 = conv2d(
conv0, input=conv0,
dim, num_filters=dim,
3, filter_size=3,
1, stride=1,
padding=1, padding=1,
use_bias=False, use_bias=False,
norm="instance_norm", norm="instance_norm",
...@@ -59,10 +59,10 @@ class StarGAN_model(object): ...@@ -59,10 +59,10 @@ class StarGAN_model(object):
x=label_trg_e, expand_times=[1, 1, shape[2], shape[3]]) x=label_trg_e, expand_times=[1, 1, shape[2], shape[3]])
input1 = fluid.layers.concat([input, label_trg_e], 1) input1 = fluid.layers.concat([input, label_trg_e], 1)
conv0 = conv2d( conv0 = conv2d(
input1, input=input1,
cfg.g_base_dims, num_filters=cfg.g_base_dims,
7, filter_size=7,
1, stride=1,
padding=3, padding=3,
use_bias=False, use_bias=False,
norm="instance_norm", norm="instance_norm",
...@@ -73,10 +73,10 @@ class StarGAN_model(object): ...@@ -73,10 +73,10 @@ class StarGAN_model(object):
for i in range(2): for i in range(2):
rate = 2**(i + 1) rate = 2**(i + 1)
conv_down = conv2d( conv_down = conv2d(
conv_down, input=conv_down,
cfg.g_base_dims * rate, num_filters=cfg.g_base_dims * rate,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
use_bias=False, use_bias=False,
norm="instance_norm", norm="instance_norm",
...@@ -93,10 +93,10 @@ class StarGAN_model(object): ...@@ -93,10 +93,10 @@ class StarGAN_model(object):
for i in range(2): for i in range(2):
rate = 2**(1 - i) rate = 2**(1 - i)
deconv = deconv2d( deconv = deconv2d(
deconv, input=deconv,
cfg.g_base_dims * rate, num_filters=cfg.g_base_dims * rate,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
use_bias=False, use_bias=False,
norm="instance_norm", norm="instance_norm",
...@@ -104,10 +104,10 @@ class StarGAN_model(object): ...@@ -104,10 +104,10 @@ class StarGAN_model(object):
name=name + str(15 + i * 3), name=name + str(15 + i * 3),
initial='kaiming') initial='kaiming')
out = conv2d( out = conv2d(
deconv, input=deconv,
3, num_filters=3,
7, filter_size=7,
1, stride=1,
padding=3, padding=3,
use_bias=False, use_bias=False,
norm=None, norm=None,
...@@ -118,10 +118,10 @@ class StarGAN_model(object): ...@@ -118,10 +118,10 @@ class StarGAN_model(object):
def network_D(self, input, cfg, name="discriminator"): def network_D(self, input, cfg, name="discriminator"):
conv0 = conv2d( conv0 = conv2d(
input, input=input,
cfg.d_base_dims, num_filters=cfg.d_base_dims,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
activation_fn='leaky_relu', activation_fn='leaky_relu',
name=name + '0', name=name + '0',
...@@ -132,28 +132,28 @@ class StarGAN_model(object): ...@@ -132,28 +132,28 @@ class StarGAN_model(object):
for i in range(1, repeat_num): for i in range(1, repeat_num):
curr_dim *= 2 curr_dim *= 2
conv = conv2d( conv = conv2d(
conv, input=conv,
curr_dim, num_filters=curr_dim,
4, filter_size=4,
2, stride=2,
padding=1, padding=1,
activation_fn='leaky_relu', activation_fn='leaky_relu',
name=name + str(i * 2), name=name + str(i * 2),
initial='kaiming') initial='kaiming')
kernel_size = int(cfg.image_size / np.power(2, repeat_num)) kernel_size = int(cfg.image_size / np.power(2, repeat_num))
out1 = conv2d( out1 = conv2d(
conv, input=conv,
1, num_filters=1,
3, filter_size=3,
1, stride=1,
padding=1, padding=1,
use_bias=False, use_bias=False,
name="d_conv1", name="d_conv1",
initial='kaiming') initial='kaiming')
out2 = conv2d( out2 = conv2d(
conv, input=conv,
cfg.c_dim, num_filters=cfg.c_dim,
kernel_size, filter_size=kernel_size,
use_bias=False, use_bias=False,
name="d_conv2", name="d_conv2",
initial='kaiming') initial='kaiming')
......
...@@ -272,7 +272,7 @@ def deconv2d(input, ...@@ -272,7 +272,7 @@ def deconv2d(input,
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr)
if outpadding != 0 and padding_type == None: if np.mean(outpadding) != 0 and padding_type == None:
conv = fluid.layers.pad2d( conv = fluid.layers.pad2d(
conv, paddings=outpadding, mode='constant', pad_value=0.0) conv, paddings=outpadding, mode='constant', pad_value=0.0)
......
python infer.py --model_net AttGAN --init_model output/checkpoints/119/ --dataset_dir "data/celeba/" --image_size 128 python infer.py --model_net AttGAN --init_model output/checkpoints/119/ --dataset_dir ./data/celeba/ --image_size 128
python infer.py --model_net CGAN --init_model ./output/checkpoints/9/ --batch_size 32 --noise_size 100 python infer.py --model_net CGAN --init_model ./output/checkpoints/19/ --n_samples 32 --noise_size 100
python infer.py --init_model output/checkpoints/199/ --dataset_dir "data/cityscapes/testB/*" --image_size 256 --input_style B --model_net CycleGAN --net_G resnet_9block --g_base_dims 32 python infer.py --init_model output/checkpoints/199/ --dataset_dir data/cityscapes/ --image_size 256 --n_samples 1 --crop_size 256 --input_style B --test_list ./data/cityscapes/testB.txt --model_net CycleGAN --net_G resnet_9block --g_base_dims 32
python infer.py --model_net DCGAN --init_model ./output/checkpoints/9/ --batch_size 32 --noise_size 100 python infer.py --model_net DCGAN --init_model ./output/checkpoints/19/ --n_samples 32 --noise_size 100
python infer.py --init_model output/checkpoints/199/ --image_size 256 --dataset_dir "data/cityscapes/testB/*" --model_net Pix2pix --net_G unet_256 python infer.py --init_model output/checkpoints/199/ --image_size 256 --n_samples 1 --crop_size 256 --dataset_dir data/cityscapes/ --model_net Pix2pix --net_G unet_256 --test_list data/cityscapes/testB.txt
python infer.py --model_net StarGAN --init_model output/checkpoints/19/ --dataset_dir "data/celeba/" --image_size 128 --c_dim 5 --selected_attrs "Black_Hair,Blond_Hair,Brown_Hair,Male,Young" python infer.py --model_net StarGAN --init_model ./output/checkpoints/19/ --dataset_dir ./data/celeba/ --image_size 128 --c_dim 5 --selected_attrs "Black_Hair,Blond_Hair,Brown_Hair,Male,Young"
python infer.py --model_net STGAN --init_model output/checkpoints/19/ --dataset_dir "data/celeba/" --image_size 128 --use_gru True python infer.py --model_net STGAN --init_model ./output/checkpoints/19/ --dataset_dir ./data/celeba/ --image_size 128 --use_gru True
...@@ -3,7 +3,7 @@ import argparse ...@@ -3,7 +3,7 @@ import argparse
parser = argparse.ArgumentParser(description='the direction of data list') parser = argparse.ArgumentParser(description='the direction of data list')
parser.add_argument( parser.add_argument(
'--direction', type=str, default='A2B', help='the direction of data list') '--direction', type=str, default='B2A', help='the direction of data list')
def make_pair_data(fileA, file, d): def make_pair_data(fileA, file, d):
...@@ -27,10 +27,10 @@ def make_pair_data(fileA, file, d): ...@@ -27,10 +27,10 @@ def make_pair_data(fileA, file, d):
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
trainA_file = "./data/cityscapes/trainA.txt" trainA_file = os.path.join("data", "cityscapes", "trainA.txt")
train_file = "./data/cityscapes/pix2pix_train_list" train_file = os.path.join("data", "cityscapes", "pix2pix_train_list")
make_pair_data(trainA_file, train_file, args.direction) make_pair_data(trainA_file, train_file, args.direction)
testA_file = "./data/cityscapes/testA.txt" testA_file = os.path.join("data", "cityscapes", "testA.txt")
test_file = "./data/cityscapes/pix2pix_test_list" test_file = os.path.join("data", "cityscapes", "pix2pix_test_list")
make_pair_data(testA_file, test_file, args.direction) make_pair_data(testA_file, test_file, args.direction)
python train.py --model_net CGAN --dataset mnist --noise_size 100 --batch_size 32 --epoch 10 >log_out 2>log_err python train.py --model_net CGAN --dataset mnist --noise_size 100 --batch_size 121 --epoch 20 >log_out 2>log_err
python train.py --model_net DCGAN --dataset mnist --noise_size 100 --batch_size 32 --epoch 10 >log_out 2>log_err python train.py --model_net DCGAN --dataset mnist --noise_size 100 --batch_size 128 --epoch 20 >log_out 2>log_err
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --batch_size 16 --epoch 20 > log_out 2>log_err
...@@ -38,24 +38,25 @@ def train(cfg): ...@@ -38,24 +38,25 @@ def train(cfg):
reader = data_reader(cfg) reader = data_reader(cfg)
if cfg.model_net in ['CycleGAN']: if cfg.model_net in ['CycleGAN']:
a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data( a_reader, b_reader, a_reader_test, b_reader_test, batch_num, a_id2name, b_id2name = reader.make_data(
) )
else: else:
if cfg.dataset in ['mnist']: if cfg.dataset in ['mnist']:
train_reader = reader.make_data() train_reader = reader.make_data()
else: else:
train_reader, test_reader, batch_num = reader.make_data() train_reader, test_reader, batch_num, id2name = reader.make_data()
if cfg.model_net in ['CGAN', 'DCGAN']: if cfg.model_net in ['CGAN', 'DCGAN']:
if cfg.dataset != 'mnist': if cfg.dataset != 'mnist':
raise NotImplementedError("CGAN/DCGAN only support MNIST now!") raise NotImplementedError("CGAN/DCGAN only support MNIST now!")
model = trainer.__dict__[cfg.model_net](cfg, train_reader) model = trainer.__dict__[cfg.model_net](cfg, train_reader)
elif cfg.model_net in ['CycleGAN']: elif cfg.model_net in ['CycleGAN']:
model = trainer.__dict__[cfg.model_net]( model = trainer.__dict__[cfg.model_net](cfg, a_reader, b_reader,
cfg, a_reader, b_reader, a_reader_test, b_reader_test, batch_num) a_reader_test, b_reader_test,
batch_num, a_id2name, b_id2name)
else: else:
model = trainer.__dict__[cfg.model_net](cfg, train_reader, test_reader, model = trainer.__dict__[cfg.model_net](cfg, train_reader, test_reader,
batch_num) batch_num, id2name)
model.build_model() model.build_model()
......
...@@ -51,12 +51,9 @@ class GTrainer(): ...@@ -51,12 +51,9 @@ class GTrainer():
self.g_loss_fake.persistable = True self.g_loss_fake.persistable = True
self.g_loss_rec.persistable = True self.g_loss_rec.persistable = True
self.g_loss_cls.persistable = True self.g_loss_cls.persistable = True
if cfg.epoch <= 100: lr = fluid.layers.piecewise_decay(
lr = cfg.g_lr boundaries=[99 * step_per_epoch],
else: values=[cfg.g_lr, cfg.g_lr * 0.1])
lr = fluid.layers.piecewise_decay(
boundaries=[99 * step_per_epoch],
values=[cfg.g_lr, cfg.g_lr * 0.1], )
vars = [] vars = []
for var in self.program.list_vars(): for var in self.program.list_vars():
if fluid.io.is_parameter(var) and var.name.startswith( if fluid.io.is_parameter(var) and var.name.startswith(
...@@ -76,11 +73,6 @@ class DTrainer(): ...@@ -76,11 +73,6 @@ class DTrainer():
lr = cfg.d_lr lr = cfg.d_lr
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
model = AttGAN_model() model = AttGAN_model()
clone_image_real = []
for b in self.program.blocks:
if b.has_var('image_real'):
clone_image_real = b.var('image_real')
break
self.fake_img, _ = model.network_G( self.fake_img, _ = model.network_G(
image_real, label_org, label_trg_, cfg, name="generator") image_real, label_org, label_trg_, cfg, name="generator")
self.pred_real, self.cls_real = model.network_D( self.pred_real, self.cls_real = model.network_D(
...@@ -96,7 +88,7 @@ class DTrainer(): ...@@ -96,7 +88,7 @@ class DTrainer():
self.d_loss_real = -1 * fluid.layers.reduce_mean(self.pred_real) self.d_loss_real = -1 * fluid.layers.reduce_mean(self.pred_real)
self.d_loss_gp = self.gradient_penalty( self.d_loss_gp = self.gradient_penalty(
model.network_D, model.network_D,
clone_image_real, image_real,
self.fake_img, self.fake_img,
cfg=cfg, cfg=cfg,
name="discriminator") name="discriminator")
...@@ -114,7 +106,13 @@ class DTrainer(): ...@@ -114,7 +106,13 @@ class DTrainer():
x=self.pred_real, y=ones))) x=self.pred_real, y=ones)))
self.d_loss_fake = fluid.layers.mean( self.d_loss_fake = fluid.layers.mean(
fluid.layers.square(x=self.pred_fake)) fluid.layers.square(x=self.pred_fake))
self.d_loss = self.d_loss_real + self.d_loss_fake + self.d_loss_cls self.d_loss_gp = self.gradient_penalty(
model.network_D,
image_real,
None,
cfg=cfg,
name="discriminator")
self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
self.d_loss_real.persistable = True self.d_loss_real.persistable = True
self.d_loss_fake.persistable = True self.d_loss_fake.persistable = True
...@@ -128,12 +126,9 @@ class DTrainer(): ...@@ -128,12 +126,9 @@ class DTrainer():
vars.append(var.name) vars.append(var.name)
self.param = vars self.param = vars
if cfg.epoch <= 100: lr = fluid.layers.piecewise_decay(
lr = cfg.d_lr boundaries=[99 * step_per_epoch],
else: values=[cfg.g_lr, cfg.g_lr * 0.1])
lr = fluid.layers.piecewise_decay(
boundaries=[99 * step_per_epoch],
values=[cfg.g_lr, cfg.g_lr * 0.1], )
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
learning_rate=lr, beta1=0.5, beta2=0.999, name="net_D") learning_rate=lr, beta1=0.5, beta2=0.999, name="net_D")
...@@ -141,14 +136,21 @@ class DTrainer(): ...@@ -141,14 +136,21 @@ class DTrainer():
def gradient_penalty(self, f, real, fake=None, cfg=None, name=None): def gradient_penalty(self, f, real, fake=None, cfg=None, name=None):
def _interpolate(a, b=None): def _interpolate(a, b=None):
if b is None:
beta = fluid.layers.uniform_random_batch_size_like(
input=a, shape=a.shape, min=0.0, max=1.0)
mean = fluid.layers.reduce_mean(
a, range(len(a.shape)), keep_dim=True)
input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0)
var = fluid.layers.reduce_mean(
fluid.layers.square(input_sub_mean),
range(len(a.shape)),
keep_dim=True)
b = beta * fluid.layers.sqrt(var) * 0.5 + a
shape = [a.shape[0]] shape = [a.shape[0]]
alpha = fluid.layers.uniform_random_batch_size_like( alpha = fluid.layers.uniform_random_batch_size_like(
input=a, shape=shape, min=0.0, max=1.0) input=a, shape=shape, min=0.0, max=1.0)
tmp = fluid.layers.elementwise_mul( inner = (b - a) * alpha + a
fluid.layers.elementwise_sub(b, a), alpha, axis=0)
alpha.stop_gradient = True
tmp.stop_gradient = True
inner = fluid.layers.elementwise_add(a, tmp, axis=0)
return inner return inner
x = _interpolate(real, fake) x = _interpolate(real, fake)
...@@ -240,11 +242,13 @@ class AttGAN(object): ...@@ -240,11 +242,13 @@ class AttGAN(object):
cfg=None, cfg=None,
train_reader=None, train_reader=None,
test_reader=None, test_reader=None,
batch_num=1): batch_num=1,
id2name=None):
self.cfg = cfg self.cfg = cfg
self.train_reader = train_reader self.train_reader = train_reader
self.test_reader = test_reader self.test_reader = test_reader
self.batch_num = batch_num self.batch_num = batch_num
self.id2name = id2name
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.image_size, self.cfg.image_size] data_shape = [-1, 3, self.cfg.image_size, self.cfg.image_size]
...@@ -259,6 +263,20 @@ class AttGAN(object): ...@@ -259,6 +263,20 @@ class AttGAN(object):
name='label_org_', shape=[self.cfg.c_dim], dtype='float32') name='label_org_', shape=[self.cfg.c_dim], dtype='float32')
label_trg_ = fluid.layers.data( label_trg_ = fluid.layers.data(
name='label_trg_', shape=[self.cfg.c_dim], dtype='float32') name='label_trg_', shape=[self.cfg.c_dim], dtype='float32')
py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg],
capacity=64,
iterable=True,
use_double_buffer=True)
test_gen_trainer = GTrainer(image_real, label_org, label_org_,
label_trg, label_trg_, self.cfg,
self.batch_num)
label_org_ = (label_org * 2.0 - 1.0) * self.cfg.thres_int
label_trg_ = (label_trg * 2.0 - 1.0) * self.cfg.thres_int
gen_trainer = GTrainer(image_real, label_org, label_org_, label_trg, gen_trainer = GTrainer(image_real, label_org, label_org_, label_trg,
label_trg_, self.cfg, self.batch_num) label_trg_, self.cfg, self.batch_num)
dis_trainer = DTrainer(image_real, label_org, label_org_, label_trg, dis_trainer = DTrainer(image_real, label_org, label_org_, label_trg,
...@@ -266,6 +284,7 @@ class AttGAN(object): ...@@ -266,6 +284,7 @@ class AttGAN(object):
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
py_reader.decorate_batch_generator(self.train_reader, places=place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -276,7 +295,6 @@ class AttGAN(object): ...@@ -276,7 +295,6 @@ class AttGAN(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
gen_trainer_program = fluid.CompiledProgram( gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel( gen_trainer.program).with_data_parallel(
...@@ -291,84 +309,61 @@ class AttGAN(object): ...@@ -291,84 +309,61 @@ class AttGAN(object):
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for i in range(self.batch_num): for data in py_reader():
image, label_org = next(self.train_reader())
label_trg = copy.deepcopy(label_org)
np.random.shuffle(label_trg)
label_org_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_org))
label_trg_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg))
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_label_trg = fluid.LoDTensor()
tensor_label_org_ = fluid.LoDTensor()
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(image, place)
tensor_label_org.set(label_org, place)
tensor_label_trg.set(label_trg, place)
tensor_label_org_.set(label_org_, place)
tensor_label_trg_.set(label_trg_, place)
label_shape = tensor_label_trg.shape
s_time = time.time() s_time = time.time()
# optimize the discriminator network # optimize the discriminator network
if (batch_id + 1) % self.cfg.num_discriminator_time != 0: fetches = [
fetches = [ dis_trainer.d_loss.name,
dis_trainer.d_loss.name, dis_trainer.d_loss_real.name, dis_trainer.d_loss_real.name,
dis_trainer.d_loss_fake.name, dis_trainer.d_loss_fake.name,
dis_trainer.d_loss_cls.name, dis_trainer.d_loss_gp.name dis_trainer.d_loss_cls.name,
] dis_trainer.d_loss_gp.name,
d_loss, d_loss_real, d_loss_fake, d_loss_cls, d_loss_gp = exe.run( ]
dis_trainer_program, d_loss, d_loss_real, d_loss_fake, d_loss_cls, d_loss_gp = exe.run(
fetch_list=fetches, dis_trainer_program, fetch_list=fetches, feed=data)
feed={
"image_real": tensor_img, if (batch_id + 1) % self.cfg.num_discriminator_time == 0:
"label_org": tensor_label_org, # optimize the generator network
"label_org_": tensor_label_org_,
"label_trg": tensor_label_trg,
"label_trg_": tensor_label_trg_
})
batch_time = time.time() - s_time
t_time += batch_time
print("epoch{}: batch{}: \n\
d_loss: {}; d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
Batch_time_cost: {}".format(epoch_id, batch_id, d_loss[
0], d_loss_real[0], d_loss_fake[0], d_loss_cls[0],
d_loss_gp[0], batch_time))
# optimize the generator network
else:
d_fetches = [ d_fetches = [
gen_trainer.g_loss_fake.name, gen_trainer.g_loss_fake.name,
gen_trainer.g_loss_rec.name, gen_trainer.g_loss_rec.name,
gen_trainer.g_loss_cls.name, gen_trainer.fake_img.name gen_trainer.g_loss_cls.name, gen_trainer.fake_img.name
] ]
g_loss_fake, g_loss_rec, g_loss_cls, fake_img = exe.run( g_loss_fake, g_loss_rec, g_loss_cls, fake_img = exe.run(
gen_trainer_program, gen_trainer_program, fetch_list=d_fetches, feed=data)
fetch_list=d_fetches,
feed={
"image_real": tensor_img,
"label_org": tensor_label_org,
"label_org_": tensor_label_org_,
"label_trg": tensor_label_trg,
"label_trg_": tensor_label_trg_
})
print("epoch{}: batch{}: \n\ print("epoch{}: batch{}: \n\
g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}" g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}"
.format(epoch_id, batch_id, g_loss_fake[0], .format(epoch_id, batch_id, g_loss_fake[0],
g_loss_rec[0], g_loss_cls[0])) g_loss_rec[0], g_loss_cls[0]))
batch_time = time.time() - s_time
t_time += batch_time
if (batch_id + 1) % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
d_loss: {}; d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
Batch_time_cost: {}".format(epoch_id, batch_id, d_loss[
0], d_loss_real[0], d_loss_fake[0], d_loss_cls[0],
d_loss_gp[0], batch_time))
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
if self.cfg.run_test: if self.cfg.run_test:
test_program = gen_trainer.infer_program image_name = fluid.layers.data(
name='image_name',
shape=[self.cfg.n_samples],
dtype='int32')
test_py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg, image_name],
capacity=32,
iterable=True,
use_double_buffer=True)
test_py_reader.decorate_batch_generator(
self.test_reader, places=place)
test_program = test_gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place, utility.save_test_image(epoch_id, self.cfg, exe, place,
test_program, gen_trainer, test_program, test_gen_trainer,
self.test_reader) test_py_reader)
if self.cfg.save_checkpoints: if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer, utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
......
...@@ -33,10 +33,10 @@ class GTrainer(): ...@@ -33,10 +33,10 @@ class GTrainer():
def __init__(self, input, conditions, cfg): def __init__(self, input, conditions, cfg):
self.program = fluid.default_main_program().clone() self.program = fluid.default_main_program().clone()
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
model = CGAN_model() model = CGAN_model(cfg.batch_size)
self.fake = model.network_G(input, conditions, name="G") self.fake = model.network_G(input, conditions, name="G")
self.fake.persistable = True self.fake.persistable = True
self.infer_program = self.program.clone() self.infer_program = self.program.clone(for_test=True)
d_fake = model.network_D(self.fake, conditions, name="D") d_fake = model.network_D(self.fake, conditions, name="D")
fake_labels = fluid.layers.fill_constant_batch_size_like( fake_labels = fluid.layers.fill_constant_batch_size_like(
input=input, dtype='float32', shape=[-1, 1], value=1.0) input=input, dtype='float32', shape=[-1, 1], value=1.0)
...@@ -59,7 +59,7 @@ class DTrainer(): ...@@ -59,7 +59,7 @@ class DTrainer():
def __init__(self, input, conditions, labels, cfg): def __init__(self, input, conditions, labels, cfg):
self.program = fluid.default_main_program().clone() self.program = fluid.default_main_program().clone()
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
model = CGAN_model() model = CGAN_model(cfg.batch_size)
d_logit = model.network_D(input, conditions, name="D") d_logit = model.network_D(input, conditions, name="D")
self.d_loss = fluid.layers.reduce_mean( self.d_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
...@@ -114,7 +114,6 @@ class CGAN(object): ...@@ -114,7 +114,6 @@ class CGAN(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
g_trainer_program = fluid.CompiledProgram( g_trainer_program = fluid.CompiledProgram(
g_trainer.program).with_data_parallel( g_trainer.program).with_data_parallel(
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
from network.CycleGAN_network import CycleGAN_model from network.CycleGAN_network import CycleGAN_model
from util import utility from util import utility
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
import sys import sys
import time import time
...@@ -215,13 +216,17 @@ class CycleGAN(object): ...@@ -215,13 +216,17 @@ class CycleGAN(object):
B_reader=None, B_reader=None,
A_test_reader=None, A_test_reader=None,
B_test_reader=None, B_test_reader=None,
batch_num=1): batch_num=1,
A_id2name=None,
B_id2name=None):
self.cfg = cfg self.cfg = cfg
self.A_reader = A_reader self.A_reader = A_reader
self.B_reader = B_reader self.B_reader = B_reader
self.A_test_reader = A_test_reader self.A_test_reader = A_test_reader
self.B_test_reader = B_test_reader self.B_test_reader = B_test_reader
self.batch_num = batch_num self.batch_num = batch_num
self.A_id2name = A_id2name
self.B_id2name = B_id2name
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.crop_size, self.cfg.crop_size] data_shape = [-1, 3, self.cfg.crop_size, self.cfg.crop_size]
...@@ -235,12 +240,28 @@ class CycleGAN(object): ...@@ -235,12 +240,28 @@ class CycleGAN(object):
fake_pool_B = fluid.layers.data( fake_pool_B = fluid.layers.data(
name='fake_pool_B', shape=data_shape, dtype='float32') name='fake_pool_B', shape=data_shape, dtype='float32')
A_py_reader = fluid.io.PyReader(
feed_list=[input_A],
capacity=4,
iterable=True,
use_double_buffer=True)
B_py_reader = fluid.io.PyReader(
feed_list=[input_B],
capacity=4,
iterable=True,
use_double_buffer=True)
gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num) gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num)
d_A_trainer = DATrainer(input_B, fake_pool_B, self.cfg, self.batch_num) d_A_trainer = DATrainer(input_B, fake_pool_B, self.cfg, self.batch_num)
d_B_trainer = DBTrainer(input_A, fake_pool_A, self.cfg, self.batch_num) d_B_trainer = DBTrainer(input_A, fake_pool_A, self.cfg, self.batch_num)
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
A_py_reader.decorate_batch_generator(self.A_reader, places=place)
B_py_reader.decorate_batch_generator(self.B_reader, places=place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -255,7 +276,6 @@ class CycleGAN(object): ...@@ -255,7 +276,6 @@ class CycleGAN(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
gen_trainer_program = fluid.CompiledProgram( gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel( gen_trainer.program).with_data_parallel(
...@@ -270,20 +290,14 @@ class CycleGAN(object): ...@@ -270,20 +290,14 @@ class CycleGAN(object):
loss_name=d_B_trainer.d_loss_B.name, loss_name=d_B_trainer.d_loss_B.name,
build_strategy=build_strategy) build_strategy=build_strategy)
losses = [[], []]
t_time = 0 t_time = 0
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for i in range(self.batch_num): for data_A, data_B in zip(A_py_reader(), B_py_reader()):
data_A = next(self.A_reader())
data_B = next(self.B_reader())
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_A.set(data_A, place)
tensor_B.set(data_B, place)
s_time = time.time() s_time = time.time()
# optimize the g_A network tensor_A, tensor_B = data_A[0]['input_A'], data_B[0]['input_B']
## optimize the g_A network
g_A_loss, g_A_cyc_loss, g_A_idt_loss, g_B_loss, g_B_cyc_loss,\ g_A_loss, g_A_cyc_loss, g_A_idt_loss, g_B_loss, g_B_cyc_loss,\
g_B_idt_loss, fake_A_tmp, fake_B_tmp = exe.run( g_B_idt_loss, fake_A_tmp, fake_B_tmp = exe.run(
gen_trainer_program, gen_trainer_program,
...@@ -324,16 +338,42 @@ class CycleGAN(object): ...@@ -324,16 +338,42 @@ class CycleGAN(object):
g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0], g_B_loss[ g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0], g_B_loss[
0], g_B_cyc_loss[0], g_B_idt_loss[0], batch_time)) 0], g_B_cyc_loss[0], g_B_idt_loss[0], batch_time))
losses[0].append(g_A_loss[0])
losses[1].append(d_A_loss[0])
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
if self.cfg.run_test: if self.cfg.run_test:
A_image_name = fluid.layers.data(
name='A_image_name', shape=[1], dtype='int32')
B_image_name = fluid.layers.data(
name='B_image_name', shape=[1], dtype='int32')
A_test_py_reader = fluid.io.PyReader(
feed_list=[input_A, A_image_name],
capacity=4,
iterable=True,
use_double_buffer=True)
B_test_py_reader = fluid.io.PyReader(
feed_list=[input_B, B_image_name],
capacity=4,
iterable=True,
use_double_buffer=True)
A_test_py_reader.decorate_batch_generator(
self.A_test_reader, places=place)
B_test_py_reader.decorate_batch_generator(
self.B_test_reader, places=place)
test_program = gen_trainer.infer_program test_program = gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place, utility.save_test_image(
test_program, gen_trainer, epoch_id,
self.A_test_reader, self.B_test_reader) self.cfg,
exe,
place,
test_program,
gen_trainer,
A_test_py_reader,
B_test_py_reader,
A_id2name=self.A_id2name,
B_id2name=self.B_id2name)
if self.cfg.save_checkpoints: if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer, utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
......
...@@ -33,10 +33,10 @@ class GTrainer(): ...@@ -33,10 +33,10 @@ class GTrainer():
def __init__(self, input, label, cfg): def __init__(self, input, label, cfg):
self.program = fluid.default_main_program().clone() self.program = fluid.default_main_program().clone()
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
model = DCGAN_model() model = DCGAN_model(cfg.batch_size)
self.fake = model.network_G(input, name='G') self.fake = model.network_G(input, name='G')
self.fake.persistable = True self.fake.persistable = True
self.infer_program = self.program.clone() self.infer_program = self.program.clone(for_test=True)
d_fake = model.network_D(self.fake, name="D") d_fake = model.network_D(self.fake, name="D")
fake_labels = fluid.layers.fill_constant_batch_size_like( fake_labels = fluid.layers.fill_constant_batch_size_like(
input, dtype='float32', shape=[-1, 1], value=1.0) input, dtype='float32', shape=[-1, 1], value=1.0)
...@@ -58,7 +58,7 @@ class DTrainer(): ...@@ -58,7 +58,7 @@ class DTrainer():
def __init__(self, input, labels, cfg): def __init__(self, input, labels, cfg):
self.program = fluid.default_main_program().clone() self.program = fluid.default_main_program().clone()
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
model = DCGAN_model() model = DCGAN_model(cfg.batch_size)
d_logit = model.network_D(input, name="D") d_logit = model.network_D(input, name="D")
self.d_loss = fluid.layers.reduce_mean( self.d_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
...@@ -110,7 +110,6 @@ class DCGAN(object): ...@@ -110,7 +110,6 @@ class DCGAN(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
g_trainer_program = fluid.CompiledProgram( g_trainer_program = fluid.CompiledProgram(
g_trainer.program).with_data_parallel( g_trainer.program).with_data_parallel(
...@@ -120,7 +119,6 @@ class DCGAN(object): ...@@ -120,7 +119,6 @@ class DCGAN(object):
loss_name=d_trainer.d_loss.name, build_strategy=build_strategy) loss_name=d_trainer.d_loss.name, build_strategy=build_strategy)
t_time = 0 t_time = 0
losses = [[], []]
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
for batch_id, data in enumerate(self.train_reader()): for batch_id, data in enumerate(self.train_reader()):
if len(data) != self.cfg.batch_size: if len(data) != self.cfg.batch_size:
...@@ -139,7 +137,7 @@ class DCGAN(object): ...@@ -139,7 +137,7 @@ class DCGAN(object):
shape=[real_image.shape[0], 1], dtype='float32') shape=[real_image.shape[0], 1], dtype='float32')
s_time = time.time() s_time = time.time()
generate_image = exe.run(g_trainer.infer_program, generate_image = exe.run(g_trainer_program,
feed={'noise': noise_data}, feed={'noise': noise_data},
fetch_list=[g_trainer.fake]) fetch_list=[g_trainer.fake])
...@@ -154,13 +152,16 @@ class DCGAN(object): ...@@ -154,13 +152,16 @@ class DCGAN(object):
'label': fake_label}, 'label': fake_label},
fetch_list=[d_trainer.d_loss])[0] fetch_list=[d_trainer.d_loss])[0]
d_loss = d_real_loss + d_fake_loss d_loss = d_real_loss + d_fake_loss
losses[1].append(d_loss)
for _ in six.moves.xrange(self.cfg.num_generator_time): for _ in six.moves.xrange(self.cfg.num_generator_time):
noise_data = np.random.uniform(
low=-1.0,
high=1.0,
size=[self.cfg.batch_size, self.cfg.noise_size]).astype(
'float32')
g_loss = exe.run(g_trainer_program, g_loss = exe.run(g_trainer_program,
feed={'noise': noise_data}, feed={'noise': noise_data},
fetch_list=[g_trainer.g_loss])[0] fetch_list=[g_trainer.g_loss])[0]
losses[0].append(g_loss)
batch_time = time.time() - s_time batch_time = time.time() - s_time
t_time += batch_time t_time += batch_time
......
...@@ -195,11 +195,13 @@ class Pix2pix(object): ...@@ -195,11 +195,13 @@ class Pix2pix(object):
cfg=None, cfg=None,
train_reader=None, train_reader=None,
test_reader=None, test_reader=None,
batch_num=1): batch_num=1,
id2name=None):
self.cfg = cfg self.cfg = cfg
self.train_reader = train_reader self.train_reader = train_reader
self.test_reader = test_reader self.test_reader = test_reader
self.batch_num = batch_num self.batch_num = batch_num
self.id2name = id2name
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.crop_size, self.cfg.crop_size] data_shape = [-1, 3, self.cfg.crop_size, self.cfg.crop_size]
...@@ -211,12 +213,19 @@ class Pix2pix(object): ...@@ -211,12 +213,19 @@ class Pix2pix(object):
input_fake = fluid.layers.data( input_fake = fluid.layers.data(
name='input_fake', shape=data_shape, dtype='float32') name='input_fake', shape=data_shape, dtype='float32')
py_reader = fluid.io.PyReader(
feed_list=[input_A, input_B],
capacity=4, ## batch_size * 4
iterable=True,
use_double_buffer=True)
gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num) gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num)
dis_trainer = DTrainer(input_A, input_B, input_fake, self.cfg, dis_trainer = DTrainer(input_A, input_B, input_fake, self.cfg,
self.batch_num) self.batch_num)
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
py_reader.decorate_batch_generator(self.train_reader, places=place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -227,7 +236,6 @@ class Pix2pix(object): ...@@ -227,7 +236,6 @@ class Pix2pix(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
gen_trainer_program = fluid.CompiledProgram( gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel( gen_trainer.program).with_data_parallel(
...@@ -242,13 +250,10 @@ class Pix2pix(object): ...@@ -242,13 +250,10 @@ class Pix2pix(object):
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for i in range(self.batch_num): for tensor in py_reader():
data_A, data_B = next(self.train_reader())
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_A.set(data_A, place)
tensor_B.set(data_B, place)
s_time = time.time() s_time = time.time()
tensor_A, tensor_B = tensor[0]['input_A'], tensor[0]['input_B']
# optimize the generator network # optimize the generator network
g_loss_gan, g_loss_l1, fake_B_tmp = exe.run( g_loss_gan, g_loss_l1, fake_B_tmp = exe.run(
gen_trainer_program, gen_trainer_program,
...@@ -256,8 +261,7 @@ class Pix2pix(object): ...@@ -256,8 +261,7 @@ class Pix2pix(object):
gen_trainer.g_loss_gan, gen_trainer.g_loss_L1, gen_trainer.g_loss_gan, gen_trainer.g_loss_L1,
gen_trainer.fake_B gen_trainer.fake_B
], ],
feed={"input_A": tensor_A, feed=tensor)
"input_B": tensor_B})
# optimize the discriminator network # optimize the discriminator network
d_loss_real, d_loss_fake = exe.run(dis_trainer_program, d_loss_real, d_loss_fake = exe.run(dis_trainer_program,
...@@ -285,10 +289,27 @@ class Pix2pix(object): ...@@ -285,10 +289,27 @@ class Pix2pix(object):
batch_id += 1 batch_id += 1
if self.cfg.run_test: if self.cfg.run_test:
image_name = fluid.layers.data(
name='image_name',
shape=[self.cfg.batch_size],
dtype="int32")
test_py_reader = fluid.io.PyReader(
feed_list=[input_A, input_B, image_name],
capacity=4, ## batch_size * 4
iterable=True,
use_double_buffer=True)
test_py_reader.decorate_batch_generator(
self.test_reader, places=place)
test_program = gen_trainer.infer_program test_program = gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place, utility.save_test_image(
test_program, gen_trainer, epoch_id,
self.test_reader) self.cfg,
exe,
place,
test_program,
gen_trainer,
test_py_reader,
A_id2name=self.id2name)
if self.cfg.save_checkpoints: if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer, utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
......
...@@ -73,11 +73,6 @@ class DTrainer(): ...@@ -73,11 +73,6 @@ class DTrainer():
lr = cfg.d_lr lr = cfg.d_lr
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
model = STGAN_model() model = STGAN_model()
clone_image_real = []
for b in self.program.blocks:
if b.has_var('image_real'):
clone_image_real = b.var('image_real')
break
self.fake_img, _ = model.network_G( self.fake_img, _ = model.network_G(
image_real, label_org_, label_trg_, cfg, name="generator") image_real, label_org_, label_trg_, cfg, name="generator")
self.pred_real, self.cls_real = model.network_D( self.pred_real, self.cls_real = model.network_D(
...@@ -95,7 +90,7 @@ class DTrainer(): ...@@ -95,7 +90,7 @@ class DTrainer():
self.d_loss_real = -1 * fluid.layers.reduce_mean(self.pred_real) self.d_loss_real = -1 * fluid.layers.reduce_mean(self.pred_real)
self.d_loss_gp = self.gradient_penalty( self.d_loss_gp = self.gradient_penalty(
model.network_D, model.network_D,
clone_image_real, image_real,
self.fake_img, self.fake_img,
cfg=cfg, cfg=cfg,
name="discriminator") name="discriminator")
...@@ -113,7 +108,13 @@ class DTrainer(): ...@@ -113,7 +108,13 @@ class DTrainer():
x=self.pred_real, y=ones))) x=self.pred_real, y=ones)))
self.d_loss_fake = fluid.layers.mean( self.d_loss_fake = fluid.layers.mean(
fluid.layers.square(x=self.pred_fake)) fluid.layers.square(x=self.pred_fake))
self.d_loss = self.d_loss_real + self.d_loss_fake + self.d_loss_cls self.d_loss_gp = self.gradient_penalty(
model.network_D,
image_real,
None,
cfg=cfg,
name="discriminator")
self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
self.d_loss_real.persistable = True self.d_loss_real.persistable = True
self.d_loss_fake.persistable = True self.d_loss_fake.persistable = True
...@@ -136,17 +137,26 @@ class DTrainer(): ...@@ -136,17 +137,26 @@ class DTrainer():
name="net_D") name="net_D")
optimizer.minimize(self.d_loss, parameter_list=vars) optimizer.minimize(self.d_loss, parameter_list=vars)
f = open('G_program.txt', 'w')
print(self.program, file=f)
def gradient_penalty(self, f, real, fake=None, cfg=None, name=None): def gradient_penalty(self, f, real, fake=None, cfg=None, name=None):
def _interpolate(a, b=None): def _interpolate(a, b=None):
if b is None:
beta = fluid.layers.uniform_random_batch_size_like(
input=a, shape=a.shape, min=0.0, max=1.0)
mean = fluid.layers.reduce_mean(
a, range(len(a.shape)), keep_dim=True)
input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0)
var = fluid.layers.reduce_mean(
fluid.layers.square(input_sub_mean),
range(len(a.shape)),
keep_dim=True)
b = beta * fluid.layers.sqrt(var) * 0.5 + a
shape = [a.shape[0]] shape = [a.shape[0]]
alpha = fluid.layers.uniform_random_batch_size_like( alpha = fluid.layers.uniform_random_batch_size_like(
input=a, shape=shape, min=0.0, max=1.0) input=a, shape=shape, min=0.0, max=1.0)
tmp = fluid.layers.elementwise_mul( inner = (b - a) * alpha + a
fluid.layers.elementwise_sub(b, a), alpha, axis=0)
alpha.stop_gradient = True
tmp.stop_gradient = True
inner = fluid.layers.elementwise_add(a, tmp, axis=0)
return inner return inner
x = _interpolate(real, fake) x = _interpolate(real, fake)
...@@ -245,7 +255,8 @@ class STGAN(object): ...@@ -245,7 +255,8 @@ class STGAN(object):
cfg=None, cfg=None,
train_reader=None, train_reader=None,
test_reader=None, test_reader=None,
batch_num=1): batch_num=1,
id2name=None):
self.cfg = cfg self.cfg = cfg
self.train_reader = train_reader self.train_reader = train_reader
self.test_reader = test_reader self.test_reader = test_reader
...@@ -264,6 +275,19 @@ class STGAN(object): ...@@ -264,6 +275,19 @@ class STGAN(object):
name='label_org_', shape=[self.cfg.c_dim], dtype='float32') name='label_org_', shape=[self.cfg.c_dim], dtype='float32')
label_trg_ = fluid.layers.data( label_trg_ = fluid.layers.data(
name='label_trg_', shape=[self.cfg.c_dim], dtype='float32') name='label_trg_', shape=[self.cfg.c_dim], dtype='float32')
test_gen_trainer = GTrainer(image_real, label_org, label_org_,
label_trg, label_trg_, self.cfg,
self.batch_num)
py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg],
capacity=64,
iterable=True,
use_double_buffer=True)
label_org_ = (label_org * 2.0 - 1.0) * self.cfg.thres_int
label_trg_ = (label_trg * 2.0 - 1.0) * self.cfg.thres_int
gen_trainer = GTrainer(image_real, label_org, label_org_, label_trg, gen_trainer = GTrainer(image_real, label_org, label_org_, label_trg,
label_trg_, self.cfg, self.batch_num) label_trg_, self.cfg, self.batch_num)
dis_trainer = DTrainer(image_real, label_org, label_org_, label_trg, dis_trainer = DTrainer(image_real, label_org, label_org_, label_trg,
...@@ -271,6 +295,8 @@ class STGAN(object): ...@@ -271,6 +295,8 @@ class STGAN(object):
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
py_reader.decorate_batch_generator(self.train_reader, places=place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -281,7 +307,6 @@ class STGAN(object): ...@@ -281,7 +307,6 @@ class STGAN(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
gen_trainer_program = fluid.CompiledProgram( gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel( gen_trainer.program).with_data_parallel(
...@@ -296,83 +321,57 @@ class STGAN(object): ...@@ -296,83 +321,57 @@ class STGAN(object):
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for i in range(self.batch_num): for data in py_reader():
image, label_org = next(self.train_reader())
label_trg = copy.deepcopy(label_org)
np.random.shuffle(label_trg)
label_org_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_org))
label_trg_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg))
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_label_trg = fluid.LoDTensor()
tensor_label_org_ = fluid.LoDTensor()
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(image, place)
tensor_label_org.set(label_org, place)
tensor_label_trg.set(label_trg, place)
tensor_label_org_.set(label_org_, place)
tensor_label_trg_.set(label_trg_, place)
label_shape = tensor_label_trg.shape
s_time = time.time() s_time = time.time()
# optimize the discriminator network # optimize the discriminator network
if (batch_id + 1) % self.cfg.num_discriminator_time != 0: fetches = [
fetches = [ dis_trainer.d_loss.name,
dis_trainer.d_loss.name, dis_trainer.d_loss_real.name, dis_trainer.d_loss_real.name,
dis_trainer.d_loss_fake.name, dis_trainer.d_loss_fake.name,
dis_trainer.d_loss_cls.name, dis_trainer.d_loss_gp.name dis_trainer.d_loss_cls.name,
] dis_trainer.d_loss_gp.name,
d_loss, d_loss_real, d_loss_fake, d_loss_cls, d_loss_gp = exe.run( ]
dis_trainer_program, d_loss, d_loss_real, d_loss_fake, d_loss_cls, d_loss_gp, = exe.run(
fetch_list=fetches, dis_trainer_program, fetch_list=fetches, feed=data)
feed={ if (batch_id + 1) % self.cfg.num_discriminator_time == 0:
"image_real": tensor_img, # optimize the generator network
"label_org": tensor_label_org,
"label_org_": tensor_label_org_,
"label_trg": tensor_label_trg,
"label_trg_": tensor_label_trg_
})
batch_time = time.time() - s_time
t_time += batch_time
print("epoch{}: batch{}: \n\
d_loss: {}; d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
Batch_time_cost: {}".format(epoch_id, batch_id, d_loss[
0], d_loss_real[0], d_loss_fake[0], d_loss_cls[0],
d_loss_gp[0], batch_time))
# optimize the generator network
else:
d_fetches = [ d_fetches = [
gen_trainer.g_loss_fake.name, gen_trainer.g_loss_fake.name,
gen_trainer.g_loss_rec.name, gen_trainer.g_loss_cls.name gen_trainer.g_loss_rec.name, gen_trainer.g_loss_cls.name
] ]
g_loss_fake, g_loss_rec, g_loss_cls = exe.run( g_loss_fake, g_loss_rec, g_loss_cls = exe.run(
gen_trainer_program, gen_trainer_program, fetch_list=d_fetches, feed=data)
fetch_list=d_fetches,
feed={
"image_real": tensor_img,
"label_org": tensor_label_org,
"label_org_": tensor_label_org_,
"label_trg": tensor_label_trg,
"label_trg_": tensor_label_trg_
})
print("epoch{}: batch{}: \n\ print("epoch{}: batch{}: \n\
g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}" g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}"
.format(epoch_id, batch_id, g_loss_fake[0], .format(epoch_id, batch_id, g_loss_fake[0],
g_loss_rec[0], g_loss_cls[0])) g_loss_rec[0], g_loss_cls[0]))
batch_time = time.time() - s_time
t_time += batch_time
if (batch_id + 1) % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
d_loss: {}; d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
Batch_time_cost: {}".format(epoch_id, batch_id, d_loss[
0], d_loss_real[0], d_loss_fake[0], d_loss_cls[0],
d_loss_gp[0], batch_time))
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
if self.cfg.run_test: if self.cfg.run_test:
test_program = gen_trainer.infer_program image_name = fluid.layers.data(
name='image_name',
shape=[self.cfg.n_samples],
dtype='int32')
test_py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg, image_name],
capacity=32,
iterable=True,
use_double_buffer=True)
test_py_reader.decorate_batch_generator(
self.test_reader, places=place)
test_program = test_gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place, utility.save_test_image(epoch_id, self.cfg, exe, place,
test_program, gen_trainer, test_program, test_gen_trainer,
self.test_reader) test_py_reader)
if self.cfg.save_checkpoints: if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer, utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
......
...@@ -30,19 +30,7 @@ class GTrainer(): ...@@ -30,19 +30,7 @@ class GTrainer():
self.pred_fake, self.cls_fake = model.network_D( self.pred_fake, self.cls_fake = model.network_D(
self.fake_img, cfg, name="d_main") self.fake_img, cfg, name="d_main")
#wgan #wgan
if cfg.gan_mode == "wgan": self.g_loss_fake = -1 * fluid.layers.mean(self.pred_fake)
self.g_loss_fake = -1 * fluid.layers.mean(self.pred_fake)
#lsgan
elif cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_fake,
shape=self.pred_fake.shape,
value=1,
dtype='float32')
self.g_loss_fake = fluid.layers.mean(
fluid.layers.square(
fluid.layers.elementwise_sub(
x=self.pred_fake, y=ones)))
cls_shape = self.cls_fake.shape cls_shape = self.cls_fake.shape
self.cls_fake = fluid.layers.reshape( self.cls_fake = fluid.layers.reshape(
...@@ -87,11 +75,9 @@ class DTrainer(): ...@@ -87,11 +75,9 @@ class DTrainer():
self.program = fluid.default_main_program().clone() self.program = fluid.default_main_program().clone()
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
model = StarGAN_model() model = StarGAN_model()
clone_image_real = []
for b in self.program.blocks: image_real = fluid.layers.data(
if b.has_var('image_real'): name='image_real', shape=image_real.shape, dtype='float32')
clone_image_real = b.var('image_real')
break
self.fake_img = model.network_G( self.fake_img = model.network_G(
image_real, label_trg, cfg, name="g_main") image_real, label_trg, cfg, name="g_main")
self.pred_real, self.cls_real = model.network_D( self.pred_real, self.cls_real = model.network_D(
...@@ -106,30 +92,15 @@ class DTrainer(): ...@@ -106,30 +92,15 @@ class DTrainer():
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
self.cls_real, label_org)) / cfg.batch_size self.cls_real, label_org)) / cfg.batch_size
#wgan #wgan
if cfg.gan_mode == "wgan": self.d_loss_fake = fluid.layers.mean(self.pred_fake)
self.d_loss_fake = fluid.layers.mean(self.pred_fake) self.d_loss_real = -1 * fluid.layers.mean(self.pred_real)
self.d_loss_real = -1 * fluid.layers.mean(self.pred_real) self.d_loss_gp = self.gradient_penalty(
self.d_loss_gp = self.gradient_penalty( getattr(model, "network_D"),
getattr(model, "network_D"), image_real,
clone_image_real, self.fake_img,
self.fake_img, cfg=cfg,
cfg=cfg, name="d_main")
name="d_main") self.d_loss = self.d_loss_real + self.d_loss_fake + self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
self.d_loss = self.d_loss_real + self.d_loss_fake + self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
#lsgan
elif cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_real,
shape=self.pred_real.shape,
value=1,
dtype='float32')
self.d_loss_real = fluid.layers.mean(
fluid.layers.square(
fluid.layers.elementwise_sub(
x=self.pred_real, y=ones)))
self.d_loss_fake = fluid.layers.mean(
fluid.layers.square(x=self.pred_fake))
self.d_loss = self.d_loss_real + self.d_loss_fake + cfg.lambda_cls * self.d_loss_cls
self.d_loss_real.persistable = True self.d_loss_real.persistable = True
self.d_loss_fake.persistable = True self.d_loss_fake.persistable = True
...@@ -168,13 +139,8 @@ class DTrainer(): ...@@ -168,13 +139,8 @@ class DTrainer():
shape = [a.shape[0]] shape = [a.shape[0]]
alpha = fluid.layers.uniform_random_batch_size_like( alpha = fluid.layers.uniform_random_batch_size_like(
input=a, shape=shape, min=0.0, max=1.0) input=a, shape=shape, min=0.0, max=1.0)
a.stop_gradient = True
b.stop_gradient = True inner = b * (1.0 - alpha) + a * alpha
inner1 = fluid.layers.elementwise_mul(a, alpha, axis=0)
inner2 = fluid.layers.elementwise_mul(b, (1.0 - alpha), axis=0)
inner1.stop_gradient = True
inner2.stop_gradient = True
inner = inner1 + inner2
return inner return inner
x = _interpolate(real, fake) x = _interpolate(real, fake)
...@@ -264,7 +230,8 @@ class StarGAN(object): ...@@ -264,7 +230,8 @@ class StarGAN(object):
cfg=None, cfg=None,
train_reader=None, train_reader=None,
test_reader=None, test_reader=None,
batch_num=1): batch_num=1,
id2name=None):
self.cfg = cfg self.cfg = cfg
self.train_reader = train_reader self.train_reader = train_reader
self.test_reader = test_reader self.test_reader = test_reader
...@@ -279,6 +246,13 @@ class StarGAN(object): ...@@ -279,6 +246,13 @@ class StarGAN(object):
name='label_org', shape=[self.cfg.c_dim], dtype='float32') name='label_org', shape=[self.cfg.c_dim], dtype='float32')
label_trg = fluid.layers.data( label_trg = fluid.layers.data(
name='label_trg', shape=[self.cfg.c_dim], dtype='float32') name='label_trg', shape=[self.cfg.c_dim], dtype='float32')
py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg],
capacity=128,
iterable=True,
use_double_buffer=True)
gen_trainer = GTrainer(image_real, label_org, label_trg, self.cfg, gen_trainer = GTrainer(image_real, label_org, label_trg, self.cfg,
self.batch_num) self.batch_num)
dis_trainer = DTrainer(image_real, label_org, label_trg, self.cfg, dis_trainer = DTrainer(image_real, label_org, label_trg, self.cfg,
...@@ -286,6 +260,7 @@ class StarGAN(object): ...@@ -286,6 +260,7 @@ class StarGAN(object):
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
py_reader.decorate_batch_generator(self.train_reader, places=place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -296,7 +271,6 @@ class StarGAN(object): ...@@ -296,7 +271,6 @@ class StarGAN(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
gen_trainer_program = fluid.CompiledProgram( gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel( gen_trainer.program).with_data_parallel(
...@@ -311,19 +285,8 @@ class StarGAN(object): ...@@ -311,19 +285,8 @@ class StarGAN(object):
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for i in range(self.batch_num): for data in py_reader():
image, label_org = next(self.train_reader())
label_trg = copy.deepcopy(label_org)
np.random.shuffle(label_trg)
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_label_trg = fluid.LoDTensor()
tensor_img.set(image, place)
tensor_label_org.set(label_org, place)
tensor_label_trg.set(label_trg, place)
s_time = time.time() s_time = time.time()
# optimize the discriminator network
d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run( d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run(
dis_trainer_program, dis_trainer_program,
fetch_list=[ fetch_list=[
...@@ -331,11 +294,7 @@ class StarGAN(object): ...@@ -331,11 +294,7 @@ class StarGAN(object):
dis_trainer.d_loss, dis_trainer.d_loss_cls, dis_trainer.d_loss, dis_trainer.d_loss_cls,
dis_trainer.d_loss_gp dis_trainer.d_loss_gp
], ],
feed={ feed=data)
"image_real": tensor_img,
"label_org": tensor_label_org,
"label_trg": tensor_label_trg
})
# optimize the generator network # optimize the generator network
if (batch_id + 1) % self.cfg.n_critic == 0: if (batch_id + 1) % self.cfg.n_critic == 0:
g_loss_fake, g_loss_rec, g_loss_cls, fake_img, rec_img = exe.run( g_loss_fake, g_loss_rec, g_loss_cls, fake_img, rec_img = exe.run(
...@@ -345,11 +304,7 @@ class StarGAN(object): ...@@ -345,11 +304,7 @@ class StarGAN(object):
gen_trainer.g_loss_cls, gen_trainer.fake_img, gen_trainer.g_loss_cls, gen_trainer.fake_img,
gen_trainer.rec_img gen_trainer.rec_img
], ],
feed={ feed=data)
"image_real": tensor_img,
"label_org": tensor_label_org,
"label_trg": tensor_label_trg
})
print("epoch{}: batch{}: \n\ print("epoch{}: batch{}: \n\
g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}" g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}"
.format(epoch_id, batch_id, g_loss_fake[0], .format(epoch_id, batch_id, g_loss_fake[0],
...@@ -357,7 +312,7 @@ class StarGAN(object): ...@@ -357,7 +312,7 @@ class StarGAN(object):
batch_time = time.time() - s_time batch_time = time.time() - s_time
t_time += batch_time t_time += batch_time
if batch_id % self.cfg.print_freq == 0: if (batch_id + 1) % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\ print("epoch{}: batch{}: \n\
d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\ d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
Batch_time_cost: {}".format( Batch_time_cost: {}".format(
...@@ -368,10 +323,21 @@ class StarGAN(object): ...@@ -368,10 +323,21 @@ class StarGAN(object):
batch_id += 1 batch_id += 1
if self.cfg.run_test: if self.cfg.run_test:
image_name = fluid.layers.data(
name='image_name',
shape=[self.cfg.n_samples],
dtype='int32')
test_py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg, image_name],
capacity=32,
iterable=True,
use_double_buffer=True)
test_py_reader.decorate_batch_generator(
self.test_reader, places=place)
test_program = gen_trainer.infer_program test_program = gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place, utility.save_test_image(epoch_id, self.cfg, exe, place,
test_program, gen_trainer, test_program, gen_trainer,
self.test_reader) test_py_reader)
if self.cfg.save_checkpoints: if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer, utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
......
...@@ -67,6 +67,66 @@ def init_checkpoints(cfg, exe, trainer, name): ...@@ -67,6 +67,66 @@ def init_checkpoints(cfg, exe, trainer, name):
sys.stdout.flush() sys.stdout.flush()
### the initialize checkpoint is one file named checkpoint.pdparams
def init_from_checkpoint(args, exe, trainer, name):
if not os.path.exists(args.init_model):
raise Warning("the checkpoint path does not exist.")
return False
fluid.io.load_persistables(
executor=exe,
dirname=os.path.join(args.init_model, name),
main_program=trainer.program,
filename="checkpoint.pdparams")
print("finish initing model from checkpoint from %s" % (args.init_model))
return True
### save the parameters of generator to one file
def save_param(args, exe, program, dirname, var_name="generator"):
param_dir = os.path.join(args.output, 'infer_vars')
if not os.path.exists(param_dir):
os.makedirs(param_dir)
def _name_has_generator(var):
res = (fluid.io.is_parameter(var) and var.name.startswith(var_name))
print(var.name, res)
return res
fluid.io.save_vars(
exe,
os.path.join(param_dir, dirname),
main_program=program,
predicate=_name_has_generator,
filename="params.pdparams")
print("save parameters at %s" % (os.path.join(param_dir, dirname)))
return True
### save the checkpoint to one file
def save_checkpoint(epoch, args, exe, trainer, dirname):
checkpoint_dir = os.path.join(args.output, 'checkpoints', str(epoch))
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
fluid.io.save_persistables(
exe,
os.path.join(checkpoint_dir, dirname),
main_program=trainer.program,
filename="checkpoint.pdparams")
print("save checkpoint at %s" % (os.path.join(checkpoint_dir, dirname)))
return True
def save_test_image(epoch, def save_test_image(epoch,
cfg, cfg,
exe, exe,
...@@ -74,57 +134,60 @@ def save_test_image(epoch, ...@@ -74,57 +134,60 @@ def save_test_image(epoch,
test_program, test_program,
g_trainer, g_trainer,
A_test_reader, A_test_reader,
B_test_reader=None): B_test_reader=None,
A_id2name=None,
B_id2name=None):
out_path = os.path.join(cfg.output, 'test') out_path = os.path.join(cfg.output, 'test')
if not os.path.exists(out_path): if not os.path.exists(out_path):
os.makedirs(out_path) os.makedirs(out_path)
if cfg.model_net == "Pix2pix": if cfg.model_net == "Pix2pix":
for data in zip(A_test_reader()): for data in A_test_reader():
data_A, data_B, name = data[0] A_data, B_data, image_name = data[0]['input_A'], data[0][
name = name[0] 'input_B'], data[0]['image_name']
tensor_A = fluid.LoDTensor() fake_B_temp = exe.run(test_program,
tensor_B = fluid.LoDTensor() fetch_list=[g_trainer.fake_B],
tensor_A.set(data_A, place) feed={"input_A": A_data,
tensor_B.set(data_B, place) "input_B": B_data})
fake_B_temp = exe.run(
test_program,
fetch_list=[g_trainer.fake_B],
feed={"input_A": tensor_A,
"input_B": tensor_B})
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0]) fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
input_A_temp = np.squeeze(data_A[0]).transpose([1, 2, 0]) input_A_temp = np.squeeze(np.array(A_data)[0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_A[0]).transpose([1, 2, 0]) input_B_temp = np.squeeze(np.array(A_data)[0]).transpose([1, 2, 0])
imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + name, ( fakeB_name = "fakeB_" + str(epoch) + "_" + A_id2name[np.array(
(fake_B_temp + 1) * 127.5).astype(np.uint8)) image_name).astype('int32')[0]]
imageio.imwrite(out_path + "/inputA_" + str(epoch) + "_" + name, ( inputA_name = "inputA_" + str(epoch) + "_" + A_id2name[np.array(
(input_A_temp + 1) * 127.5).astype(np.uint8)) image_name).astype('int32')[0]]
imageio.imwrite(out_path + "/inputB_" + str(epoch) + "_" + name, ( inputB_name = "inputB_" + str(epoch) + "_" + A_id2name[np.array(
(input_B_temp + 1) * 127.5).astype(np.uint8)) image_name).astype('int32')[0]]
imageio.imwrite(
os.path.join(out_path, fakeB_name), (
(fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, inputA_name), (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, inputB_name), (
(input_B_temp + 1) * 127.5).astype(np.uint8))
elif cfg.model_net == "StarGAN": elif cfg.model_net == "StarGAN":
for data in zip(A_test_reader()): for data in A_test_reader():
real_img, label_org, name = data[0] real_img, label_org, label_trg, image_name = data[0][
'image_real'], data[0]['label_org'], data[0]['label_trg'], data[
0]['image_name']
attr_names = cfg.selected_attrs.split(',') attr_names = cfg.selected_attrs.split(',')
tensor_img = fluid.LoDTensor() real_img_temp = save_batch_image(np.array(real_img))
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = save_batch_image(real_img)
images = [real_img_temp] images = [real_img_temp]
for i in range(cfg.c_dim): for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org) label_trg_tmp = copy.deepcopy(np.array(label_org))
for j in range(len(label_org)): for j in range(len(np.array(label_org))):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict( np_label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names) label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor() label_trg.set(np_label_trg, place)
tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run( fake_temp, rec_temp = exe.run(
test_program, test_program,
feed={ feed={
"image_real": tensor_img, "image_real": real_img,
"label_org": tensor_label_org, "label_org": label_org,
"label_trg": tensor_label_trg "label_trg": label_trg
}, },
fetch_list=[g_trainer.fake_img, g_trainer.rec_img]) fetch_list=[g_trainer.fake_img, g_trainer.rec_img])
fake_temp = save_batch_image(fake_temp) fake_temp = save_batch_image(fake_temp)
...@@ -132,95 +195,110 @@ def save_test_image(epoch, ...@@ -132,95 +195,110 @@ def save_test_image(epoch,
images.append(fake_temp) images.append(fake_temp)
images.append(rec_temp) images.append(rec_temp)
images_concat = np.concatenate(images, 1) images_concat = np.concatenate(images, 1)
if len(label_org) > 1: if len(np.array(label_org)) > 1:
images_concat = np.concatenate(images_concat, 1) images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(out_path + "/fake_img" + str(epoch) + "_" + name[0], image_name_save = "fake_img" + str(epoch) + "_" + str(
((images_concat + 1) * 127.5).astype(np.uint8)) np.array(image_name)[0].astype('int32')) + '.jpg'
imageio.imwrite(
os.path.join(out_path, image_name_save), (
(images_concat + 1) * 127.5).astype(np.uint8))
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN': elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
for data in zip(A_test_reader()): for data in A_test_reader():
real_img, label_org, name = data[0] real_img, label_org, label_trg, image_name = data[0][
'image_real'], data[0]['label_org'], data[0]['label_trg'], data[
0]['image_name']
attr_names = cfg.selected_attrs.split(',') attr_names = cfg.selected_attrs.split(',')
label_trg = copy.deepcopy(label_org) real_img_temp = save_batch_image(np.array(real_img))
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_label_trg = fluid.LoDTensor()
tensor_label_org_ = fluid.LoDTensor()
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = save_batch_image(real_img)
images = [real_img_temp] images = [real_img_temp]
for i in range(cfg.c_dim): for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_trg) label_trg_tmp = copy.deepcopy(np.array(label_trg))
for j in range(len(label_org)): for j in range(len(label_trg_tmp)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict( label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names) label_trg_tmp, attr_names[i], attr_names)
label_org_ = list(map(lambda x: ((x * 2) - 1) * 0.5, label_org)) label_org_tmp = list(
label_trg_ = list( map(lambda x: ((x * 2) - 1) * 0.5, np.array(label_org)))
label_trg_tmp = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp)) map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
if cfg.model_net == 'AttGAN': if cfg.model_net == 'AttGAN':
for k in range(len(label_org)): for k in range(len(label_trg_tmp)):
label_trg_[k][i] = label_trg_[k][i] * 2.0 label_trg_tmp[k][i] = label_trg_tmp[k][i] * 2.0
tensor_label_org_.set(label_org_, place) tensor_label_org_ = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place) tensor_label_org_.set(label_org_tmp, place)
tensor_label_trg_.set(label_trg_, place) tensor_label_trg_ = fluid.LoDTensor()
tensor_label_trg_.set(label_trg_tmp, place)
out = exe.run(test_program, out = exe.run(test_program,
feed={ feed={
"image_real": tensor_img, "image_real": real_img,
"label_org": tensor_label_org, "label_org": label_org,
"label_org_": tensor_label_org_, "label_org_": tensor_label_org_,
"label_trg": tensor_label_trg, "label_trg": label_trg,
"label_trg_": tensor_label_trg_ "label_trg_": tensor_label_trg_
}, },
fetch_list=[g_trainer.fake_img]) fetch_list=[g_trainer.fake_img])
fake_temp = save_batch_image(out[0]) fake_temp = save_batch_image(out[0])
images.append(fake_temp) images.append(fake_temp)
images_concat = np.concatenate(images, 1) images_concat = np.concatenate(images, 1)
if len(label_org) > 1: if len(label_trg_tmp) > 1:
images_concat = np.concatenate(images_concat, 1) images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(out_path + "/fake_img" + str(epoch) + '_' + name[0], image_name_save = 'fake_img_' + str(epoch) + '_' + str(
((images_concat + 1) * 127.5).astype(np.uint8)) np.array(image_name)[0].astype('int32')) + '.jpg'
image_path = os.path.join(out_path, image_name_save)
imageio.imwrite(image_path, (
(images_concat + 1) * 127.5).astype(np.uint8))
else: else:
for data_A, data_B in zip(A_test_reader(), B_test_reader()): for data_A, data_B in zip(A_test_reader(), B_test_reader()):
A_data, A_name = data_A A_data, A_name = data_A[0]['input_A'], data_A[0]['A_image_name']
B_data, B_name = data_B B_data, B_name = data_B[0]['input_B'], data_B[0]['B_image_name']
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_A.set(A_data, place)
tensor_B.set(B_data, place)
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run( fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
test_program, test_program,
fetch_list=[ fetch_list=[
g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A, g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A,
g_trainer.cyc_B g_trainer.cyc_B
], ],
feed={"input_A": tensor_A, feed={"input_A": A_data,
"input_B": tensor_B}) "input_B": B_data})
fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0]) fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0])
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0]) fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0]) cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0])
cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0]) cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0])
input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0]) input_A_temp = np.squeeze(np.array(A_data)).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0]) input_B_temp = np.squeeze(np.array(B_data)).transpose([1, 2, 0])
imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name[0], fakeA_name = "fakeA_" + str(epoch) + "_" + A_id2name[np.array(
((fake_B_temp + 1) * 127.5).astype(np.uint8)) A_name).astype('int32')[0]]
imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name[0], fakeB_name = "fakeB_" + str(epoch) + "_" + B_id2name[np.array(
((fake_A_temp + 1) * 127.5).astype(np.uint8)) B_name).astype('int32')[0]]
imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name[0], inputA_name = "inputA_" + str(epoch) + "_" + A_id2name[np.array(
((cyc_A_temp + 1) * 127.5).astype(np.uint8)) A_name).astype('int32')[0]]
imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name[0], inputB_name = "inputB_" + str(epoch) + "_" + B_id2name[np.array(
((cyc_B_temp + 1) * 127.5).astype(np.uint8)) B_name).astype('int32')[0]]
cycA_name = "cycA_" + str(epoch) + "_" + A_id2name[np.array(
A_name).astype('int32')[0]]
cycB_name = "cycB_" + str(epoch) + "_" + B_id2name[np.array(
B_name).astype('int32')[0]]
imageio.imwrite(
os.path.join(out_path, fakeB_name), (
(fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, fakeA_name), (
(fake_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, cycA_name), (
(cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, cycB_name), (
(cyc_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite( imageio.imwrite(
out_path + "/inputA_" + str(epoch) + "_" + A_name[0], ( os.path.join(out_path, inputA_name), (
(input_A_temp + 1) * 127.5).astype(np.uint8)) (input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite( imageio.imwrite(
out_path + "/inputB_" + str(epoch) + "_" + B_name[0], ( os.path.join(out_path, inputB_name), (
(input_B_temp + 1) * 127.5).astype(np.uint8)) (input_B_temp + 1) * 127.5).astype(np.uint8))
...@@ -273,6 +351,7 @@ def check_attribute_conflict(label_batch, attr, attrs): ...@@ -273,6 +351,7 @@ def check_attribute_conflict(label_batch, attr, attrs):
def save_batch_image(img): def save_batch_image(img):
#if img.shape[0] == 1:
if len(img) == 1: if len(img) == 1:
res_img = np.squeeze(img).transpose([1, 2, 0]) res_img = np.squeeze(img).transpose([1, 2, 0])
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册