未验证 提交 b91d222e 编写于 作者: L lvmengsi 提交者: GitHub

Update gan1.6 (#3551)

* replace instance_norm

* add version check

* update check version

* -1->None

* fix mistake
上级 ecd723a1
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
注意: 注意:
1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。 1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
2. GAN模型目前仅仅验证了单机单卡训练和预测结果。 2. GAN模型目前仅仅验证了单机单卡训练和预测结果。
3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。 3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。cityscapes数据集需要从[官方](https://www.cityscapes-dataset.com)下载数据,下载完之后使用`scripts/prepare_cityscapes_dataset.py`处理,处理后的文件夹命名为cityscapes并放入data目录下即可。
4. PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装。 4. PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装。
5. 中间效果图保存在${output_dir}/test文件夹中。对于Pix2Pix来说,inputA 和inputB 代表输入的两种风格的图片,fakeB表示生成图片;对于CycleGAN来说,inputA表示输入图片,fakeB表示inputA根据生成的图片,cycA表示fakeB经过生成器重构出来的对应于inputA的重构图片;对于StarGAN,AttGAN和STGAN来说,第一行表示原图,之后的每一行都代表一种属性变换。 5. 中间效果图保存在${output_dir}/test文件夹中。对于Pix2Pix来说,inputA 和inputB 代表输入的两种风格的图片,fakeB表示生成图片;对于CycleGAN来说,inputA表示输入图片,fakeB表示inputA根据生成的图片,cycA表示fakeB经过生成器重构出来的对应于inputA的重构图片;对于StarGAN,AttGAN和STGAN来说,第一行表示原图,之后的每一行都代表一种属性变换。
6. infer过程使用的test_list文件和训练过程中使用的train_list具有相同格式,第一行为样本数量,第二行为属性,之后的行中第一个表示图片名称,之后的-1和1表示该图片是否拥有该属性(1为有该属性,-1为没有该属性)。 6. infer过程使用的test_list文件和训练过程中使用的train_list具有相同格式,第一行为样本数量,第二行为属性,之后的行中第一个表示图片名称,之后的-1和1表示该图片是否拥有该属性(1为有该属性,-1为没有该属性)。
...@@ -67,7 +67,7 @@ ...@@ -67,7 +67,7 @@
### 安装说明 ### 安装说明
**安装[PaddlePaddle](https://github.com/PaddlePaddle/Paddle):** **安装[PaddlePaddle](https://github.com/PaddlePaddle/Paddle):**
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。 在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.6或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
其他依赖包: 其他依赖包:
1. `pip install imageio` 或者 `pip install -r requirements.txt` 安装imageio包(保存图片代码中所依赖的包) 1. `pip install imageio` 或者 `pip install -r requirements.txt` 安装imageio包(保存图片代码中所依赖的包)
......
...@@ -153,8 +153,8 @@ if __name__ == '__main__': ...@@ -153,8 +153,8 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
cycle_pix_dataset = [ cycle_pix_dataset = [
'apple2orange', 'summer2winter_yosemite', 'horse2zebra', 'monet2photo', 'apple2orange', 'summer2winter_yosemite', 'horse2zebra', 'monet2photo',
'cezanne2photo', 'ukiyoe2photo', 'vangogh2photo', 'maps', 'cityscapes', 'cezanne2photo', 'ukiyoe2photo', 'vangogh2photo', 'maps', 'facades',
'facades', 'iphone2dslr_flower', 'ae_photos', 'mini' 'iphone2dslr_flower', 'ae_photos', 'mini'
] ]
pwd = os.path.join(os.path.dirname(__file__), 'data') pwd = os.path.join(os.path.dirname(__file__), 'data')
......
...@@ -27,7 +27,7 @@ import imageio ...@@ -27,7 +27,7 @@ import imageio
import glob import glob
from util.config import add_arguments, print_arguments from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator, reader_creator, triplex_reader_creator from data_reader import celeba_reader_creator, reader_creator, triplex_reader_creator
from util.utility import check_attribute_conflict, check_gpu, save_batch_image from util.utility import check_attribute_conflict, check_gpu, save_batch_image, check_version
from util import utility from util import utility
import copy import copy
...@@ -82,6 +82,7 @@ def infer(args): ...@@ -82,6 +82,7 @@ def infer(args):
name='image_name', shape=[args.n_samples], dtype='int32') name='image_name', shape=[args.n_samples], dtype='int32')
model_name = 'net_G' model_name = 'net_G'
if args.model_net == 'CycleGAN': if args.model_net == 'CycleGAN':
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
feed_list=[input, image_name], feed_list=[input, image_name],
...@@ -383,4 +384,5 @@ if __name__ == "__main__": ...@@ -383,4 +384,5 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
check_gpu(args.use_gpu) check_gpu(args.use_gpu)
check_version()
infer(args) infer(args)
...@@ -64,12 +64,6 @@ def norm_layer(input, ...@@ -64,12 +64,6 @@ def norm_layer(input,
moving_variance_name=name + '_var') moving_variance_name=name + '_var')
elif norm_type == 'instance_norm': elif norm_type == 'instance_norm':
helper = fluid.layer_helper.LayerHelper("instance_norm", **locals())
dtype = helper.input_dtype()
epsilon = 1e-5
mean = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True)
var = fluid.layers.reduce_mean(
fluid.layers.square(input - mean), dim=[2, 3], keep_dim=True)
if name is not None: if name is not None:
scale_name = name + "_scale" scale_name = name + "_scale"
offset_name = name + "_offset" offset_name = name + "_offset"
...@@ -91,15 +85,8 @@ def norm_layer(input, ...@@ -91,15 +85,8 @@ def norm_layer(input,
name=offset_name, name=offset_name,
initializer=fluid.initializer.Constant(0.0), initializer=fluid.initializer.Constant(0.0),
trainable=False) trainable=False)
scale = helper.create_parameter( return fluid.layers.instance_norm(
attr=scale_param, shape=input.shape[1:2], dtype=dtype) input, param_attr=scale_param, bias_attr=offset_param)
offset = helper.create_parameter(
attr=offset_param, shape=input.shape[1:2], dtype=dtype)
tmp = fluid.layers.elementwise_mul(x=(input - mean), y=scale, axis=1)
tmp = tmp / fluid.layers.sqrt(var + epsilon)
tmp = fluid.layers.elementwise_add(tmp, offset, axis=1)
return tmp
else: else:
raise NotImplementedError("norm type: [%s] is not support" % norm_type) raise NotImplementedError("norm type: [%s] is not support" % norm_type)
......
import os
import argparse
import functools
import glob
from PIL import Image
''' Based on https://github.com/junyanz/CycleGAN'''
def load_image(path):
return Image.open(path).convert('RGB').resize((256, 256))
def propress_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
save_dir = os.path.join(output_dir, phase)
try:
os.makedirs(save_dir)
except Exception as e:
print("{} makedirs".format(e))
pass
try:
os.makedirs(os.path.join(save_dir, 'A'))
except Exception as e:
print("{} makedirs".format(e))
try:
os.makedirs(os.path.join(save_dir, 'B'))
except Exception as e:
print("{} makedirs".format(e))
seg_expr = os.path.join(gtFine_dir, phase, "*", "*_color.png")
seg_paths = glob.glob(seg_expr)
seg_paths = sorted(seg_paths)
photo_expr = os.path.join(leftImg8bit_dir, phase, "*", '*_leftImg8bit.png')
photo_paths = glob.glob(photo_expr)
photo_paths = sorted(photo_paths)
assert len(seg_paths) == len(photo_paths), \
"[%d] gtFine images NOT match [%d] leftImg8bit images. Aborting." % (len(segmap_paths), len(photo_paths))
for i, (seg_path, photo_path) in enumerate(zip(seg_paths, photo_paths)):
seg_image = load_image(seg_path)
photo_image = load_image(photo_path)
# save image
save_path = os.path.join(save_dir, 'A', "%d_A.jpg" % i)
photo_image.save(save_path, format='JPEG', subsampling=0, quality=100)
save_path = os.path.join(save_dir, 'B', "%d_B.jpg" % i)
seg_image.save(save_path, format='JPEG', subsampling=0, quality=100)
if i % 10 == 0:
print("preprocess %d ~ %d images." % (i, i + 10))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
parser.add_argument('--gtFine_dir', type=str, default=None, help='Path to Cityscapes gtFine directory.')
parser.add_argument('--leftImg8bit_dir', type=str, default=None, help='Path to Cityscapes leftImg8bit_trainvaltest directory.')
parser.add_argument('--output_dir', type=str, default=None, help='Path to output Cityscapes directory.')
# yapf: enable
args = parser.parse_args()
print('Preparing Cityscapes Dataset for val phase')
propress_cityscapes(args.gtFine_dir, args.leftImg8bit_dir, args.output_dir,
'val')
print('Preparing Cityscapes Dataset for train phase')
propress_cityscapes(args.gtFine_dir, args.leftImg8bit_dir, args.output_dir,
'train')
print("DONE!!!")
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 --output ./output/attgan/ >log_out 2>log_err python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 --dis_norm instance_norm --output ./output/attgan/ >log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 --output ./output/stgan/ >log_out 2>log_err python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 --dis_norm instance_norm --output ./output/stgan/ >log_out 2>log_err
...@@ -30,7 +30,8 @@ import trainer ...@@ -30,7 +30,8 @@ import trainer
def train(cfg): def train(cfg):
MODELS = [ MODELS = [
"CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN", "SPADE" "CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN",
"SPADE"
] ]
if cfg.model_net not in MODELS: if cfg.model_net not in MODELS:
raise NotImplementedError("{} is not support!".format(cfg.model_net)) raise NotImplementedError("{} is not support!".format(cfg.model_net))
...@@ -65,6 +66,7 @@ if __name__ == "__main__": ...@@ -65,6 +66,7 @@ if __name__ == "__main__":
cfg = config.parse_args() cfg = config.parse_args()
config.print_arguments(cfg) config.print_arguments(cfg)
utility.check_gpu(cfg.use_gpu) utility.check_gpu(cfg.use_gpu)
utility.check_version()
if cfg.profile: if cfg.profile:
if cfg.use_gpu: if cfg.use_gpu:
with fluid.profiler.profiler('All', 'total', with fluid.profiler.profiler('All', 'total',
......
...@@ -270,18 +270,18 @@ class AttGAN(object): ...@@ -270,18 +270,18 @@ class AttGAN(object):
self.id2name = id2name self.id2name = id2name
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.image_size, self.cfg.image_size] data_shape = [None, 3, self.cfg.image_size, self.cfg.image_size]
image_real = fluid.data( image_real = fluid.data(
name='image_real', shape=data_shape, dtype='float32') name='image_real', shape=data_shape, dtype='float32')
label_org = fluid.data( label_org = fluid.data(
name='label_org', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_org', shape=[None, self.cfg.c_dim], dtype='float32')
label_trg = fluid.data( label_trg = fluid.data(
name='label_trg', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_trg', shape=[None, self.cfg.c_dim], dtype='float32')
label_org_ = fluid.data( label_org_ = fluid.data(
name='label_org_', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_org_', shape=[None, self.cfg.c_dim], dtype='float32')
label_trg_ = fluid.data( label_trg_ = fluid.data(
name='label_trg_', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_trg_', shape=[None, self.cfg.c_dim], dtype='float32')
py_reader = fluid.io.PyReader( py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg], feed_list=[image_real, label_org, label_trg],
...@@ -369,9 +369,9 @@ class AttGAN(object): ...@@ -369,9 +369,9 @@ class AttGAN(object):
batch_id += 1 batch_id += 1
if self.cfg.run_test: if self.cfg.run_test:
image_name = fluid.layers.data( image_name = fluid.data(
name='image_name', name='image_name',
shape=[self.cfg.n_samples], shape=[None, self.cfg.n_samples],
dtype='int32') dtype='int32')
test_py_reader = fluid.io.PyReader( test_py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg, image_name], feed_list=[image_real, label_org, label_trg, image_name],
......
...@@ -88,11 +88,12 @@ class CGAN(object): ...@@ -88,11 +88,12 @@ class CGAN(object):
def build_model(self): def build_model(self):
img = fluid.data(name='img', shape=[-1, 784], dtype='float32') img = fluid.data(name='img', shape=[None, 784], dtype='float32')
condition = fluid.data(name='condition', shape=[-1, 1], dtype='float32') condition = fluid.data(
name='condition', shape=[None, 1], dtype='float32')
noise = fluid.data( noise = fluid.data(
name='noise', shape=[-1, self.cfg.noise_size], dtype='float32') name='noise', shape=[None, self.cfg.noise_size], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='float32') label = fluid.data(name='label', shape=[None, 1], dtype='float32')
g_trainer = GTrainer(noise, condition, self.cfg) g_trainer = GTrainer(noise, condition, self.cfg)
d_trainer = DTrainer(img, condition, label, self.cfg) d_trainer = DTrainer(img, condition, label, self.cfg)
......
...@@ -229,7 +229,7 @@ class CycleGAN(object): ...@@ -229,7 +229,7 @@ class CycleGAN(object):
self.B_id2name = B_id2name self.B_id2name = B_id2name
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.crop_size, self.cfg.crop_size] data_shape = [None, 3, self.cfg.crop_size, self.cfg.crop_size]
input_A = fluid.data(name='input_A', shape=data_shape, dtype='float32') input_A = fluid.data(name='input_A', shape=data_shape, dtype='float32')
input_B = fluid.data(name='input_B', shape=data_shape, dtype='float32') input_B = fluid.data(name='input_B', shape=data_shape, dtype='float32')
...@@ -347,9 +347,9 @@ class CycleGAN(object): ...@@ -347,9 +347,9 @@ class CycleGAN(object):
if self.cfg.run_test: if self.cfg.run_test:
A_image_name = fluid.data( A_image_name = fluid.data(
name='A_image_name', shape=[-1, 1], dtype='int32') name='A_image_name', shape=[None, 1], dtype='int32')
B_image_name = fluid.data( B_image_name = fluid.data(
name='B_image_name', shape=[-1, 1], dtype='int32') name='B_image_name', shape=[None, 1], dtype='int32')
A_test_py_reader = fluid.io.PyReader( A_test_py_reader = fluid.io.PyReader(
feed_list=[input_A, A_image_name], feed_list=[input_A, A_image_name],
capacity=4, capacity=4,
......
...@@ -86,10 +86,10 @@ class DCGAN(object): ...@@ -86,10 +86,10 @@ class DCGAN(object):
self.train_reader = train_reader self.train_reader = train_reader
def build_model(self): def build_model(self):
img = fluid.data(name='img', shape=[-1, 784], dtype='float32') img = fluid.data(name='img', shape=[None, 784], dtype='float32')
noise = fluid.data( noise = fluid.data(
name='noise', shape=[-1, self.cfg.noise_size], dtype='float32') name='noise', shape=[None, self.cfg.noise_size], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='float32') label = fluid.data(name='label', shape=[None, 1], dtype='float32')
g_trainer = GTrainer(noise, label, self.cfg) g_trainer = GTrainer(noise, label, self.cfg)
d_trainer = DTrainer(img, label, self.cfg) d_trainer = DTrainer(img, label, self.cfg)
......
...@@ -211,7 +211,7 @@ class Pix2pix(object): ...@@ -211,7 +211,7 @@ class Pix2pix(object):
self.id2name = id2name self.id2name = id2name
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.crop_size, self.cfg.crop_size] data_shape = [None, 3, self.cfg.crop_size, self.cfg.crop_size]
input_A = fluid.data(name='input_A', shape=data_shape, dtype='float32') input_A = fluid.data(name='input_A', shape=data_shape, dtype='float32')
input_B = fluid.data(name='input_B', shape=data_shape, dtype='float32') input_B = fluid.data(name='input_B', shape=data_shape, dtype='float32')
...@@ -298,7 +298,7 @@ class Pix2pix(object): ...@@ -298,7 +298,7 @@ class Pix2pix(object):
if self.cfg.run_test: if self.cfg.run_test:
image_name = fluid.data( image_name = fluid.data(
name='image_name', name='image_name',
shape=[-1, self.cfg.batch_size], shape=[None, self.cfg.batch_size],
dtype="int32") dtype="int32")
test_loader = fluid.io.DataLoader.from_generator( test_loader = fluid.io.DataLoader.from_generator(
feed_list=[input_A, input_B, image_name], feed_list=[input_A, input_B, image_name],
......
...@@ -283,11 +283,11 @@ class SPADE(object): ...@@ -283,11 +283,11 @@ class SPADE(object):
self.id2name = id2name self.id2name = id2name
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.crop_height, self.cfg.crop_width] data_shape = [None, 3, self.cfg.crop_height, self.cfg.crop_width]
label_shape = [ label_shape = [
-1, self.cfg.label_nc, self.cfg.crop_height, self.cfg.crop_width None, self.cfg.label_nc, self.cfg.crop_height, self.cfg.crop_width
] ]
edge_shape = [-1, 1, self.cfg.crop_height, self.cfg.crop_width] edge_shape = [None, 1, self.cfg.crop_height, self.cfg.crop_width]
input_A = fluid.data( input_A = fluid.data(
name='input_label', shape=label_shape, dtype='float32') name='input_label', shape=label_shape, dtype='float32')
...@@ -389,7 +389,7 @@ class SPADE(object): ...@@ -389,7 +389,7 @@ class SPADE(object):
test_program = gen_trainer.infer_program test_program = gen_trainer.infer_program
image_name = fluid.data( image_name = fluid.data(
name='image_name', name='image_name',
shape=[-1, self.cfg.batch_size], shape=[None, self.cfg.batch_size],
dtype="int32") dtype="int32")
test_py_reader = fluid.io.PyReader( test_py_reader = fluid.io.PyReader(
feed_list=[input_A, input_B, input_C, image_name], feed_list=[input_A, input_B, input_C, image_name],
......
...@@ -282,18 +282,18 @@ class STGAN(object): ...@@ -282,18 +282,18 @@ class STGAN(object):
self.batch_num = batch_num self.batch_num = batch_num
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.image_size, self.cfg.image_size] data_shape = [None, 3, self.cfg.image_size, self.cfg.image_size]
image_real = fluid.data( image_real = fluid.data(
name='image_real', shape=data_shape, dtype='float32') name='image_real', shape=data_shape, dtype='float32')
label_org = fluid.data( label_org = fluid.data(
name='label_org', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_org', shape=[None, self.cfg.c_dim], dtype='float32')
label_trg = fluid.data( label_trg = fluid.data(
name='label_trg', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_trg', shape=[None, self.cfg.c_dim], dtype='float32')
label_org_ = fluid.data( label_org_ = fluid.data(
name='label_org_', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_org_', shape=[None, self.cfg.c_dim], dtype='float32')
label_trg_ = fluid.data( label_trg_ = fluid.data(
name='label_trg_', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_trg_', shape=[None, self.cfg.c_dim], dtype='float32')
test_gen_trainer = GTrainer(image_real, label_org, label_org_, test_gen_trainer = GTrainer(image_real, label_org, label_org_,
label_trg, label_trg_, self.cfg, label_trg, label_trg_, self.cfg,
...@@ -380,7 +380,7 @@ class STGAN(object): ...@@ -380,7 +380,7 @@ class STGAN(object):
if self.cfg.run_test: if self.cfg.run_test:
image_name = fluid.data( image_name = fluid.data(
name='image_name', name='image_name',
shape=[-1, self.cfg.n_samples], shape=[None, self.cfg.n_samples],
dtype='int32') dtype='int32')
test_py_reader = fluid.io.PyReader( test_py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg, image_name], feed_list=[image_real, label_org, label_trg, image_name],
......
...@@ -259,14 +259,14 @@ class StarGAN(object): ...@@ -259,14 +259,14 @@ class StarGAN(object):
self.batch_num = batch_num self.batch_num = batch_num
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.image_size, self.cfg.image_size] data_shape = [None, 3, self.cfg.image_size, self.cfg.image_size]
image_real = fluid.data( image_real = fluid.data(
name='image_real', shape=data_shape, dtype='float32') name='image_real', shape=data_shape, dtype='float32')
label_org = fluid.data( label_org = fluid.data(
name='label_org', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_org', shape=[None, self.cfg.c_dim], dtype='float32')
label_trg = fluid.data( label_trg = fluid.data(
name='label_trg', shape=[-1, self.cfg.c_dim], dtype='float32') name='label_trg', shape=[None, self.cfg.c_dim], dtype='float32')
py_reader = fluid.io.PyReader( py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg], feed_list=[image_real, label_org, label_trg],
...@@ -348,7 +348,7 @@ class StarGAN(object): ...@@ -348,7 +348,7 @@ class StarGAN(object):
if self.cfg.run_test: if self.cfg.run_test:
image_name = fluid.data( image_name = fluid.data(
name='image_name', name='image_name',
shape=[-1, self.cfg.n_samples], shape=[None, self.cfg.n_samples],
dtype='int32') dtype='int32')
test_py_reader = fluid.io.PyReader( test_py_reader = fluid.io.PyReader(
feed_list=[image_real, label_org, label_trg, image_name], feed_list=[image_real, label_org, label_trg, image_name],
......
...@@ -409,3 +409,19 @@ def check_gpu(use_gpu): ...@@ -409,3 +409,19 @@ def check_gpu(use_gpu):
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
pass pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
print(err)
sys.exit(1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册