未验证 提交 030a1529 编写于 作者: L lvmengsi 提交者: GitHub

revert some action (#3509)

上级 ab3db149
......@@ -21,7 +21,7 @@
注意:
1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
2. GAN模型目前仅仅验证了单机单卡训练和预测结果。
3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。
3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。cityscapes数据集需要从[官方](https://www.cityscapes-dataset.com)下载数据,下载完之后使用`scripts/prepare_cityscapes_dataset.py`处理,处理后的文件夹命名为cityscapes并放入data目录下即可。
4. PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装。
5. 中间效果图保存在${output_dir}/test文件夹中。对于Pix2Pix来说,inputA 和inputB 代表输入的两种风格的图片,fakeB表示生成图片;对于CycleGAN来说,inputA表示输入图片,fakeB表示inputA根据生成的图片,cycA表示fakeB经过生成器重构出来的对应于inputA的重构图片;对于StarGAN,AttGAN和STGAN来说,第一行表示原图,之后的每一行都代表一种属性变换。
6. infer过程使用的test_list文件和训练过程中使用的train_list具有相同格式,第一行为样本数量,第二行为属性,之后的行中第一个表示图片名称,之后的-1和1表示该图片是否拥有该属性(1为有该属性,-1为没有该属性)。
......
......@@ -153,8 +153,8 @@ if __name__ == '__main__':
args = parser.parse_args()
cycle_pix_dataset = [
'apple2orange', 'summer2winter_yosemite', 'horse2zebra', 'monet2photo',
'cezanne2photo', 'ukiyoe2photo', 'vangogh2photo', 'maps', 'cityscapes',
'facades', 'iphone2dslr_flower', 'ae_photos', 'mini'
'cezanne2photo', 'ukiyoe2photo', 'vangogh2photo', 'maps', 'facades',
'iphone2dslr_flower', 'ae_photos', 'mini'
]
pwd = os.path.join(os.path.dirname(__file__), 'data')
......
......@@ -82,6 +82,7 @@ def infer(args):
name='image_name', shape=[args.n_samples], dtype='int32')
model_name = 'net_G'
if args.model_net == 'CycleGAN':
py_reader = fluid.io.PyReader(
feed_list=[input, image_name],
......
import os
import argparse
import functools
import glob
from PIL import Image
''' Based on https://github.com/junyanz/CycleGAN'''
def load_image(path):
return Image.open(path).convert('RGB').resize((256, 256))
def propress_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
save_dir = os.path.join(output_dir, phase)
try:
os.makedirs(save_dir)
except Exception as e:
print("{} makedirs".format(e))
pass
try:
os.makedirs(os.path.join(save_dir, 'A'))
except Exception as e:
print("{} makedirs".format(e))
try:
os.makedirs(os.path.join(save_dir, 'B'))
except Exception as e:
print("{} makedirs".format(e))
seg_expr = os.path.join(gtFine_dir, phase, "*", "*_color.png")
seg_paths = glob.glob(seg_expr)
seg_paths = sorted(seg_paths)
photo_expr = os.path.join(leftImg8bit_dir, phase, "*", '*_leftImg8bit.png')
photo_paths = glob.glob(photo_expr)
photo_paths = sorted(photo_paths)
assert len(seg_paths) == len(photo_paths), \
"[%d] gtFine images NOT match [%d] leftImg8bit images. Aborting." % (len(segmap_paths), len(photo_paths))
for i, (seg_path, photo_path) in enumerate(zip(seg_paths, photo_paths)):
seg_image = load_image(seg_path)
photo_image = load_image(photo_path)
# save image
save_path = os.path.join(save_dir, 'A', "%d_A.jpg" % i)
photo_image.save(save_path, format='JPEG', subsampling=0, quality=100)
save_path = os.path.join(save_dir, 'B', "%d_B.jpg" % i)
seg_image.save(save_path, format='JPEG', subsampling=0, quality=100)
if i % 10 == 0:
print("proprecess %d ~ %d images." % (i, i + 10))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
parser.add_argument('--gtFine_dir', type=str, default=None, help='Path to Cityscapes gtFine directory.')
parser.add_argument('--leftImg8bit_dir', type=str, default=None, help='Path to Cityscapes leftImg8bit_trainvaltest directory.')
parser.add_argument('--output_dir', type=str, default=None, help='Path to output Cityscapes directory.')
# yapf: enable
args = parser.parse_args()
print('Preparing Cityscapes Dataset for val phase')
propress_cityscapes(args.gtFine_dir, args.leftImg8bit_dir, args.output_dir,
'val')
print('Preparing Cityscapes Dataset for train phase')
propress_cityscapes(args.gtFine_dir, args.leftImg8bit_dir, args.output_dir,
'train')
print("DONE!!!")
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 --output ./output/attgan/ >log_out 2>log_err
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 --dis_norm instance_norm --output ./output/attgan/ >log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 --output ./output/stgan/ >log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 --dis_norm instance_norm --output ./output/stgan/ >log_out 2>log_err
......@@ -30,7 +30,8 @@ import trainer
def train(cfg):
MODELS = [
"CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN", "SPADE"
"CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN",
"SPADE"
]
if cfg.model_net not in MODELS:
raise NotImplementedError("{} is not support!".format(cfg.model_net))
......
......@@ -325,7 +325,7 @@ class SPADE(object):
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.enable_inplace = True
build_strategy.sync_batch_norm = False
gen_trainer_program = fluid.CompiledProgram(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册