diff --git a/PaddleCV/PaddleGAN/README.md b/PaddleCV/PaddleGAN/README.md index c2e3fb4082c73e26f810374351aaeacbad8cb874..2c1de576970c8014dac3c51b39100758a542b6f1 100644 --- a/PaddleCV/PaddleGAN/README.md +++ b/PaddleCV/PaddleGAN/README.md @@ -21,7 +21,7 @@ 注意: 1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。 2. GAN模型目前仅仅验证了单机单卡训练和预测结果。 -3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。 +3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。cityscapes数据集需要从[官方](https://www.cityscapes-dataset.com)下载数据,下载完之后使用`scripts/prepare_cityscapes_dataset.py`处理,处理后的文件夹命名为cityscapes并放入data目录下即可。 4. PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装。 5. 中间效果图保存在${output_dir}/test文件夹中。对于Pix2Pix来说,inputA 和inputB 代表输入的两种风格的图片,fakeB表示生成图片;对于CycleGAN来说,inputA表示输入图片,fakeB表示inputA根据生成的图片,cycA表示fakeB经过生成器重构出来的对应于inputA的重构图片;对于StarGAN,AttGAN和STGAN来说,第一行表示原图,之后的每一行都代表一种属性变换。 6. infer过程使用的test_list文件和训练过程中使用的train_list具有相同格式,第一行为样本数量,第二行为属性,之后的行中第一个表示图片名称,之后的-1和1表示该图片是否拥有该属性(1为有该属性,-1为没有该属性)。 diff --git a/PaddleCV/PaddleGAN/download.py b/PaddleCV/PaddleGAN/download.py index 295f8f1e8492b83ca8534c56942cb1a817696d7c..b3452325fda06a855f2a7a80aae74e65c18212b0 100644 --- a/PaddleCV/PaddleGAN/download.py +++ b/PaddleCV/PaddleGAN/download.py @@ -153,8 +153,8 @@ if __name__ == '__main__': args = parser.parse_args() cycle_pix_dataset = [ 'apple2orange', 'summer2winter_yosemite', 'horse2zebra', 'monet2photo', - 'cezanne2photo', 'ukiyoe2photo', 'vangogh2photo', 'maps', 'cityscapes', - 'facades', 'iphone2dslr_flower', 'ae_photos', 'mini' + 'cezanne2photo', 'ukiyoe2photo', 'vangogh2photo', 'maps', 'facades', + 'iphone2dslr_flower', 'ae_photos', 'mini' ] pwd = os.path.join(os.path.dirname(__file__), 'data') diff --git a/PaddleCV/PaddleGAN/infer.py b/PaddleCV/PaddleGAN/infer.py index 347852a267c22c43fa64b897b8fb1bec1295d6eb..9b2a1202032802d79f3180cfe25ae4145c6127d8 100644 --- a/PaddleCV/PaddleGAN/infer.py +++ b/PaddleCV/PaddleGAN/infer.py @@ -82,6 +82,7 @@ def infer(args): name='image_name', shape=[args.n_samples], dtype='int32') model_name = 'net_G' + if args.model_net == 'CycleGAN': py_reader = fluid.io.PyReader( feed_list=[input, image_name], diff --git a/PaddleCV/PaddleGAN/scripts/prepare_cityscapes_dataset.py b/PaddleCV/PaddleGAN/scripts/prepare_cityscapes_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc8422ac32eacf0f85c48b596cc8e0cb669f388 --- /dev/null +++ b/PaddleCV/PaddleGAN/scripts/prepare_cityscapes_dataset.py @@ -0,0 +1,70 @@ +import os +import argparse +import functools +import glob +from PIL import Image +''' Based on https://github.com/junyanz/CycleGAN''' + + +def load_image(path): + return Image.open(path).convert('RGB').resize((256, 256)) + + +def propress_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase): + save_dir = os.path.join(output_dir, phase) + try: + os.makedirs(save_dir) + except Exception as e: + print("{} makedirs".format(e)) + pass + try: + os.makedirs(os.path.join(save_dir, 'A')) + except Exception as e: + print("{} makedirs".format(e)) + try: + os.makedirs(os.path.join(save_dir, 'B')) + except Exception as e: + print("{} makedirs".format(e)) + + seg_expr = os.path.join(gtFine_dir, phase, "*", "*_color.png") + seg_paths = glob.glob(seg_expr) + seg_paths = sorted(seg_paths) + + photo_expr = os.path.join(leftImg8bit_dir, phase, "*", '*_leftImg8bit.png') + photo_paths = glob.glob(photo_expr) + photo_paths = sorted(photo_paths) + + assert len(seg_paths) == len(photo_paths), \ + "[%d] gtFine images NOT match [%d] leftImg8bit images. Aborting." % (len(segmap_paths), len(photo_paths)) + + for i, (seg_path, photo_path) in enumerate(zip(seg_paths, photo_paths)): + seg_image = load_image(seg_path) + photo_image = load_image(photo_path) + # save image + save_path = os.path.join(save_dir, 'A', "%d_A.jpg" % i) + photo_image.save(save_path, format='JPEG', subsampling=0, quality=100) + save_path = os.path.join(save_dir, 'B', "%d_B.jpg" % i) + seg_image.save(save_path, format='JPEG', subsampling=0, quality=100) + + if i % 10 == 0: + print("proprecess %d ~ %d images." % (i, i + 10)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=__doc__) + # yapf: disable + parser.add_argument('--gtFine_dir', type=str, default=None, help='Path to Cityscapes gtFine directory.') + parser.add_argument('--leftImg8bit_dir', type=str, default=None, help='Path to Cityscapes leftImg8bit_trainvaltest directory.') + parser.add_argument('--output_dir', type=str, default=None, help='Path to output Cityscapes directory.') + # yapf: enable + args = parser.parse_args() + + print('Preparing Cityscapes Dataset for val phase') + propress_cityscapes(args.gtFine_dir, args.leftImg8bit_dir, args.output_dir, + 'val') + + print('Preparing Cityscapes Dataset for train phase') + propress_cityscapes(args.gtFine_dir, args.leftImg8bit_dir, args.output_dir, + 'train') + + print("DONE!!!") diff --git a/PaddleCV/PaddleGAN/scripts/run_attgan.sh b/PaddleCV/PaddleGAN/scripts/run_attgan.sh index b0fce2de8a3e03346cfae5bc91717130ba25c015..c339d05857765b480b910d7ff64701d8d061d4b2 100644 --- a/PaddleCV/PaddleGAN/scripts/run_attgan.sh +++ b/PaddleCV/PaddleGAN/scripts/run_attgan.sh @@ -1 +1 @@ -python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 --output ./output/attgan/ >log_out 2>log_err +python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 --dis_norm instance_norm --output ./output/attgan/ >log_out 2>log_err diff --git a/PaddleCV/PaddleGAN/scripts/run_stgan.sh b/PaddleCV/PaddleGAN/scripts/run_stgan.sh index 44d616112c13b5238b954b27861a59c69e9dc3c2..d258de51930cec8d82f147039ad3d622610ca8e7 100644 --- a/PaddleCV/PaddleGAN/scripts/run_stgan.sh +++ b/PaddleCV/PaddleGAN/scripts/run_stgan.sh @@ -1 +1 @@ -python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 --output ./output/stgan/ >log_out 2>log_err +python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 --dis_norm instance_norm --output ./output/stgan/ >log_out 2>log_err diff --git a/PaddleCV/PaddleGAN/train.py b/PaddleCV/PaddleGAN/train.py index be10d4ac4e785efa2e3014d0312b6dbd194b7e38..62da9d8afac568dcc23b51c541f0ba805ce077c8 100644 --- a/PaddleCV/PaddleGAN/train.py +++ b/PaddleCV/PaddleGAN/train.py @@ -30,7 +30,8 @@ import trainer def train(cfg): MODELS = [ - "CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN", "SPADE" + "CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN", + "SPADE" ] if cfg.model_net not in MODELS: raise NotImplementedError("{} is not support!".format(cfg.model_net)) diff --git a/PaddleCV/PaddleGAN/trainer/SPADE.py b/PaddleCV/PaddleGAN/trainer/SPADE.py index 5c427c4a1350e3e10f1afe8694da0ec10b1fd714..1ef1f4f74f45116ace16b94986e7d23ee95797ae 100644 --- a/PaddleCV/PaddleGAN/trainer/SPADE.py +++ b/PaddleCV/PaddleGAN/trainer/SPADE.py @@ -325,7 +325,7 @@ class SPADE(object): ### memory optim build_strategy = fluid.BuildStrategy() - build_strategy.enable_inplace = False + build_strategy.enable_inplace = True build_strategy.sync_batch_norm = False gen_trainer_program = fluid.CompiledProgram(