“84ef5f9e1303a00140d05feb52ee4018f19ef28c”上不存在“python/examples/faster_rcnn_model/new_test_client.py”
未验证 提交 4a93b001 编写于 作者: L lvmengsi 提交者: GitHub

Update gan (#2871)

* refine gan
上级 7b1a5565
...@@ -61,76 +61,77 @@ def get_preprocess_param(load_size, crop_size): ...@@ -61,76 +61,77 @@ def get_preprocess_param(load_size, crop_size):
class reader_creator(object): class reader_creator(object):
''' read and preprocess dataset''' ''' read and preprocess dataset'''
def __init__(self, image_dir, list_filename, batch_size=1, drop_last=False): def __init__(self,
image_dir,
list_filename,
shuffle=False,
batch_size=1,
drop_last=False,
mode="TRAIN"):
self.image_dir = image_dir self.image_dir = image_dir
self.list_filename = list_filename self.list_filename = list_filename
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.mode = mode
self.lines = open(self.list_filename).readlines() self.lines = open(self.list_filename).readlines()
if self.mode == "TRAIN":
self.shuffle = shuffle
self.drop_last = drop_last
else:
self.shuffle = False
self.drop_last = False
def len(self): def len(self):
if self.drop_last or len(self.lines) % self.batch_size == 0: if self.drop_last or len(self.lines) % self.batch_size == 0:
return len(self.lines) // self.batch_size return len(self.lines) // self.batch_size
else: else:
return len(self.lines) // self.batch_size + 1 return len(self.lines) // self.batch_size + 1
def get_train_reader(self, args, shuffle=False, return_name=False): def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename) print(self.image_dir, self.list_filename)
def reader(): def reader():
batch_out = [] batch_out = []
while True: batch_out_name = []
if shuffle:
np.random.shuffle(self.lines)
for file in self.lines:
file = file.strip('\n\r\t ')
img = Image.open(os.path.join(self.image_dir,
file)).convert('RGB')
img = img.resize((args.load_size, args.load_size),
Image.BICUBIC)
if args.crop_type == 'Centor':
img = CentorCrop(img, args.crop_size, args.crop_size)
elif args.crop_type == 'Random':
img = RandomCrop(img, args.crop_size, args.crop_size)
img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1])
if return_name: if self.shuffle:
batch_out.append([img, os.path.basename(file)]) np.random.shuffle(self.lines)
else:
batch_out.append(img)
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
if self.drop_last == False and len(batch_out) != 0:
yield batch_out
return reader
def get_test_reader(self, args, shuffle=False, return_name=False):
print(self.image_dir, self.list_filename)
def reader():
batch_out = []
for file in self.lines: for file in self.lines:
file = file.strip('\n\r\t ') file = file.strip('\n\r\t ')
img = Image.open(os.path.join(self.image_dir, file)).convert( img = Image.open(os.path.join(self.image_dir, file)).convert(
'RGB') 'RGB')
img = img.resize((args.crop_size, args.crop_size), if self.mode == "TRAIN":
Image.BICUBIC) img = img.resize((args.image_size, args.image_size),
Image.BICUBIC)
if args.crop_type == 'Centor':
img = CentorCrop(img, args.crop_size, args.crop_size)
elif args.crop_type == 'Random':
img = RandomCrop(img, args.crop_size, args.crop_size)
else:
img = img.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5 img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1]) img = img.transpose([2, 0, 1])
if return_name: if return_name:
batch_out.append( batch_out.append(img)
[img[np.newaxis, :], os.path.basename(file)]) batch_out_name.append(os.path.basename(file))
else: else:
batch_out.append(img) batch_out.append(img)
if len(batch_out) == self.batch_size: if len(batch_out) == self.batch_size:
yield batch_out if return_name:
yield batch_out, batch_out_name
batch_out_name = []
else:
yield batch_out
batch_out = [] batch_out = []
if len(batch_out) != 0: if self.drop_last == False and len(batch_out) != 0:
yield batch_out if return_name:
yield batch_out, batch_out_name
else:
yield batch_out
return reader return reader
...@@ -138,29 +139,43 @@ class reader_creator(object): ...@@ -138,29 +139,43 @@ class reader_creator(object):
class pair_reader_creator(reader_creator): class pair_reader_creator(reader_creator):
''' read and preprocess dataset''' ''' read and preprocess dataset'''
def __init__(self, image_dir, list_filename, batch_size=1, drop_last=False): def __init__(self,
image_dir,
list_filename,
shuffle=False,
batch_size=1,
drop_last=False,
mode="TRAIN"):
super(pair_reader_creator, self).__init__( super(pair_reader_creator, self).__init__(
image_dir, list_filename, batch_size=1, drop_last=drop_last) image_dir,
list_filename,
def get_train_reader(self, args, shuffle=False, return_name=False): shuffle=shuffle,
batch_size=batch_size,
drop_last=drop_last,
mode=mode)
def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename) print(self.image_dir, self.list_filename)
def reader(): def reader():
batch_out_1 = [] batch_out_1 = []
batch_out_2 = [] batch_out_2 = []
while True: batch_out_name = []
if shuffle: if self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
for line in self.lines: for line in self.lines:
files = line.strip('\n\r\t ').split('\t') files = line.strip('\n\r\t ').split('\t')
img1 = Image.open(os.path.join(self.image_dir, files[ img1 = Image.open(os.path.join(self.image_dir, files[
0])).convert('RGB') 0])).convert('RGB')
img2 = Image.open(os.path.join(self.image_dir, files[ img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB') 1])).convert('RGB')
param = get_preprocess_param(args.load_size, args.crop_size)
img1 = img1.resize((args.load_size, args.load_size), if self.mode == "TRAIN":
param = get_preprocess_param(args.image_size,
args.crop_size)
img1 = img1.resize((args.image_size, args.image_size),
Image.BICUBIC) Image.BICUBIC)
img2 = img2.resize((args.load_size, args.load_size), img2 = img2.resize((args.image_size, args.image_size),
Image.BICUBIC) Image.BICUBIC)
if args.crop_type == 'Centor': if args.crop_type == 'Centor':
img1 = CentorCrop(img1, args.crop_size, args.crop_size) img1 = CentorCrop(img1, args.crop_size, args.crop_size)
...@@ -172,65 +187,32 @@ class pair_reader_creator(reader_creator): ...@@ -172,65 +187,32 @@ class pair_reader_creator(reader_creator):
(x, y, x + args.crop_size, y + args.crop_size)) (x, y, x + args.crop_size, y + args.crop_size))
img2 = img2.crop( img2 = img2.crop(
(x, y, x + args.crop_size, y + args.crop_size)) (x, y, x + args.crop_size, y + args.crop_size))
img1 = ( else:
np.array(img1).astype('float32') / 255.0 - 0.5) / 0.5 img1 = img1.resize((args.crop_size, args.crop_size),
img1 = img1.transpose([2, 0, 1]) Image.BICUBIC)
img2 = ( img2 = img2.resize((args.crop_size, args.crop_size),
np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5 Image.BICUBIC)
img2 = img2.transpose([2, 0, 1])
batch_out_1.append(img1)
batch_out_2.append(img2)
if len(batch_out_1) == self.batch_size:
yield batch_out_1, batch_out_2
batch_out_1 = []
batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
yield batch_out_1, batch_out_2
return reader
def get_test_reader(self, args, shuffle=False, return_name=False):
print(self.image_dir, self.list_filename)
def reader():
batch_out_1 = []
batch_out_2 = []
batch_out_3 = []
for line in self.lines:
files = line.strip('\n\r\t ').split('\t')
img1 = Image.open(os.path.join(self.image_dir, files[
0])).convert('RGB')
img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB')
img1 = img1.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
img2 = img2.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
img1 = (np.array(img1).astype('float32') / 255.0 - 0.5) / 0.5 img1 = (np.array(img1).astype('float32') / 255.0 - 0.5) / 0.5
img1 = img1.transpose([2, 0, 1]) img1 = img1.transpose([2, 0, 1])
img2 = (np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5 img2 = (np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5
img2 = img2.transpose([2, 0, 1]) img2 = img2.transpose([2, 0, 1])
batch_out_1.append(img1)
batch_out_2.append(img2)
if return_name: if return_name:
batch_out_1.append(img1) batch_out_name.append(os.path.basename(files[0]))
batch_out_2.append(img2)
batch_out_3.append(os.path.basename(files[0]))
else:
batch_out_1.append(img1)
batch_out_2.append(img2)
if len(batch_out_1) == self.batch_size: if len(batch_out_1) == self.batch_size:
if return_name: if return_name:
yield batch_out_1, batch_out_2, batch_out_3 yield batch_out_1, batch_out_2, batch_out_name
batch_out_1 = [] batch_out_name = []
batch_out_2 = []
batch_out_3 = []
else: else:
yield batch_out_1, batch_out_2 yield batch_out_1, batch_out_2
batch_out_1 = [] batch_out_1 = []
batch_out_2 = [] batch_out_2 = []
if len(batch_out_1) != 0: if self.drop_last == False and len(batch_out_1) != 0:
if return_name: if return_name:
yield batch_out_1, batch_out_2, batch_out_3 yield batch_out_1, batch_out_2, batch_out_name
else: else:
yield batch_out_1, batch_out_2 yield batch_out_1, batch_out_2
...@@ -240,24 +222,38 @@ class pair_reader_creator(reader_creator): ...@@ -240,24 +222,38 @@ class pair_reader_creator(reader_creator):
class celeba_reader_creator(reader_creator): class celeba_reader_creator(reader_creator):
''' read and preprocess dataset''' ''' read and preprocess dataset'''
def __init__(self, def __init__(self, image_dir, list_filename, args, mode="TRAIN"):
image_dir,
list_filename,
args,
batch_size=1,
drop_last=False):
self.image_dir = image_dir self.image_dir = image_dir
self.list_filename = list_filename self.list_filename = list_filename
self.batch_size = batch_size self.mode = mode
self.drop_last = drop_last self.args = args
print(self.image_dir, self.list_filename)
lines = open(self.list_filename).readlines() lines = open(self.list_filename).readlines()
all_num = int(lines[0])
train_end = 2 + int(all_num * 0.9)
test_end = train_end + int(all_num * 0.003)
all_attr_names = lines[1].split() all_attr_names = lines[1].split()
attr2idx = {} attr2idx = {}
for i, attr_name in enumerate(all_attr_names): for i, attr_name in enumerate(all_attr_names):
attr2idx[attr_name] = i attr2idx[attr_name] = i
lines = lines[2:] lines = lines[2:]
if self.mode == "TRAIN":
self.batch_size = args.batch_size
self.drop_last = args.drop_last
self.shuffle = args.shuffle
lines = lines[2:train_end]
else:
self.batch_size = args.n_samples
self.shuffle = False
self.drop_last = False
if self.mode == "TEST":
lines = lines[train_end:test_end]
else:
lines = lines[test_end:]
self.images = [] self.images = []
attr_names = args.selected_attrs.split(',') attr_names = args.selected_attrs.split(',')
for line in lines: for line in lines:
...@@ -275,75 +271,41 @@ class celeba_reader_creator(reader_creator): ...@@ -275,75 +271,41 @@ class celeba_reader_creator(reader_creator):
else: else:
return len(self.images) // self.batch_size + 1 return len(self.images) // self.batch_size + 1
def get_train_reader(self, args, shuffle=False, return_name=False): def make_reader(self, return_name=False):
def reader(): print(self.image_dir, self.list_filename)
batch_out_1 = []
batch_out_2 = []
while True:
if shuffle:
np.random.shuffle(self.images)
for file, label in self.images:
if args.model_net == "StarGAN":
img = Image.open(os.path.join(self.image_dir, file))
label = np.array(label).astype("float32")
img = RandomHorizonFlip(img)
img = CentorCrop(img, args.crop_size, args.crop_size)
img = img.resize((args.image_size, args.image_size),
Image.BILINEAR)
else:
img = Image.open(os.path.join(self.image_dir,
file)).convert('RGB')
label = np.array(label).astype("float32")
img = CentorCrop(img, args.crop_size, args.crop_size)
img = img.resize((args.image_size, args.image_size),
Image.BILINEAR)
img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1])
batch_out_1.append(img)
batch_out_2.append(label)
if len(batch_out_1) == self.batch_size:
yield batch_out_1, batch_out_2
batch_out_1 = []
batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
yield batch_out_1, batch_out_2
return reader
def get_test_reader(self, args, shuffle=False, return_name=False):
def reader(): def reader():
batch_out_1 = [] batch_out_1 = []
batch_out_2 = [] batch_out_2 = []
batch_out_3 = [] batch_out_name = []
if self.shuffle:
np.random.shuffle(self.images)
for file, label in self.images: for file, label in self.images:
img = Image.open(os.path.join(self.image_dir, file)) img = Image.open(os.path.join(self.image_dir, file))
label = np.array(label).astype("float32") label = np.array(label).astype("float32")
img = CentorCrop(img, args.crop_size, args.crop_size) if self.args.model_net == "StarGAN":
img = img.resize((args.image_size, args.image_size), img = RandomHorizonFlip(img)
img = CentorCrop(img, self.args.crop_size, self.args.crop_size)
img = img.resize((self.args.image_size, self.args.image_size),
Image.BILINEAR) Image.BILINEAR)
img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5 img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1]) img = img.transpose([2, 0, 1])
batch_out_1.append(img)
batch_out_2.append(label)
if return_name: if return_name:
batch_out_1.append(img) batch_out_name.append(os.path.basename(file))
batch_out_2.append(label)
batch_out_3.append(os.path.basename(file))
else:
batch_out_1.append(img)
batch_out_2.append(label)
if len(batch_out_1) == self.batch_size: if len(batch_out_1) == self.batch_size:
if return_name: if return_name:
yield batch_out_1, batch_out_2, batch_out_3 yield batch_out_1, batch_out_2, batch_out_name
batch_out_1 = [] batch_out_name = []
batch_out_2 = []
batch_out_3 = []
else: else:
yield batch_out_1, batch_out_2 yield batch_out_1, batch_out_2
batch_out_1 = [] batch_out_1 = []
batch_out_2 = [] batch_out_2 = []
if len(batch_out_1) != 0: if self.drop_last == False and len(batch_out_1) != 0:
if return_name: if return_name:
yield batch_out_1, batch_out_2, batch_out_3 yield batch_out_1, batch_out_2, batch_out_name
else: else:
yield batch_out_1, batch_out_2 yield batch_out_1, batch_out_2
...@@ -417,20 +379,24 @@ class data_reader(object): ...@@ -417,20 +379,24 @@ class data_reader(object):
batch_size=self.cfg.batch_size) batch_size=self.cfg.batch_size)
return train_reader return train_reader
else: else:
if self.cfg.model_net == 'CycleGAN': if self.cfg.model_net in ['CycleGAN']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
trainA_list = os.path.join(dataset_dir, "trainA.txt") trainA_list = os.path.join(dataset_dir, "trainA.txt")
trainB_list = os.path.join(dataset_dir, "trainB.txt") trainB_list = os.path.join(dataset_dir, "trainB.txt")
a_train_reader = reader_creator( a_train_reader = reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=trainA_list, list_filename=trainA_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last) drop_last=self.cfg.drop_last,
mode="TRAIN")
b_train_reader = reader_creator( b_train_reader = reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=trainB_list, list_filename=trainB_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last) drop_last=self.cfg.drop_last,
mode="TRAIN")
a_reader_test = None a_reader_test = None
b_reader_test = None b_reader_test = None
if self.cfg.run_test: if self.cfg.run_test:
...@@ -439,27 +405,29 @@ class data_reader(object): ...@@ -439,27 +405,29 @@ class data_reader(object):
a_test_reader = reader_creator( a_test_reader = reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=testA_list, list_filename=testA_list,
shuffle=False,
batch_size=1, batch_size=1,
drop_last=self.cfg.drop_last) drop_last=self.cfg.drop_last,
mode="TEST")
b_test_reader = reader_creator( b_test_reader = reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=testB_list, list_filename=testB_list,
shuffle=False,
batch_size=1, batch_size=1,
drop_last=self.cfg.drop_last) drop_last=self.cfg.drop_last,
a_reader_test = a_test_reader.get_test_reader( mode="TEST")
self.cfg, shuffle=False, return_name=True) a_reader_test = a_test_reader.make_reader(
b_reader_test = b_test_reader.get_test_reader( self.cfg, return_name=True)
self.cfg, shuffle=False, return_name=True) b_reader_test = b_test_reader.make_reader(
self.cfg, return_name=True)
batch_num = max(a_train_reader.len(), b_train_reader.len()) batch_num = max(a_train_reader.len(), b_train_reader.len())
a_reader = a_train_reader.get_train_reader( a_reader = a_train_reader.make_reader(self.cfg)
self.cfg, shuffle=self.shuffle) b_reader = b_train_reader.make_reader(self.cfg)
b_reader = b_train_reader.get_train_reader(
self.cfg, shuffle=self.shuffle)
return a_reader, b_reader, a_reader_test, b_reader_test, batch_num return a_reader, b_reader, a_reader_test, b_reader_test, batch_num
elif self.cfg.model_net == 'StarGAN' or self.cfg.model_net == 'STGAN' or self.cfg.model_net == 'AttGAN': elif self.cfg.model_net in ['StarGAN', 'STGAN', 'AttGAN']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt') train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.train_list is not None: if self.cfg.train_list is not None:
...@@ -467,28 +435,24 @@ class data_reader(object): ...@@ -467,28 +435,24 @@ class data_reader(object):
train_reader = celeba_reader_creator( train_reader = celeba_reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=train_list, list_filename=train_list,
batch_size=self.cfg.batch_size,
args=self.cfg, args=self.cfg,
drop_last=self.cfg.drop_last) mode="TRAIN")
reader_test = None reader_test = None
if self.cfg.run_test: if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt") test_list = train_list
if self.cfg.test_list is not None: if self.cfg.test_list is not None:
test_list = self.cfg.test_list test_list = self.cfg.test_list
test_reader = celeba_reader_creator( test_reader = celeba_reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=test_list, list_filename=train_list,
batch_size=self.cfg.n_samples, args=self.cfg,
drop_last=self.cfg.drop_last, mode="TEST")
args=self.cfg) reader_test = test_reader.make_reader(return_name=True)
reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
batch_num = train_reader.len() batch_num = train_reader.len()
reader = train_reader.get_train_reader( reader = train_reader.make_reader()
self.cfg, shuffle=self.shuffle)
return reader, reader_test, batch_num return reader, reader_test, batch_num
elif self.cfg.model_net == 'Pix2pix': elif self.cfg.model_net in ['Pix2pix']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt') train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.train_list is not None: if self.cfg.train_list is not None:
...@@ -496,8 +460,10 @@ class data_reader(object): ...@@ -496,8 +460,10 @@ class data_reader(object):
train_reader = pair_reader_creator( train_reader = pair_reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=train_list, list_filename=train_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last) drop_last=self.cfg.drop_last,
mode="TRAIN")
reader_test = None reader_test = None
if self.cfg.run_test: if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt") test_list = os.path.join(dataset_dir, "test.txt")
...@@ -506,13 +472,14 @@ class data_reader(object): ...@@ -506,13 +472,14 @@ class data_reader(object):
test_reader = pair_reader_creator( test_reader = pair_reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=test_list, list_filename=test_list,
shuffle=False,
batch_size=1, batch_size=1,
drop_last=self.cfg.drop_last) drop_last=self.cfg.drop_last,
reader_test = test_reader.get_test_reader( mode="TEST")
self.cfg, shuffle=False, return_name=True) reader_test = test_reader.make_reader(
self.cfg, return_name=True)
batch_num = train_reader.len() batch_num = train_reader.len()
reader = train_reader.get_train_reader( reader = train_reader.make_reader(self.cfg)
self.cfg, shuffle=self.shuffle)
return reader, reader_test, batch_num return reader, reader_test, batch_num
else: else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
...@@ -527,7 +494,7 @@ class data_reader(object): ...@@ -527,7 +494,7 @@ class data_reader(object):
test_reader = reader_creator( test_reader = reader_creator(
image_dir=dataset_dir, image_dir=dataset_dir,
list_filename=test_list, list_filename=test_list,
batch_size=1, batch_size=self.cfg.n_samples,
drop_last=self.cfg.drop_last) drop_last=self.cfg.drop_last)
reader_test = test_reader.get_test_reader( reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True) self.cfg, shuffle=False, return_name=True)
......
...@@ -54,8 +54,8 @@ add_arg('image_size', int, 128, "image size") ...@@ -54,8 +54,8 @@ add_arg('image_size', int, 128, "image size")
add_arg('selected_attrs', str, add_arg('selected_attrs', str,
"Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young", "Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young",
"the attributes we selected to change") "the attributes we selected to change")
add_arg('batch_size', int, 16, "batch size when test") add_arg('n_samples', int, 16, "batch size when test")
add_arg('test_list', str, "./data/celeba/test_list_attr_celeba.txt", "the test list file") add_arg('test_list', str, "./data/celeba/list_attr_celeba.txt", "the test list file")
add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory to be infered") add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory to be infered")
add_arg('n_layers', int, 5, "default layers in generotor") add_arg('n_layers', int, 5, "default layers in generotor")
add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor") add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor")
...@@ -149,11 +149,9 @@ def infer(args): ...@@ -149,11 +149,9 @@ def infer(args):
test_reader = celeba_reader_creator( test_reader = celeba_reader_creator(
image_dir=args.dataset_dir, image_dir=args.dataset_dir,
list_filename=args.test_list, list_filename=args.test_list,
batch_size=args.batch_size, args=args,
drop_last=False, mode="VAL")
args=args) reader_test = test_reader.make_reader(return_name=True)
reader_test = test_reader.get_test_reader(
args, shuffle=False, return_name=True)
for data in zip(reader_test()): for data in zip(reader_test()):
real_img, label_org, name = data[0] real_img, label_org, name = data[0]
print("read {}".format(name)) print("read {}".format(name))
...@@ -199,11 +197,9 @@ def infer(args): ...@@ -199,11 +197,9 @@ def infer(args):
test_reader = celeba_reader_creator( test_reader = celeba_reader_creator(
image_dir=args.dataset_dir, image_dir=args.dataset_dir,
list_filename=args.test_list, list_filename=args.test_list,
batch_size=args.batch_size, args=args,
drop_last=False, mode="VAL")
args=args) reader_test = test_reader.make_reader(return_name=True)
reader_test = test_reader.get_test_reader(
args, shuffle=False, return_name=True)
for data in zip(reader_test()): for data in zip(reader_test()):
real_img, label_org, name = data[0] real_img, label_org, name = data[0]
print("read {}".format(name)) print("read {}".format(name))
...@@ -256,9 +252,9 @@ def infer(args): ...@@ -256,9 +252,9 @@ def infer(args):
elif args.model_net == 'CGAN': elif args.model_net == 'CGAN':
noise_data = np.random.uniform( noise_data = np.random.uniform(
low=-1.0, high=1.0, low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32') size=[args.n_samples, args.noise_size]).astype('float32')
label = np.random.randint( label = np.random.randint(
0, 9, size=[args.batch_size, 1]).astype('float32') 0, 9, size=[args.n_samples, 1]).astype('float32')
noise_tensor = fluid.LoDTensor() noise_tensor = fluid.LoDTensor()
conditions_tensor = fluid.LoDTensor() conditions_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place) noise_tensor.set(noise_data, place)
...@@ -267,7 +263,7 @@ def infer(args): ...@@ -267,7 +263,7 @@ def infer(args):
fetch_list=[fake.name], fetch_list=[fake.name],
feed={"noise": noise_tensor, feed={"noise": noise_tensor,
"conditions": conditions_tensor})[0] "conditions": conditions_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1)) fake_image = np.reshape(fake_temp, (args.n_samples, -1))
fig = utility.plot(fake_image) fig = utility.plot(fake_image)
plt.savefig( plt.savefig(
...@@ -277,12 +273,12 @@ def infer(args): ...@@ -277,12 +273,12 @@ def infer(args):
elif args.model_net == 'DCGAN': elif args.model_net == 'DCGAN':
noise_data = np.random.uniform( noise_data = np.random.uniform(
low=-1.0, high=1.0, low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32') size=[args.n_samples, args.noise_size]).astype('float32')
noise_tensor = fluid.LoDTensor() noise_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place) noise_tensor.set(noise_data, place)
fake_temp = exe.run(fetch_list=[fake.name], fake_temp = exe.run(fetch_list=[fake.name],
feed={"noise": noise_tensor})[0] feed={"noise": noise_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1)) fake_image = np.reshape(fake_temp, (args.n_samples, -1))
fig = utility.plot(fake_image) fig = utility.plot(fake_image)
plt.savefig( plt.savefig(
......
...@@ -71,10 +71,10 @@ class AttGAN_model(object): ...@@ -71,10 +71,10 @@ class AttGAN_model(object):
d = min(dim * 2**i, MAX_DIM) d = min(dim * 2**i, MAX_DIM)
#SAME padding #SAME padding
z = conv2d( z = conv2d(
z, input=z,
d, num_filters=d,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
norm='batch_norm', norm='batch_norm',
activation_fn='leaky_relu', activation_fn='leaky_relu',
...@@ -104,10 +104,10 @@ class AttGAN_model(object): ...@@ -104,10 +104,10 @@ class AttGAN_model(object):
if i < n_layers - 1: if i < n_layers - 1:
d = min(dim * 2**(n_layers - 1 - i), MAX_DIM) d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
z = deconv2d( z = deconv2d(
z, input=z,
d, num_filters=d,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
name=name + str(i), name=name + str(i),
norm='batch_norm', norm='batch_norm',
...@@ -121,10 +121,10 @@ class AttGAN_model(object): ...@@ -121,10 +121,10 @@ class AttGAN_model(object):
z = self.concat(z, a) z = self.concat(z, a)
else: else:
x = z = deconv2d( x = z = deconv2d(
z, input=z,
3, num_filters=3,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
name=name + str(i), name=name + str(i),
activation_fn='tanh', activation_fn='tanh',
...@@ -146,10 +146,10 @@ class AttGAN_model(object): ...@@ -146,10 +146,10 @@ class AttGAN_model(object):
for i in range(n_layers): for i in range(n_layers):
d = min(dim * 2**i, MAX_DIM) d = min(dim * 2**i, MAX_DIM)
y = conv2d( y = conv2d(
y, input=y,
d, num_filters=d,
4, filter_size=4,
2, stride=2,
norm=norm, norm=norm,
padding=1, padding=1,
activation_fn='leaky_relu', activation_fn='leaky_relu',
...@@ -159,8 +159,8 @@ class AttGAN_model(object): ...@@ -159,8 +159,8 @@ class AttGAN_model(object):
initial='kaiming') initial='kaiming')
logit_gan = linear( logit_gan = linear(
y, input=y,
fc_dim, output_size=fc_dim,
activation_fn='relu', activation_fn='relu',
name=name + 'fc_adv_1', name=name + 'fc_adv_1',
initial='kaiming') initial='kaiming')
...@@ -168,8 +168,8 @@ class AttGAN_model(object): ...@@ -168,8 +168,8 @@ class AttGAN_model(object):
logit_gan, 1, name=name + 'fc_adv_2', initial='kaiming') logit_gan, 1, name=name + 'fc_adv_2', initial='kaiming')
logit_att = linear( logit_att = linear(
y, input=y,
fc_dim, output_size=fc_dim,
activation_fn='relu', activation_fn='relu',
name=name + 'fc_cls_1', name=name + 'fc_cls_1',
initial='kaiming') initial='kaiming')
......
python infer.py --model_net AttGAN --init_model output/checkpoints/199/ --dataset_dir "data/celeba/" --image_size 128 python infer.py --model_net AttGAN --init_model output/checkpoints/119/ --dataset_dir "data/celeba/" --image_size 128
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 90 >log_out 2>log_err python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 >log_out 2>log_err
python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --load_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --image_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 --epoch 200 --image_size 286 --crop_size 256 > log_out 2>log_err
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err
...@@ -24,51 +24,38 @@ import time ...@@ -24,51 +24,38 @@ import time
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import trainer
def train(cfg): def train(cfg):
MODELS = [
"CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN"
]
if cfg.model_net not in MODELS:
raise NotImplementedError("{} is not support!".format(cfg.model_net))
reader = data_reader(cfg) reader = data_reader(cfg)
if cfg.model_net == 'CycleGAN':
if cfg.model_net in ['CycleGAN']:
a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data( a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data(
) )
elif cfg.model_net == 'Pix2pix':
train_reader, test_reader, batch_num = reader.make_data()
elif cfg.model_net == 'StarGAN':
train_reader, test_reader, batch_num = reader.make_data()
else: else:
if cfg.dataset == 'mnist': if cfg.dataset in ['mnist']:
train_reader = reader.make_data() train_reader = reader.make_data()
else: else:
train_reader, test_reader, batch_num = reader.make_data() train_reader, test_reader, batch_num = reader.make_data()
if cfg.model_net == 'CGAN': if cfg.model_net in ['CGAN', 'DCGAN']:
from trainer.CGAN import CGAN
if cfg.dataset != 'mnist':
raise NotImplementedError('CGAN only support mnist now!')
model = CGAN(cfg, train_reader)
elif cfg.model_net == 'DCGAN':
from trainer.DCGAN import DCGAN
if cfg.dataset != 'mnist': if cfg.dataset != 'mnist':
raise NotImplementedError('DCGAN only support mnist now!') raise NotImplementedError("CGAN/DCGAN only support MNIST now!")
model = DCGAN(cfg, train_reader) model = trainer.__dict__[cfg.model_net](cfg, train_reader)
elif cfg.model_net == 'CycleGAN': elif cfg.model_net in ['CycleGAN']:
from trainer.CycleGAN import CycleGAN model = trainer.__dict__[cfg.model_net](
model = CycleGAN(cfg, a_reader, b_reader, a_reader_test, b_reader_test, cfg, a_reader, b_reader, a_reader_test, b_reader_test, batch_num)
batch_num)
elif cfg.model_net == 'Pix2pix':
from trainer.Pix2pix import Pix2pix
model = Pix2pix(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'StarGAN':
from trainer.StarGAN import StarGAN
model = StarGAN(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'AttGAN':
from trainer.AttGAN import AttGAN
model = AttGAN(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'STGAN':
from trainer.STGAN import STGAN
model = STGAN(cfg, train_reader, test_reader, batch_num)
else: else:
pass model = trainer.__dict__[cfg.model_net](cfg, train_reader, test_reader,
batch_num)
model.build_model() model.build_model()
...@@ -77,13 +64,13 @@ if __name__ == "__main__": ...@@ -77,13 +64,13 @@ if __name__ == "__main__":
cfg = config.parse_args() cfg = config.parse_args()
config.print_arguments(cfg) config.print_arguments(cfg)
utility.check_gpu(cfg.use_gpu) utility.check_gpu(cfg.use_gpu)
#assert cfg.load_size >= cfg.crop_size, "Load Size CANNOT less than Crop Size!"
if cfg.profile: if cfg.profile:
if cfg.use_gpu: if cfg.use_gpu:
with profiler.profiler('All', 'total', '/tmp/profile') as prof: with fluid.profiler.profiler('All', 'total',
'/tmp/profile') as prof:
train(cfg) train(cfg)
else: else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof: with fluid.profiler.profiler("CPU", sorted_key='total') as cpuprof:
train(cfg) train(cfg)
else: else:
train(cfg) train(cfg)
...@@ -175,8 +175,6 @@ class DTrainer(): ...@@ -175,8 +175,6 @@ class DTrainer():
class AttGAN(object): class AttGAN(object):
def add_special_args(self, parser): def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument( parser.add_argument(
'--g_lr', '--g_lr',
type=float, type=float,
......
...@@ -173,8 +173,6 @@ class DTrainer(): ...@@ -173,8 +173,6 @@ class DTrainer():
class STGAN(object): class STGAN(object):
def add_special_args(self, parser): def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument( parser.add_argument(
'--g_lr', '--g_lr',
type=float, type=float,
......
...@@ -199,8 +199,6 @@ class DTrainer(): ...@@ -199,8 +199,6 @@ class DTrainer():
class StarGAN(object): class StarGAN(object):
def add_special_args(self, parser): def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument( parser.add_argument(
'--g_lr', type=float, default=0.0001, help="learning rate of g") '--g_lr', type=float, default=0.0001, help="learning rate of g")
parser.add_argument( parser.add_argument(
......
...@@ -12,6 +12,14 @@ ...@@ -12,6 +12,14 @@
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
from .CGAN import CGAN
from .DCGAN import DCGAN
from .CycleGAN import CycleGAN
from .Pix2pix import Pix2pix
from .STGAN import STGAN
from .StarGAN import StarGAN
from .AttGAN import AttGAN
import importlib import importlib
......
...@@ -77,7 +77,7 @@ def base_parse_args(parser): ...@@ -77,7 +77,7 @@ def base_parse_args(parser):
add_arg('epoch', int, 200, "The number of epoch to be trained.") add_arg('epoch', int, 200, "The number of epoch to be trained.")
add_arg('g_base_dims', int, 64, "Base channels in generator") add_arg('g_base_dims', int, 64, "Base channels in generator")
add_arg('d_base_dims', int, 64, "Base channels in discriminator") add_arg('d_base_dims', int, 64, "Base channels in discriminator")
add_arg('load_size', int, 286, "the image size when load the image") add_arg('image_size', int, 286, "the image size when load the image")
add_arg('crop_type', str, 'Centor', add_arg('crop_type', str, 'Centor',
"the crop type, choose = ['Centor', 'Random']") "the crop type, choose = ['Centor', 'Random']")
add_arg('crop_size', int, 256, "crop size when preprocess image") add_arg('crop_size', int, 256, "crop size when preprocess image")
......
...@@ -113,9 +113,10 @@ def save_test_image(epoch, ...@@ -113,9 +113,10 @@ def save_test_image(epoch,
images = [real_img_temp] images = [real_img_temp]
for i in range(cfg.c_dim): for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org) label_trg_tmp = copy.deepcopy(label_org)
label_trg_tmp[0][i] = 1.0 - label_trg_tmp[0][i] for j in range(len(label_org)):
label_trg = check_attribute_conflict(label_trg_tmp, label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
attr_names[i], attr_names) label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor() tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place) tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run( fake_temp, rec_temp = exe.run(
...@@ -126,11 +127,13 @@ def save_test_image(epoch, ...@@ -126,11 +127,13 @@ def save_test_image(epoch,
"label_trg": tensor_label_trg "label_trg": tensor_label_trg
}, },
fetch_list=[g_trainer.fake_img, g_trainer.rec_img]) fetch_list=[g_trainer.fake_img, g_trainer.rec_img])
fake_temp = save_batch_image(fake_temp[0]) fake_temp = save_batch_image(fake_temp)
rec_temp = save_batch_image(rec_temp[0]) rec_temp = save_batch_image(rec_temp)
images.append(fake_temp) images.append(fake_temp)
images.append(rec_temp) images.append(rec_temp)
images_concat = np.concatenate(images, 1) images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(out_path + "/fake_img" + str(epoch) + "_" + name[0], imageio.imwrite(out_path + "/fake_img" + str(epoch) + "_" + name[0],
((images_concat + 1) * 127.5).astype(np.uint8)) ((images_concat + 1) * 127.5).astype(np.uint8))
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN': elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
...@@ -184,12 +187,12 @@ def save_test_image(epoch, ...@@ -184,12 +187,12 @@ def save_test_image(epoch,
else: else:
for data_A, data_B in zip(A_test_reader(), B_test_reader()): for data_A, data_B in zip(A_test_reader(), B_test_reader()):
A_name = data_A[0][1] A_data, A_name = data_A
B_name = data_B[0][1] B_data, B_name = data_B
tensor_A = fluid.LoDTensor() tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor() tensor_B = fluid.LoDTensor()
tensor_A.set(data_A[0][0], place) tensor_A.set(A_data, place)
tensor_B.set(data_B[0][0], place) tensor_B.set(B_data, place)
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run( fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
test_program, test_program,
fetch_list=[ fetch_list=[
...@@ -205,18 +208,20 @@ def save_test_image(epoch, ...@@ -205,18 +208,20 @@ def save_test_image(epoch,
input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0]) input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0]) input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0])
imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name, ( imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name[0],
(fake_B_temp + 1) * 127.5).astype(np.uint8)) ((fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name, ( imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name[0],
(fake_A_temp + 1) * 127.5).astype(np.uint8)) ((fake_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name, ( imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name[0],
(cyc_A_temp + 1) * 127.5).astype(np.uint8)) ((cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name, ( imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name[0],
(cyc_B_temp + 1) * 127.5).astype(np.uint8)) ((cyc_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/inputA_" + str(epoch) + "_" + A_name, ( imageio.imwrite(
(input_A_temp + 1) * 127.5).astype(np.uint8)) out_path + "/inputA_" + str(epoch) + "_" + A_name[0], (
imageio.imwrite(out_path + "/inputB_" + str(epoch) + "_" + B_name, ( (input_A_temp + 1) * 127.5).astype(np.uint8))
(input_B_temp + 1) * 127.5).astype(np.uint8)) imageio.imwrite(
out_path + "/inputB_" + str(epoch) + "_" + B_name[0], (
(input_B_temp + 1) * 127.5).astype(np.uint8))
class ImagePool(object): class ImagePool(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册