未验证 提交 4a93b001 编写于 作者: L lvmengsi 提交者: GitHub

Update gan (#2871)

* refine gan
上级 7b1a5565
......@@ -61,76 +61,77 @@ def get_preprocess_param(load_size, crop_size):
class reader_creator(object):
''' read and preprocess dataset'''
def __init__(self, image_dir, list_filename, batch_size=1, drop_last=False):
def __init__(self,
image_dir,
list_filename,
shuffle=False,
batch_size=1,
drop_last=False,
mode="TRAIN"):
self.image_dir = image_dir
self.list_filename = list_filename
self.batch_size = batch_size
self.drop_last = drop_last
self.mode = mode
self.lines = open(self.list_filename).readlines()
if self.mode == "TRAIN":
self.shuffle = shuffle
self.drop_last = drop_last
else:
self.shuffle = False
self.drop_last = False
def len(self):
if self.drop_last or len(self.lines) % self.batch_size == 0:
return len(self.lines) // self.batch_size
else:
return len(self.lines) // self.batch_size + 1
def get_train_reader(self, args, shuffle=False, return_name=False):
def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename)
def reader():
batch_out = []
while True:
if shuffle:
np.random.shuffle(self.lines)
for file in self.lines:
file = file.strip('\n\r\t ')
img = Image.open(os.path.join(self.image_dir,
file)).convert('RGB')
img = img.resize((args.load_size, args.load_size),
Image.BICUBIC)
if args.crop_type == 'Centor':
img = CentorCrop(img, args.crop_size, args.crop_size)
elif args.crop_type == 'Random':
img = RandomCrop(img, args.crop_size, args.crop_size)
img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1])
batch_out_name = []
if return_name:
batch_out.append([img, os.path.basename(file)])
else:
batch_out.append(img)
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
if self.drop_last == False and len(batch_out) != 0:
yield batch_out
if self.shuffle:
np.random.shuffle(self.lines)
return reader
def get_test_reader(self, args, shuffle=False, return_name=False):
print(self.image_dir, self.list_filename)
def reader():
batch_out = []
for file in self.lines:
file = file.strip('\n\r\t ')
img = Image.open(os.path.join(self.image_dir, file)).convert(
'RGB')
img = img.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
if self.mode == "TRAIN":
img = img.resize((args.image_size, args.image_size),
Image.BICUBIC)
if args.crop_type == 'Centor':
img = CentorCrop(img, args.crop_size, args.crop_size)
elif args.crop_type == 'Random':
img = RandomCrop(img, args.crop_size, args.crop_size)
else:
img = img.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1])
if return_name:
batch_out.append(
[img[np.newaxis, :], os.path.basename(file)])
batch_out.append(img)
batch_out_name.append(os.path.basename(file))
else:
batch_out.append(img)
if len(batch_out) == self.batch_size:
yield batch_out
if return_name:
yield batch_out, batch_out_name
batch_out_name = []
else:
yield batch_out
batch_out = []
if len(batch_out) != 0:
yield batch_out
if self.drop_last == False and len(batch_out) != 0:
if return_name:
yield batch_out, batch_out_name
else:
yield batch_out
return reader
......@@ -138,29 +139,43 @@ class reader_creator(object):
class pair_reader_creator(reader_creator):
''' read and preprocess dataset'''
def __init__(self, image_dir, list_filename, batch_size=1, drop_last=False):
def __init__(self,
image_dir,
list_filename,
shuffle=False,
batch_size=1,
drop_last=False,
mode="TRAIN"):
super(pair_reader_creator, self).__init__(
image_dir, list_filename, batch_size=1, drop_last=drop_last)
def get_train_reader(self, args, shuffle=False, return_name=False):
image_dir,
list_filename,
shuffle=shuffle,
batch_size=batch_size,
drop_last=drop_last,
mode=mode)
def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename)
def reader():
batch_out_1 = []
batch_out_2 = []
while True:
if shuffle:
np.random.shuffle(self.lines)
for line in self.lines:
files = line.strip('\n\r\t ').split('\t')
img1 = Image.open(os.path.join(self.image_dir, files[
0])).convert('RGB')
img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB')
param = get_preprocess_param(args.load_size, args.crop_size)
img1 = img1.resize((args.load_size, args.load_size),
batch_out_name = []
if self.shuffle:
np.random.shuffle(self.lines)
for line in self.lines:
files = line.strip('\n\r\t ').split('\t')
img1 = Image.open(os.path.join(self.image_dir, files[
0])).convert('RGB')
img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB')
if self.mode == "TRAIN":
param = get_preprocess_param(args.image_size,
args.crop_size)
img1 = img1.resize((args.image_size, args.image_size),
Image.BICUBIC)
img2 = img2.resize((args.load_size, args.load_size),
img2 = img2.resize((args.image_size, args.image_size),
Image.BICUBIC)
if args.crop_type == 'Centor':
img1 = CentorCrop(img1, args.crop_size, args.crop_size)
......@@ -172,65 +187,32 @@ class pair_reader_creator(reader_creator):
(x, y, x + args.crop_size, y + args.crop_size))
img2 = img2.crop(
(x, y, x + args.crop_size, y + args.crop_size))
img1 = (
np.array(img1).astype('float32') / 255.0 - 0.5) / 0.5
img1 = img1.transpose([2, 0, 1])
img2 = (
np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5
img2 = img2.transpose([2, 0, 1])
batch_out_1.append(img1)
batch_out_2.append(img2)
if len(batch_out_1) == self.batch_size:
yield batch_out_1, batch_out_2
batch_out_1 = []
batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
yield batch_out_1, batch_out_2
return reader
def get_test_reader(self, args, shuffle=False, return_name=False):
print(self.image_dir, self.list_filename)
else:
img1 = img1.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
img2 = img2.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
def reader():
batch_out_1 = []
batch_out_2 = []
batch_out_3 = []
for line in self.lines:
files = line.strip('\n\r\t ').split('\t')
img1 = Image.open(os.path.join(self.image_dir, files[
0])).convert('RGB')
img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB')
img1 = img1.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
img2 = img2.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
img1 = (np.array(img1).astype('float32') / 255.0 - 0.5) / 0.5
img1 = img1.transpose([2, 0, 1])
img2 = (np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5
img2 = img2.transpose([2, 0, 1])
batch_out_1.append(img1)
batch_out_2.append(img2)
if return_name:
batch_out_1.append(img1)
batch_out_2.append(img2)
batch_out_3.append(os.path.basename(files[0]))
else:
batch_out_1.append(img1)
batch_out_2.append(img2)
batch_out_name.append(os.path.basename(files[0]))
if len(batch_out_1) == self.batch_size:
if return_name:
yield batch_out_1, batch_out_2, batch_out_3
batch_out_1 = []
batch_out_2 = []
batch_out_3 = []
yield batch_out_1, batch_out_2, batch_out_name
batch_out_name = []
else:
yield batch_out_1, batch_out_2
batch_out_1 = []
batch_out_2 = []
if len(batch_out_1) != 0:
batch_out_1 = []
batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
if return_name:
yield batch_out_1, batch_out_2, batch_out_3
yield batch_out_1, batch_out_2, batch_out_name
else:
yield batch_out_1, batch_out_2
......@@ -240,24 +222,38 @@ class pair_reader_creator(reader_creator):
class celeba_reader_creator(reader_creator):
''' read and preprocess dataset'''
def __init__(self,
image_dir,
list_filename,
args,
batch_size=1,
drop_last=False):
def __init__(self, image_dir, list_filename, args, mode="TRAIN"):
self.image_dir = image_dir
self.list_filename = list_filename
self.batch_size = batch_size
self.drop_last = drop_last
self.mode = mode
self.args = args
print(self.image_dir, self.list_filename)
lines = open(self.list_filename).readlines()
all_num = int(lines[0])
train_end = 2 + int(all_num * 0.9)
test_end = train_end + int(all_num * 0.003)
all_attr_names = lines[1].split()
attr2idx = {}
for i, attr_name in enumerate(all_attr_names):
attr2idx[attr_name] = i
lines = lines[2:]
if self.mode == "TRAIN":
self.batch_size = args.batch_size
self.drop_last = args.drop_last
self.shuffle = args.shuffle
lines = lines[2:train_end]
else:
self.batch_size = args.n_samples
self.shuffle = False
self.drop_last = False
if self.mode == "TEST":
lines = lines[train_end:test_end]
else:
lines = lines[test_end:]
self.images = []
attr_names = args.selected_attrs.split(',')
for line in lines:
......@@ -275,75 +271,41 @@ class celeba_reader_creator(reader_creator):
else:
return len(self.images) // self.batch_size + 1
def get_train_reader(self, args, shuffle=False, return_name=False):
def reader():
batch_out_1 = []
batch_out_2 = []
while True:
if shuffle:
np.random.shuffle(self.images)
for file, label in self.images:
if args.model_net == "StarGAN":
img = Image.open(os.path.join(self.image_dir, file))
label = np.array(label).astype("float32")
img = RandomHorizonFlip(img)
img = CentorCrop(img, args.crop_size, args.crop_size)
img = img.resize((args.image_size, args.image_size),
Image.BILINEAR)
else:
img = Image.open(os.path.join(self.image_dir,
file)).convert('RGB')
label = np.array(label).astype("float32")
img = CentorCrop(img, args.crop_size, args.crop_size)
img = img.resize((args.image_size, args.image_size),
Image.BILINEAR)
img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1])
batch_out_1.append(img)
batch_out_2.append(label)
if len(batch_out_1) == self.batch_size:
yield batch_out_1, batch_out_2
batch_out_1 = []
batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
yield batch_out_1, batch_out_2
return reader
def make_reader(self, return_name=False):
print(self.image_dir, self.list_filename)
def get_test_reader(self, args, shuffle=False, return_name=False):
def reader():
batch_out_1 = []
batch_out_2 = []
batch_out_3 = []
batch_out_name = []
if self.shuffle:
np.random.shuffle(self.images)
for file, label in self.images:
img = Image.open(os.path.join(self.image_dir, file))
label = np.array(label).astype("float32")
img = CentorCrop(img, args.crop_size, args.crop_size)
img = img.resize((args.image_size, args.image_size),
if self.args.model_net == "StarGAN":
img = RandomHorizonFlip(img)
img = CentorCrop(img, self.args.crop_size, self.args.crop_size)
img = img.resize((self.args.image_size, self.args.image_size),
Image.BILINEAR)
img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1])
batch_out_1.append(img)
batch_out_2.append(label)
if return_name:
batch_out_1.append(img)
batch_out_2.append(label)
batch_out_3.append(os.path.basename(file))
else:
batch_out_1.append(img)
batch_out_2.append(label)
batch_out_name.append(os.path.basename(file))
if len(batch_out_1) == self.batch_size:
if return_name:
yield batch_out_1, batch_out_2, batch_out_3
batch_out_1 = []
batch_out_2 = []
batch_out_3 = []
yield batch_out_1, batch_out_2, batch_out_name
batch_out_name = []
else:
yield batch_out_1, batch_out_2
batch_out_1 = []
batch_out_2 = []
if len(batch_out_1) != 0:
batch_out_1 = []
batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
if return_name:
yield batch_out_1, batch_out_2, batch_out_3
yield batch_out_1, batch_out_2, batch_out_name
else:
yield batch_out_1, batch_out_2
......@@ -417,20 +379,24 @@ class data_reader(object):
batch_size=self.cfg.batch_size)
return train_reader
else:
if self.cfg.model_net == 'CycleGAN':
if self.cfg.model_net in ['CycleGAN']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
trainA_list = os.path.join(dataset_dir, "trainA.txt")
trainB_list = os.path.join(dataset_dir, "trainB.txt")
a_train_reader = reader_creator(
image_dir=dataset_dir,
list_filename=trainA_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last)
drop_last=self.cfg.drop_last,
mode="TRAIN")
b_train_reader = reader_creator(
image_dir=dataset_dir,
list_filename=trainB_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last)
drop_last=self.cfg.drop_last,
mode="TRAIN")
a_reader_test = None
b_reader_test = None
if self.cfg.run_test:
......@@ -439,27 +405,29 @@ class data_reader(object):
a_test_reader = reader_creator(
image_dir=dataset_dir,
list_filename=testA_list,
shuffle=False,
batch_size=1,
drop_last=self.cfg.drop_last)
drop_last=self.cfg.drop_last,
mode="TEST")
b_test_reader = reader_creator(
image_dir=dataset_dir,
list_filename=testB_list,
shuffle=False,
batch_size=1,
drop_last=self.cfg.drop_last)
a_reader_test = a_test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
b_reader_test = b_test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
drop_last=self.cfg.drop_last,
mode="TEST")
a_reader_test = a_test_reader.make_reader(
self.cfg, return_name=True)
b_reader_test = b_test_reader.make_reader(
self.cfg, return_name=True)
batch_num = max(a_train_reader.len(), b_train_reader.len())
a_reader = a_train_reader.get_train_reader(
self.cfg, shuffle=self.shuffle)
b_reader = b_train_reader.get_train_reader(
self.cfg, shuffle=self.shuffle)
a_reader = a_train_reader.make_reader(self.cfg)
b_reader = b_train_reader.make_reader(self.cfg)
return a_reader, b_reader, a_reader_test, b_reader_test, batch_num
elif self.cfg.model_net == 'StarGAN' or self.cfg.model_net == 'STGAN' or self.cfg.model_net == 'AttGAN':
elif self.cfg.model_net in ['StarGAN', 'STGAN', 'AttGAN']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.train_list is not None:
......@@ -467,28 +435,24 @@ class data_reader(object):
train_reader = celeba_reader_creator(
image_dir=dataset_dir,
list_filename=train_list,
batch_size=self.cfg.batch_size,
args=self.cfg,
drop_last=self.cfg.drop_last)
mode="TRAIN")
reader_test = None
if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt")
test_list = train_list
if self.cfg.test_list is not None:
test_list = self.cfg.test_list
test_reader = celeba_reader_creator(
image_dir=dataset_dir,
list_filename=test_list,
batch_size=self.cfg.n_samples,
drop_last=self.cfg.drop_last,
args=self.cfg)
reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
list_filename=train_list,
args=self.cfg,
mode="TEST")
reader_test = test_reader.make_reader(return_name=True)
batch_num = train_reader.len()
reader = train_reader.get_train_reader(
self.cfg, shuffle=self.shuffle)
reader = train_reader.make_reader()
return reader, reader_test, batch_num
elif self.cfg.model_net == 'Pix2pix':
elif self.cfg.model_net in ['Pix2pix']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.train_list is not None:
......@@ -496,8 +460,10 @@ class data_reader(object):
train_reader = pair_reader_creator(
image_dir=dataset_dir,
list_filename=train_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last)
drop_last=self.cfg.drop_last,
mode="TRAIN")
reader_test = None
if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt")
......@@ -506,13 +472,14 @@ class data_reader(object):
test_reader = pair_reader_creator(
image_dir=dataset_dir,
list_filename=test_list,
shuffle=False,
batch_size=1,
drop_last=self.cfg.drop_last)
reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
drop_last=self.cfg.drop_last,
mode="TEST")
reader_test = test_reader.make_reader(
self.cfg, return_name=True)
batch_num = train_reader.len()
reader = train_reader.get_train_reader(
self.cfg, shuffle=self.shuffle)
reader = train_reader.make_reader(self.cfg)
return reader, reader_test, batch_num
else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
......@@ -527,7 +494,7 @@ class data_reader(object):
test_reader = reader_creator(
image_dir=dataset_dir,
list_filename=test_list,
batch_size=1,
batch_size=self.cfg.n_samples,
drop_last=self.cfg.drop_last)
reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
......
......@@ -54,8 +54,8 @@ add_arg('image_size', int, 128, "image size")
add_arg('selected_attrs', str,
"Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young",
"the attributes we selected to change")
add_arg('batch_size', int, 16, "batch size when test")
add_arg('test_list', str, "./data/celeba/test_list_attr_celeba.txt", "the test list file")
add_arg('n_samples', int, 16, "batch size when test")
add_arg('test_list', str, "./data/celeba/list_attr_celeba.txt", "the test list file")
add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory to be infered")
add_arg('n_layers', int, 5, "default layers in generotor")
add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor")
......@@ -149,11 +149,9 @@ def infer(args):
test_reader = celeba_reader_creator(
image_dir=args.dataset_dir,
list_filename=args.test_list,
batch_size=args.batch_size,
drop_last=False,
args=args)
reader_test = test_reader.get_test_reader(
args, shuffle=False, return_name=True)
args=args,
mode="VAL")
reader_test = test_reader.make_reader(return_name=True)
for data in zip(reader_test()):
real_img, label_org, name = data[0]
print("read {}".format(name))
......@@ -199,11 +197,9 @@ def infer(args):
test_reader = celeba_reader_creator(
image_dir=args.dataset_dir,
list_filename=args.test_list,
batch_size=args.batch_size,
drop_last=False,
args=args)
reader_test = test_reader.get_test_reader(
args, shuffle=False, return_name=True)
args=args,
mode="VAL")
reader_test = test_reader.make_reader(return_name=True)
for data in zip(reader_test()):
real_img, label_org, name = data[0]
print("read {}".format(name))
......@@ -256,9 +252,9 @@ def infer(args):
elif args.model_net == 'CGAN':
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32')
size=[args.n_samples, args.noise_size]).astype('float32')
label = np.random.randint(
0, 9, size=[args.batch_size, 1]).astype('float32')
0, 9, size=[args.n_samples, 1]).astype('float32')
noise_tensor = fluid.LoDTensor()
conditions_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place)
......@@ -267,7 +263,7 @@ def infer(args):
fetch_list=[fake.name],
feed={"noise": noise_tensor,
"conditions": conditions_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fake_image = np.reshape(fake_temp, (args.n_samples, -1))
fig = utility.plot(fake_image)
plt.savefig(
......@@ -277,12 +273,12 @@ def infer(args):
elif args.model_net == 'DCGAN':
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32')
size=[args.n_samples, args.noise_size]).astype('float32')
noise_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place)
fake_temp = exe.run(fetch_list=[fake.name],
feed={"noise": noise_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fake_image = np.reshape(fake_temp, (args.n_samples, -1))
fig = utility.plot(fake_image)
plt.savefig(
......
......@@ -71,10 +71,10 @@ class AttGAN_model(object):
d = min(dim * 2**i, MAX_DIM)
#SAME padding
z = conv2d(
z,
d,
4,
2,
input=z,
num_filters=d,
filter_size=4,
stride=2,
padding_type='SAME',
norm='batch_norm',
activation_fn='leaky_relu',
......@@ -104,10 +104,10 @@ class AttGAN_model(object):
if i < n_layers - 1:
d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
z = deconv2d(
z,
d,
4,
2,
input=z,
num_filters=d,
filter_size=4,
stride=2,
padding_type='SAME',
name=name + str(i),
norm='batch_norm',
......@@ -121,10 +121,10 @@ class AttGAN_model(object):
z = self.concat(z, a)
else:
x = z = deconv2d(
z,
3,
4,
2,
input=z,
num_filters=3,
filter_size=4,
stride=2,
padding_type='SAME',
name=name + str(i),
activation_fn='tanh',
......@@ -146,10 +146,10 @@ class AttGAN_model(object):
for i in range(n_layers):
d = min(dim * 2**i, MAX_DIM)
y = conv2d(
y,
d,
4,
2,
input=y,
num_filters=d,
filter_size=4,
stride=2,
norm=norm,
padding=1,
activation_fn='leaky_relu',
......@@ -159,8 +159,8 @@ class AttGAN_model(object):
initial='kaiming')
logit_gan = linear(
y,
fc_dim,
input=y,
output_size=fc_dim,
activation_fn='relu',
name=name + 'fc_adv_1',
initial='kaiming')
......@@ -168,8 +168,8 @@ class AttGAN_model(object):
logit_gan, 1, name=name + 'fc_adv_2', initial='kaiming')
logit_att = linear(
y,
fc_dim,
input=y,
output_size=fc_dim,
activation_fn='relu',
name=name + 'fc_cls_1',
initial='kaiming')
......
python infer.py --model_net AttGAN --init_model output/checkpoints/199/ --dataset_dir "data/celeba/" --image_size 128
python infer.py --model_net AttGAN --init_model output/checkpoints/119/ --dataset_dir "data/celeba/" --image_size 128
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 90 >log_out 2>log_err
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 >log_out 2>log_err
python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --load_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err
python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --image_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 --epoch 200 --image_size 286 --crop_size 256 > log_out 2>log_err
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err
......@@ -24,51 +24,38 @@ import time
import numpy as np
import paddle
import paddle.fluid as fluid
import trainer
def train(cfg):
MODELS = [
"CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN"
]
if cfg.model_net not in MODELS:
raise NotImplementedError("{} is not support!".format(cfg.model_net))
reader = data_reader(cfg)
if cfg.model_net == 'CycleGAN':
if cfg.model_net in ['CycleGAN']:
a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data(
)
elif cfg.model_net == 'Pix2pix':
train_reader, test_reader, batch_num = reader.make_data()
elif cfg.model_net == 'StarGAN':
train_reader, test_reader, batch_num = reader.make_data()
else:
if cfg.dataset == 'mnist':
if cfg.dataset in ['mnist']:
train_reader = reader.make_data()
else:
train_reader, test_reader, batch_num = reader.make_data()
if cfg.model_net == 'CGAN':
from trainer.CGAN import CGAN
if cfg.dataset != 'mnist':
raise NotImplementedError('CGAN only support mnist now!')
model = CGAN(cfg, train_reader)
elif cfg.model_net == 'DCGAN':
from trainer.DCGAN import DCGAN
if cfg.model_net in ['CGAN', 'DCGAN']:
if cfg.dataset != 'mnist':
raise NotImplementedError('DCGAN only support mnist now!')
model = DCGAN(cfg, train_reader)
elif cfg.model_net == 'CycleGAN':
from trainer.CycleGAN import CycleGAN
model = CycleGAN(cfg, a_reader, b_reader, a_reader_test, b_reader_test,
batch_num)
elif cfg.model_net == 'Pix2pix':
from trainer.Pix2pix import Pix2pix
model = Pix2pix(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'StarGAN':
from trainer.StarGAN import StarGAN
model = StarGAN(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'AttGAN':
from trainer.AttGAN import AttGAN
model = AttGAN(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'STGAN':
from trainer.STGAN import STGAN
model = STGAN(cfg, train_reader, test_reader, batch_num)
raise NotImplementedError("CGAN/DCGAN only support MNIST now!")
model = trainer.__dict__[cfg.model_net](cfg, train_reader)
elif cfg.model_net in ['CycleGAN']:
model = trainer.__dict__[cfg.model_net](
cfg, a_reader, b_reader, a_reader_test, b_reader_test, batch_num)
else:
pass
model = trainer.__dict__[cfg.model_net](cfg, train_reader, test_reader,
batch_num)
model.build_model()
......@@ -77,13 +64,13 @@ if __name__ == "__main__":
cfg = config.parse_args()
config.print_arguments(cfg)
utility.check_gpu(cfg.use_gpu)
#assert cfg.load_size >= cfg.crop_size, "Load Size CANNOT less than Crop Size!"
if cfg.profile:
if cfg.use_gpu:
with profiler.profiler('All', 'total', '/tmp/profile') as prof:
with fluid.profiler.profiler('All', 'total',
'/tmp/profile') as prof:
train(cfg)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
with fluid.profiler.profiler("CPU", sorted_key='total') as cpuprof:
train(cfg)
else:
train(cfg)
......@@ -175,8 +175,6 @@ class DTrainer():
class AttGAN(object):
def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument(
'--g_lr',
type=float,
......
......@@ -173,8 +173,6 @@ class DTrainer():
class STGAN(object):
def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument(
'--g_lr',
type=float,
......
......@@ -199,8 +199,6 @@ class DTrainer():
class StarGAN(object):
def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument(
'--g_lr', type=float, default=0.0001, help="learning rate of g")
parser.add_argument(
......
......@@ -12,6 +12,14 @@
#See the License for the specific language governing permissions and
#limitations under the License.
from .CGAN import CGAN
from .DCGAN import DCGAN
from .CycleGAN import CycleGAN
from .Pix2pix import Pix2pix
from .STGAN import STGAN
from .StarGAN import StarGAN
from .AttGAN import AttGAN
import importlib
......
......@@ -77,7 +77,7 @@ def base_parse_args(parser):
add_arg('epoch', int, 200, "The number of epoch to be trained.")
add_arg('g_base_dims', int, 64, "Base channels in generator")
add_arg('d_base_dims', int, 64, "Base channels in discriminator")
add_arg('load_size', int, 286, "the image size when load the image")
add_arg('image_size', int, 286, "the image size when load the image")
add_arg('crop_type', str, 'Centor',
"the crop type, choose = ['Centor', 'Random']")
add_arg('crop_size', int, 256, "crop size when preprocess image")
......
......@@ -113,9 +113,10 @@ def save_test_image(epoch,
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
label_trg_tmp[0][i] = 1.0 - label_trg_tmp[0][i]
label_trg = check_attribute_conflict(label_trg_tmp,
attr_names[i], attr_names)
for j in range(len(label_org)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run(
......@@ -126,11 +127,13 @@ def save_test_image(epoch,
"label_trg": tensor_label_trg
},
fetch_list=[g_trainer.fake_img, g_trainer.rec_img])
fake_temp = save_batch_image(fake_temp[0])
rec_temp = save_batch_image(rec_temp[0])
fake_temp = save_batch_image(fake_temp)
rec_temp = save_batch_image(rec_temp)
images.append(fake_temp)
images.append(rec_temp)
images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(out_path + "/fake_img" + str(epoch) + "_" + name[0],
((images_concat + 1) * 127.5).astype(np.uint8))
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
......@@ -184,12 +187,12 @@ def save_test_image(epoch,
else:
for data_A, data_B in zip(A_test_reader(), B_test_reader()):
A_name = data_A[0][1]
B_name = data_B[0][1]
A_data, A_name = data_A
B_data, B_name = data_B
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_A.set(data_A[0][0], place)
tensor_B.set(data_B[0][0], place)
tensor_A.set(A_data, place)
tensor_B.set(B_data, place)
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
test_program,
fetch_list=[
......@@ -205,18 +208,20 @@ def save_test_image(epoch,
input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0])
imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name, (
(fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name, (
(fake_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name, (
(cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name, (
(cyc_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/inputA_" + str(epoch) + "_" + A_name, (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/inputB_" + str(epoch) + "_" + B_name, (
(input_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name[0],
((fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name[0],
((fake_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name[0],
((cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name[0],
((cyc_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
out_path + "/inputA_" + str(epoch) + "_" + A_name[0], (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
out_path + "/inputB_" + str(epoch) + "_" + B_name[0], (
(input_B_temp + 1) * 127.5).astype(np.uint8))
class ImagePool(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册